from collections import Iterable, OrderedDict
from imodal.DeformationModules import CompoundModule
from imodal.Manifolds import CompoundManifold
from imodal.Models import BaseModel, deformables_compute_deformed
[docs]class RegistrationModel(BaseModel):
def __init__(self, deformables, deformation_modules, attachments, fit_gd=None, lam=1., precompute_callback=None, other_parameters=None):
if not isinstance(deformables, Iterable):
deformables = [deformables]
if not isinstance(deformation_modules, Iterable):
deformation_modules = [deformation_modules]
if not isinstance(attachments, Iterable):
attachments = [attachments]
assert len(deformables) == len(attachments)
self.__deformables = deformables
self.__attachments = attachments
self.__deformation_modules = deformation_modules
self.__precompute_callback = precompute_callback
self.__fit_gd = fit_gd
self.__lam = lam
if other_parameters is None:
other_parameters = []
deformable_manifolds = [deformable.silent_module.manifold.clone(False) for deformable in self.__deformables]
deformation_modules_manifolds = CompoundModule(deformation_modules).manifold.clone(False)
self.__init_manifold = CompoundManifold([*deformable_manifolds, *deformation_modules_manifolds])
[manifold.cotan_requires_grad_() for manifold in self.__init_manifold]
self.__init_other_parameters = other_parameters
self.__modules = [deformable.silent_module for deformable in self.__deformables]
self.__modules.extend(deformation_modules)
# Update the parameter dict
self._compute_parameters()
super().__init__()
@property
def modules(self):
return self.__modules
@property
def deformation_modules(self):
return self.__deformation_modules
@property
def attachments(self):
return self.__attachments
@property
def precompute_callback(self):
return self.__precompute_callback
@property
def fit_gd(self):
return self.__fit_gd
def __get_init_manifold(self):
return self.__init_manifold
[docs] def fill_init_manifold(self, init_manifold):
self.__init_manifold = init_manifold
self._compute_parameters()
init_manifold = property(__get_init_manifold, fill_init_manifold)
@property
def init_other_parameters(self):
return self.__init_other_parameters
@property
def parameters(self):
return self.__parameters
@property
def lam(self):
return self.__lam
@property
def deformables(self):
return self.__deformables
def to_device(self, device):
self.__init_manifold.to_(device=device)
[deformable.to_device(device) for deformable in self.__deformables]
[module.to_(device=device) for module in self.__modules]
self._compute_parameters()
[docs] def to_device(self, device):
[deformation_module.to_(device=device) for deformation_module in self.__deformation_modules]
[deformable.to_device(device) for deformable in self.__deformables]
[manifold.to_(device=device) for manifold in self.__init_manifold]
self._compute_parameters()
def __str__(self):
outstr = "Registration model\n"
outstr += "=================\n"
outstr += ("Attachment={}\n".format(self.__attachments))
outstr += ("Lambda={}\n".format(self.__lam))
outstr += ("Fit geometrical descriptors={}\n".format(self.__fit_gd))
outstr += ("Precompute callback={}\n".format(self.__precompute_callback is not None))
outstr += ("Other parameters={}\n".format(len(self.__init_other_parameters) != 0))
outstr += "\n"
outstr += "Modules\n"
outstr += "=======\n"
for module in self.modules[1:]:
outstr += ("\n" + str(module))
return outstr
def _compute_parameters(self):
# Fill the parameter dictionary that will be given to the optimizer.
# For Python version before 3.6, order of dictionary is not garanteed.
# For Python version 3.6, order is garanteed in the CPython implementation but not standardised in the language
# For Python beyon version 3.6, order is garanteed by the language specifications
# Since order for the parameter list is important and to ensure it is preserved with any Python version, we use an OrderedDict
self.__parameters = OrderedDict()
# Initial moments
self.__parameters['cotan'] = {'params': self.__init_manifold.unroll_cotan()}
# Geometrical descriptors if specified
if self.__fit_gd:
self.__parameters['gd'] = {'params': []}
for fit_gd, init_manifold in zip(self.__fit_gd, self.__init_manifold[len(self.__deformables):]):
if isinstance(fit_gd, bool) and fit_gd:
init_manifold.gd_requires_grad_()
self.__parameters['gd']['params'].extend(init_manifold.unroll_gd())
# Geometrical descriptor is multidimensional
elif isinstance(fit_gd, Iterable):
for fit_gdi, init_manifold_gdi in zip(fit_gd, init_manifold.unroll_gd()):
if fit_gdi:
self.__parameters['gd']['params'].append(init_manifold_gdi)
# Other parameters
self.__parameters.update(self.__init_other_parameters)
[docs] def evaluate(self, target, solver, it, costs=None, backpropagation=True):
""" Evaluate the model and output its cost.
Parameters
----------
targets : torch.Tensor or list of torch.Tensor
Targets we try to approach.
solver : str
Solver to use for the shooting.
it : int
Number of iterations for the integration.
Returns
-------
dict
Dictionnary of (string, float) pairs, representing the costs.
"""
if costs is None:
costs = {}
if not isinstance(target, Iterable):
target = [target]
assert len(target) == len(self.__deformables)
# Call precompute callback if available
precompute_cost = None
if self.precompute_callback is not None:
precompute_cost = self.precompute_callback(self.init_manifold, self.modules, self.parameters, self.deformables)
if precompute_cost is not None:
costs['precompute'] = precompute_cost
deformed_sources = self.compute_deformed(solver, it, costs=costs)
costs['attach'] = self.__lam * self._compute_attachment_cost(deformed_sources, target)
# if torch.any(torch.isnan(torch.tensor(list(costs.values())))):
# print("Registration model has been evaluated to NaN!")
# print(costs)
total_cost = sum(costs.values())
if total_cost.requires_grad and backpropagation:
# Compute backward and return costs as a dictionary of floats
total_cost.backward()
return dict([(key, costs[key].item()) for key in costs])
else:
return costs
def _compute_attachment_cost(self, deformed_sources, targets, deformation_costs=None):
return sum([attachment(deformed_source, target.geometry) for attachment, deformed_source, target in zip(self.__attachments, deformed_sources, targets)])