Source code for DeformationModules.Linear
import torch
from imodal.StructuredFields import StructuredField_Affine
from imodal.Manifolds import Landmarks
from imodal.DeformationModules.Abstract import DeformationModule
[docs]class LinearDeformation(DeformationModule):
"""Global translation module."""
def __init__(self, manifold, A, coeff=1., label=None):
super().__init__(label)
self.__controls = torch.tensor(0., dtype=A.dtype)
self.__coeff = coeff
self.__manifold = manifold
self.__A = A
def __str__(self):
outstr = "Linear deformation module\n"
if self.label:
outstr += " Label=" + self.label + "\n"
outstr += " Coeff=" + self.__coeff
outstr += " A=\n"
outstr += str(self.__A.detach().cpu().tolist())
return outstr
[docs] @classmethod
def build(cls, A, coeff=1., gd=None, tan=None, cotan=None, label=None):
return cls(Landmarks(A.shape[0], 1, gd=gd, tan=tan, cotan=cotan), A, coeff, label)
[docs] def to_(self, *args, **kwargs):
self.__manifold.to_(*args, **kwargs)
self.__A = self.__A.to(*args, **kwargs)
self.__controls = self.__controls.to(*args, **kwargs)
@property
def coeff(self):
return self.__coeff
@property
def A(self):
return self.__A
@property
def manifold(self):
return self.__manifold
def __get_controls(self):
return self.__controls
[docs] def fill_controls(self, controls):
assert controls.shape == torch.Size([])
self.__controls = controls
def __get_coeff(self):
return self.__coeff
def __set_coeff(self, coeff):
self.__coeff = coeff
controls = property(__get_controls, fill_controls)
coeff = property(__get_coeff, __set_coeff)
[docs] def compute_geodesic_control(self, man):
"""Computes geodesic control from StructuredField vs."""
vs = StructuredField_Affine(self.__A, self.__manifold.gd.flatten(), torch.zeros_like(self.__manifold.gd.flatten()))
self.__controls = man.inner_prod_field(vs)/self.__coeff
[docs] def field_generator(self):
return StructuredField_Affine(self.__controls*self.__A, self.__manifold.gd.flatten(), torch.zeros_like(self.__manifold.gd.flatten()))