Source code for DeformationModules.OrientedTranslation

import torch
import numpy as np

from pykeops.torch import Genred, KernelSolve

from imodal.DeformationModules.Abstract import DeformationModule, create_deformation_module_with_backends
from imodal.Kernels.kernels import K_xy, K_xx
from imodal.Manifolds import LandmarksDirection
from imodal.StructuredFields import StructuredField_0


[docs]class OrientedTranslationsBase(DeformationModule): """Module generating sum of oriented translations.""" def __init__(self, manifold, sigma, coeff, label): assert isinstance(manifold, LandmarksDirection) super().__init__(label) self.__manifold = manifold self.__sigma = sigma self.__coeff = coeff self.__controls = torch.zeros(self.__manifold.nb_pts, device=manifold.device, dtype=manifold.dtype) def __str__(self): outstr = "Oriented translation\n" if self.label: outstr += " Label=" + self.label + "\n" outstr += " Sigma=" + str(self.__sigma) + "\n" outstr += " Coeff=" + str(self.__coeff) + "\n" outstr += " Nb pts=" + str(self.__manifold.nb_pts) return outstr
[docs] @classmethod def build(cls, dim, nb_pts, sigma, transport='vector', coeff=1., gd=None, tan=None, cotan=None, label=None): return cls(LandmarksDirection(dim, nb_pts, transport, gd=gd, tan=tan, cotan=cotan), sigma, coeff, label)
[docs] def to_(self, *args, **kwargs): self.__manifold.to_(*args, **kwargs) self.__controls = self.__controls.to(*args, **kwargs)
@property def device(self): return self.__manifold.device @property def manifold(self): return self.__manifold @property def dim(self): return self.__manifold.dim @property def sigma(self): return self.__sigma @property def coeff(self): return self.__coeff def __get_controls(self): return self.__controls
[docs] def fill_controls(self, controls): self.__controls = controls
controls = property(__get_controls, fill_controls)
[docs] def fill_controls_zero(self): self.__controls = torch.zeros(self.__manifold.nb_pts, device=self.__manifold.device, dtype=self.__manifold.dtype)
[docs] def __call__(self, points, k=0): return self.field_generator()(points, k)
[docs] def cost(self): raise NotImplementedError
[docs] def compute_geodesic_control(self, man): raise NotImplementedError
[docs] def field_generator(self): return StructuredField_0(self.__manifold.gd[0], self.__controls.unsqueeze(1).repeat(1, self.dim)*self.__manifold.gd[1], self.__sigma, device=self.device, backend=self.backend)
[docs] def adjoint(self, manifold): return manifold.cot_to_vs(self.__sigma, backend=self.backend)
class OrientedTranslations_Torch(OrientedTranslationsBase): def __init__(self, manifold, sigma, coeff, label): super().__init__(manifold, sigma, coeff, label) @property def backend(self): return 'torch' def cost(self): K_q = K_xx(self.manifold.gd[0], self.sigma) m = torch.mm(K_q, self.controls.unsqueeze(1).repeat(1, self.dim)*self.manifold.gd[1]) return 0.5 * self.coeff * torch.dot(m.flatten(), (self.controls.unsqueeze(1).repeat(1, self.dim)*self.manifold.gd[1]).flatten()) def compute_geodesic_control(self, man): vs = self.adjoint(man) Z = K_xx(self.manifold.gd[0], self.sigma) * torch.mm(self.manifold.gd[1], self.manifold.gd[1].T) controls, _ = torch.solve(torch.einsum('ni, ni->n', vs(self.manifold.gd[0]), self.manifold.gd[1]).unsqueeze(1), Z) self.controls = controls.flatten().contiguous()/self.coeff class OrientedTranslations_KeOps(OrientedTranslationsBase): def __init__(self, manifold, sigma, coeff, label): super().__init__(manifold, sigma, coeff, label) self.__keops_dtype = str(manifold.gd.dtype).split(".")[1] self.__keops_backend = 'CPU' if str(self.device) != 'cpu': self.__keops_backend = 'GPU' @property def backend(self): return 'keops' def cost(self): raise NotImplementedError() def compute_geodesic_control(self, man): raise NotImplementedError() OrientedTranslations = create_deformation_module_with_backends(OrientedTranslations_Torch.build, OrientedTranslations_Torch.build)