import torch
from imodal.StructuredFields import ConstantField
from imodal.Manifolds import EmptyManifold
from imodal.DeformationModules.Abstract import DeformationModule
[docs]class GlobalTranslation(DeformationModule):
""" Global translation deformation module. """
def __init__(self, dim, coeff=1., label=None):
super().__init__(label)
self.__controls = torch.zeros(dim)
self.__coeff = coeff
self.__manifold = EmptyManifold(dim)
def __str__(self):
outstr = "Global translation\n"
if self.label:
outstr += " Label=" + self.label + "\n"
outstr += " Coeff=" + str(self.__coeff)
return outstr
[docs] @classmethod
def build(cls, dim, coeff=1., label=None):
return cls(dim, coeff, label)
[docs] def to_(self, *args, **kwargs):
self.__manifold.to_(*args, **kwargs)
self.__controls = self.__controls.to(*args, **kwargs)
@property
def coeff(self):
return self.__coeff
@property
def manifold(self):
return self.__manifold
@property
def device(self):
return self.__manifold.device
def __get_controls(self):
return self.__controls
[docs] def fill_controls(self, controls):
self.__controls = controls.clone()
def __get_coeff(self):
return self.__coeff
def __set_coeff(self, coeff):
self.__coeff = coeff
controls = property(__get_controls, fill_controls)
coeff = property(__get_coeff, __set_coeff)
[docs] def fill_controls_zero(self):
self.fill_controls(torch.zeros_like(self.__controls))
[docs] def __call__(self, points):
return self.field_generator()(points)
[docs] def cost(self):
return 0.5 * self.__coeff * torch.dot(self.__controls, self.__controls)
[docs] def compute_geodesic_control(self, man):
"""Computes geodesic control from StructuredField vs."""
geodesic_controls = torch.zeros_like(self.__controls)
for i in range(self.__controls.shape[0]):
cont_i = torch.zeros_like(self.__controls)
cont_i[i] = 1.
v_i = ConstantField(cont_i)
geodesic_controls[i] = man.inner_prod_field(v_i) / self.__coeff
self.__controls = geodesic_controls
[docs] def field_generator(self):
return ConstantField(self.__controls)