import torch
import math
from pykeops.torch import Genred, KernelSolve
from imodal.DeformationModules.Abstract import DeformationModule, create_deformation_module_with_backends
from imodal.Kernels.kernels import K_xy, K_xx
from imodal.Manifolds import Landmarks
from imodal.StructuredFields import StructuredField_0
[docs]class LocalConstrainedTranslationsBase(DeformationModule):
"""Module generating sum of constrained translations."""
def __init__(self, manifold, sigma, descstr, f_support, f_vectors, coeff, label):
assert isinstance(manifold, Landmarks)
super().__init__(label)
self.__manifold = manifold
self.__sigma = sigma
self.__descstr = descstr
self.__controls = torch.zeros(1).view([])
self.__coeff = coeff
self._f_support = f_support
self._f_vectors = f_vectors
def __str__(self):
outstr = "Local constrained translation module\n"
if self.label:
outstr += " Label=" + self.label + "\n"
outstr += " Type=" + self.descstr + "\n"
outstr += " Sigma=" + str(self.__sigma) + "\n"
outstr += " Coeff=" + str(self.__coeff)
return outstr
[docs] @classmethod
def build(cls, dim, nb_pts, sigma, descstr, f_support, f_vectors, coeff=1., gd=None, tan=None, cotan=None, label=None):
return cls(Landmarks(dim, nb_pts, gd=gd, tan=tan, cotan=cotan), sigma, descstr, f_support, f_vectors, coeff, label)
[docs] def to_(self, *args, **kwargs):
self.__manifold.to_(*args, **kwargs)
self.__controls = self.__controls.to(*args, **kwargs)
@property
def descstr(self):
"""Description string. Used by __str__()."""
return self.__descstr
@property
def coeff(self):
return self.__coeff
@property
def device(self):
return self.__manifold.device
@property
def manifold(self):
return self.__manifold
@property
def dim(self):
return self.__manifold.dim
@property
def sigma(self):
return self.__sigma
def __get_controls(self):
return self.__controls
[docs] def fill_controls(self, controls):
self.__controls = controls
controls = property(__get_controls, fill_controls)
[docs] def fill_controls_zero(self):
self.__controls = torch.zeros(1, requires_grad=True)
[docs] def __call__(self, points, k=0):
return self.field_generator()(points, k)
[docs] def cost(self):
raise NotImplementedError()
[docs] def compute_geodesic_control(self, man):
raise NotImplementedError()
[docs] def field_generator(self):
support = self._f_support(self.__manifold.gd)
vectors = self._f_vectors(self.__manifold.gd)
return StructuredField_0(support, self.__controls*vectors, self.__sigma, device=self.device, backend=self.backend)
[docs] def adjoint(self, manifold):
return manifold.cot_to_vs(self.__sigma, backend=self.backend)
class LocalConstrainedTranslations_Torch(LocalConstrainedTranslationsBase):
def __init__(self, manifold, sigma, descstr, f_support, f_vectors, coeff, label):
super().__init__(manifold, sigma, descstr, f_support, f_vectors, coeff, label)
@property
def backend(self):
return 'torch'
def cost(self):
support = self._f_support(self.manifold.gd)
vectors = self._f_vectors(self.manifold.gd)
K_q = K_xx(support, self.sigma)
m = torch.mm(K_q, vectors)
return 0.5 * self.coeff * torch.dot(m.flatten(), vectors.flatten()).view([]) * self.controls * self.controls
def compute_geodesic_control(self, man):
support = self._f_support(self.manifold.gd)
vectors = self._f_vectors(self.manifold.gd)
# vector field for control = 1
v = StructuredField_0(support, vectors, self.sigma, device=self.device, backend='torch')
K_q = K_xx(support, self.sigma)
m = torch.mm(K_q, vectors)
co = self.coeff * torch.dot(m.flatten(), vectors.flatten())
self.controls = man.inner_prod_field(v)/co
class LocalConstrainedTranslations_KeOps(LocalConstrainedTranslationsBase):
def __init__(self, manifold, sigma, descstr, f_support, f_vectors, coeff, label):
super().__init__(manifold, sigma, descstr, f_support, f_vectors, coeff, label)
@property
def backend(self):
return 'keops'
def cost(self):
raise NotImplementedError()
def compute_geodesic_control(self, man):
raise NotImplementedError()
LocalConstrainedTranslations = create_deformation_module_with_backends(LocalConstrainedTranslations_Torch.build, LocalConstrainedTranslations_Torch.build)
[docs]def LocalScaling(dim, sigma, coeff=1., gd=None, tan=None, cotan=None, label=None, backend=None):
""" Generates a local scaling deformation module.
Local scaling is approximated by a local constrained translation deformation
module with 3 vectors around the scaling center, pointing inwards.
Parameters
----------
dim : int
Dimension of the ambiant space the deformation module will live on.
sigma : float
Kernel size of the underlying vector space
coeff : float
Coefficient of the deformation module
gd : torch.Tensor
Geometrical descriptor of the deformation module i.e. the scale centers
tan : torch.Tensor
Tangent tensor
cotan : torch.Tensor
Cotangent tensor
label :
Optional identifier
"""
def f_vectors(gd):
return torch.tensor([[math.cos(2.*math.pi/3.*i), math.sin(2.*math.pi/3.*i)] for i in range(3)], device=gd.device, dtype=gd.dtype)
def f_support(gd):
return gd.repeat(3, 1) + sigma/3. * f_vectors(gd)
return LocalConstrainedTranslations(dim, 1, sigma, "Local scaling", f_support, f_vectors, coeff=coeff, gd=gd, tan=tan, cotan=cotan, label=label, backend=backend)
[docs]def LocalRotation(dim, sigma, coeff=1., gd=None, tan=None, cotan=None, label=None, backend=None):
""" Generates a local rotation deformation module.
Local roation is approximated by a local constrained translation deformation
module with 3 vectors around the scaling center, pointing tangantially.
Parameters
----------
dim : int
Dimension of the ambiant space the deformation module will live on.
sigma : float
Kernel size of the underlying vector space
coeff : float
Coefficient of the deformation module
gd : torch.Tensor
Geometrical descriptor of the deformation module i.e. the rotation centers
tan : torch.Tensor
Tangent tensor
cotan : torch.Tensor
Cotangent tensor
label :
Optional identifier
backend : str
Computation backend the deformation module will
"""
def f_vectors_2d(gd):
return torch.tensor([[-math.sin(2.*math.pi/3.*i), math.cos(2.*math.pi/3.*i)] for i in range(3)], device=gd.device, dtype=gd.dtype)
def f_support_2d(gd):
return gd.repeat(3, 1) + sigma/3. * torch.tensor([[math.cos(2.*math.pi/3.*i), math.sin(2.*math.pi/3.*i)] for i in range(3)], device=gd.device, dtype=gd.dtype)
def f_vectors_3d(gd):
tetra = torch.tensor([[1., 1., 1.], [1., -1., -1.], [-1., 1., -1.], [-1., -1., 1.]], device=gd.device, dtype=gd.dtype)
vec = gd[1] - gd[0]
return torch.cross(tetra, vec.repeat(4, 1))
def f_support_3d(gd):
tetra = torch.tensor([[1., 1., 1.], [1., -1., -1.], [-1., 1., -1.], [-1., -1., 1.]], device=gd.device, dtype=gd.dtype)
return gd[0].repeat(4, 1) + sigma/3. * tetra
f_vectors = f_vectors_2d
f_support = f_support_2d
pts_count = 1
if dim == 3:
f_vectors = f_vectors_3d
f_support = f_support_3d
pts_count = 2
return LocalConstrainedTranslations(dim, pts_count, sigma, "Local rotation", f_support, f_vectors, coeff=coeff, gd=gd, tan=tan, cotan=cotan, label=label, backend=backend)