Source code for imodal.Models.registration_model

from collections import Iterable, OrderedDict

from imodal.DeformationModules import CompoundModule
from imodal.Manifolds import CompoundManifold
from imodal.Models import BaseModel, deformables_compute_deformed

[docs]class RegistrationModel(BaseModel): def __init__(self, deformables, deformation_modules, attachments, fit_gd=None, lam=1., precompute_callback=None, other_parameters=None): if not isinstance(deformables, Iterable): deformables = [deformables] if not isinstance(deformation_modules, Iterable): deformation_modules = [deformation_modules] if not isinstance(attachments, Iterable): attachments = [attachments] assert len(deformables) == len(attachments) self.__deformables = deformables self.__attachments = attachments self.__deformation_modules = deformation_modules self.__precompute_callback = precompute_callback self.__fit_gd = fit_gd self.__lam = lam if other_parameters is None: other_parameters = [] deformable_manifolds = [deformable.silent_module.manifold.clone(False) for deformable in self.__deformables] deformation_modules_manifolds = CompoundModule(deformation_modules).manifold.clone(False) self.__init_manifold = CompoundManifold([*deformable_manifolds, *deformation_modules_manifolds]) [manifold.cotan_requires_grad_() for manifold in self.__init_manifold] self.__init_other_parameters = other_parameters self.__modules = [deformable.silent_module for deformable in self.__deformables] self.__modules.extend(deformation_modules) # Update the parameter dict self._compute_parameters() super().__init__() @property def modules(self): return self.__modules @property def deformation_modules(self): return self.__deformation_modules @property def attachments(self): return self.__attachments @property def precompute_callback(self): return self.__precompute_callback @property def fit_gd(self): return self.__fit_gd def __get_init_manifold(self): return self.__init_manifold
[docs] def fill_init_manifold(self, init_manifold): self.__init_manifold = init_manifold self._compute_parameters()
init_manifold = property(__get_init_manifold, fill_init_manifold) @property def init_other_parameters(self): return self.__init_other_parameters @property def parameters(self): return self.__parameters @property def lam(self): return self.__lam @property def deformables(self): return self.__deformables def to_device(self, device): self.__init_manifold.to_(device=device) [deformable.to_device(device) for deformable in self.__deformables] [module.to_(device=device) for module in self.__modules] self._compute_parameters()
[docs] def to_device(self, device): [deformation_module.to_(device=device) for deformation_module in self.__deformation_modules] [deformable.to_device(device) for deformable in self.__deformables] [manifold.to_(device=device) for manifold in self.__init_manifold] self._compute_parameters()
def __str__(self): outstr = "Registration model\n" outstr += "=================\n" outstr += ("Attachment={}\n".format(self.__attachments)) outstr += ("Lambda={}\n".format(self.__lam)) outstr += ("Fit geometrical descriptors={}\n".format(self.__fit_gd)) outstr += ("Precompute callback={}\n".format(self.__precompute_callback is not None)) outstr += ("Other parameters={}\n".format(len(self.__init_other_parameters) != 0)) outstr += "\n" outstr += "Modules\n" outstr += "=======\n" for module in self.modules[1:]: outstr += ("\n" + str(module)) return outstr def _compute_parameters(self): # Fill the parameter dictionary that will be given to the optimizer. # For Python version before 3.6, order of dictionary is not garanteed. # For Python version 3.6, order is garanteed in the CPython implementation but not standardised in the language # For Python beyon version 3.6, order is garanteed by the language specifications # Since order for the parameter list is important and to ensure it is preserved with any Python version, we use an OrderedDict self.__parameters = OrderedDict() # Initial moments self.__parameters['cotan'] = {'params': self.__init_manifold.unroll_cotan()} # Geometrical descriptors if specified if self.__fit_gd: self.__parameters['gd'] = {'params': []} for fit_gd, init_manifold in zip(self.__fit_gd, self.__init_manifold[len(self.__deformables):]): if isinstance(fit_gd, bool) and fit_gd: init_manifold.gd_requires_grad_() self.__parameters['gd']['params'].extend(init_manifold.unroll_gd()) # Geometrical descriptor is multidimensional elif isinstance(fit_gd, Iterable): for fit_gdi, init_manifold_gdi in zip(fit_gd, init_manifold.unroll_gd()): if fit_gdi: self.__parameters['gd']['params'].append(init_manifold_gdi) # Other parameters self.__parameters.update(self.__init_other_parameters)
[docs] def evaluate(self, target, solver, it, costs=None, backpropagation=True): """ Evaluate the model and output its cost. Parameters ---------- targets : torch.Tensor or list of torch.Tensor Targets we try to approach. solver : str Solver to use for the shooting. it : int Number of iterations for the integration. Returns ------- dict Dictionnary of (string, float) pairs, representing the costs. """ if costs is None: costs = {} if not isinstance(target, Iterable): target = [target] assert len(target) == len(self.__deformables) # Call precompute callback if available precompute_cost = None if self.precompute_callback is not None: precompute_cost = self.precompute_callback(self.init_manifold, self.modules, self.parameters, self.deformables) if precompute_cost is not None: costs['precompute'] = precompute_cost deformed_sources = self.compute_deformed(solver, it, costs=costs) costs['attach'] = self.__lam * self._compute_attachment_cost(deformed_sources, target) # if torch.any(torch.isnan(torch.tensor(list(costs.values())))): # print("Registration model has been evaluated to NaN!") # print(costs) total_cost = sum(costs.values()) if total_cost.requires_grad and backpropagation: # Compute backward and return costs as a dictionary of floats total_cost.backward() return dict([(key, costs[key].item()) for key in costs]) else: return costs
def _compute_attachment_cost(self, deformed_sources, targets, deformation_costs=None): return sum([attachment(deformed_source, target.geometry) for attachment, deformed_source, target in zip(self.__attachments, deformed_sources, targets)])
[docs] def compute_deformed(self, solver, it, t1=1., costs=None, intermediates=None): """ Compute the deformed source. Parameters ---------- solver : str Solver to use for the shooting. it : int Number of iterations the integration method will do. costs : dict, default=None If provided, will be filled with the costs associated to the deformation. Returns ------- list List of deformed sources. """ compound_module = CompoundModule(self.__deformation_modules) compound_module.manifold.fill_gd([ for manifold in self.__init_manifold[len(self.__deformables):]]) compound_module.manifold.fill_cotan([manifold.cotan for manifold in self.__init_manifold[len(self.__deformables):]]) for deformable, deformable_manifold in zip(self.__deformables, self.__init_manifold): deformable.silent_module.manifold.fill(deformable_manifold) return deformables_compute_deformed(self.__deformables, compound_module, solver, it, t1=t1, costs=costs, intermediates=intermediates)