from typing import Iterable
import torch
from imodal.Manifolds.Abstract import Manifold
from imodal.StructuredFields import StructuredField_m, StructuredField_0
from imodal.StructuredFields.Abstract import SumStructuredField
[docs]class NormalFrame(Manifold):
def __init__(self, dim, nb_pts, gd=None, tan=None, cotan=None, device=None):
assert (gd is None) or ((gd[0].shape[0] == nb_pts) and (gd[0].shape[1] == dim) and\
(gd[1].shape[0] == nb_pts) and (gd[1].shape[1] == dim) and\
(gd[1].shape[2] == dim))
assert (tan is None) or ((tan[0].shape[0] == nb_pts) and (tan[0].shape[1] == dim) and\
(tan[1].shape[0] == nb_pts) and (tan[1].shape[1] == dim) and\
(tan[1].shape[2] == dim))
assert (cotan is None) or ((cotan[0].shape[0] == nb_pts) and (cotan[0].shape[1] == dim) and\
(cotan[1].shape[0] == nb_pts) and (cotan[1].shape[1] == dim) and\
(cotan[1].shape[2] == dim))
super().__init__(((dim,), (dim, dim)), nb_pts, gd, tan, cotan, device=device)
self.__dim = dim
@property
def dim(self):
return self.__dim
[docs] def inner_prod_field(self, field):
man = self.infinitesimal_action(field)
return torch.dot(self.cotan[0].flatten(), man.tan[0].flatten()) + \
torch.einsum('nij, nij->', self.cotan[1], man.tan[1])
[docs] def infinitesimal_action(self, field):
"""Applies the vector field generated by the module on the landmark."""
vx = field(self.gd[0])
d_vx = field(self.gd[0], k=1)
S = 0.5 * (d_vx - torch.transpose(d_vx, 1, 2))
vr = torch.bmm(S, self.gd[1])
return NormalFrame(self.__dim, self.nb_pts, gd=self.gd, tan=(vx, vr))
[docs] def cot_to_vs(self, sigma, backend=None):
v0 = StructuredField_0(self.gd[0], self.cotan[0], sigma, backend=backend)
R = torch.einsum('nik, njk->nij', self.cotan[1], self.gd[1])
vm = StructuredField_m(self.gd[0], R, sigma, backend=backend)
return SumStructuredField([v0, vm])