from typing import Iterable
from imodal.DeformationModules.Abstract import DeformationModule
from imodal.Manifolds import CompoundManifold
from imodal.StructuredFields import SumStructuredField
[docs]class CompoundModule(DeformationModule, Iterable):
""" Combination of deformation modules. """
""" Compound module constructor.
Parameters
----------
modules : Iterable
Iterable of deformation modules we want to build the compound module from.
label :
Optional identifier
"""
def __init__(self, modules, label=None):
assert isinstance(modules, Iterable)
super().__init__(label)
self.__modules = [*modules]
def __str__(self):
outstr = "Compound Module\n"
if self.label:
outstr += " Label=" + self.label + "\n"
outstr += "Modules=\n"
for module in self.__modules:
outstr += "*"*20
outstr += str(module) + "\n"
outstr += "*"*20
return outstr
[docs] def to(self, *args, **kwargs):
[mod.to(*args, **kwargs) for mod in self.__modules]
@property
def device(self):
return self.__modules[0].device
@property
def modules(self):
return self.__modules
[docs] def todict(self):
return dict(zip(self.label, self.__modules))
def __getitem__(self, itemid):
if isinstance(itemid, int) or isinstance(itemid, slice):
return self.__modules[itemid]
else:
return self.todict()[itemid]
def __iter__(self):
self.current = 0
return self
def __next__(self):
if self.current >= len(self.__modules):
raise StopIteration
else:
self.current = self.current + 1
return self.__modules[self.current - 1]
@property
def dim(self):
return self.__modules[0].dim # Dirty
def __get_controls(self):
return [m.controls for m in self.__modules]
[docs] def fill_controls(self, controls):
assert len(controls) == len(self.__modules)
[module.fill_controls(control) for module, control in zip(self.__modules, controls)]
controls = property(__get_controls, fill_controls)
[docs] def fill_controls_zero(self):
[module.fill_controls_zero() for module in self.__modules]
@property
def manifold(self):
return CompoundManifold([m.manifold for m in self.__modules])
[docs] def __call__(self, points):
"""Applies the generated vector field on given points."""
return sum([module(points) for module in self.__modules])
[docs] def cost(self):
"""Returns the cost."""
return sum([module.cost() for module in self.__modules])
[docs] def compute_geodesic_control(self, man):
"""Computes geodesic control from \delta \in H^\ast."""
[module.compute_geodesic_control(man) for module in self.__modules]
[docs] def field_generator(self):
return SumStructuredField([m.field_generator() for m in self.__modules])