Source code for DeformationModules.ImplicitModule1

import torch
import pickle

from pykeops.torch import Genred, KernelSolve

from imodal.DeformationModules.Abstract import DeformationModule, create_deformation_module_with_backends
from imodal.Kernels.SKS import eta, compute_sks, A
from imodal.Manifolds import NormalFrame
from imodal.StructuredFields import StructuredField_p


[docs]class ImplicitModule1Base(DeformationModule): """ Implicit module of order 1. """ def __init__(self, manifold, sigma, C, nu, coeff, label): assert isinstance(manifold, NormalFrame) super().__init__(label) self.__manifold = manifold self.__C = C self.__sigma = sigma self.__nu = nu self.__coeff = coeff self.__dim_controls = C.shape[2] self.__sym_dim = int(self.manifold.dim * (self.manifold.dim + 1) / 2) self.__controls = torch.zeros(self.__dim_controls, device=self.__manifold.device) def __str__(self): outstr = "Implicit module of order 1\n" if self.label: outstr += " Label=" + self.label + "\n" outstr += " Sigma=" + str(self.sigma) + "\n" outstr += " Nu=" + str(self.__nu) + "\n" outstr += " Coeff=" + str(self.__coeff) + "\n" outstr += " Dim controls=" + str(self.__dim_controls) + "\n" outstr += " Nb pts=" + str(self.__manifold.nb_pts) + "\n" return outstr
[docs] @classmethod def build(cls, dim, nb_pts, sigma, C, nu=0., coeff=1., gd=None, tan=None, cotan=None, label=None): return cls(NormalFrame(dim, nb_pts, gd=gd, tan=tan, cotan=cotan), sigma, C, nu, coeff, label)
@property def dim(self): return self.__manifold.dim
[docs] def to_(self, *args, **kwargs): self.__manifold.to_(*args, **kwargs) self.__controls = self.__controls.to(*args, **kwargs) self.__C = self.__C.to(*args, **kwargs)
@property def device(self): return self.__manifold.device @property def manifold(self): return self.__manifold def __get_C(self): return self.__C def __set_C(self, C): self.__C = C C = property(__get_C, __set_C) @property def sigma(self): return self.__sigma @property def nu(self): return self.__nu @property def sym_dim(self): return self.__sym_dim @property def dim_controls(self): return self.__dim_controls def __get_controls(self): return self.__controls
[docs] def fill_controls(self, controls): self.__controls = controls self.compute_moments()
def __get_coeff(self): return self.__coeff def __set_coeff(self, coeff): self.__coeff = coeff controls = property(__get_controls, fill_controls) coeff = property(__get_coeff, __set_coeff)
[docs] def fill_controls_zero(self): self.fill_controls(torch.zeros(self.__dim_controls, device=self.device))
[docs] def __call__(self, points, k=0): return self.field_generator()(points, k)
[docs] def cost(self): raise NotImplementedError()
[docs] def compute_geodesic_control(self, man): raise NotImplementedError()
[docs] def compute_moments(self): raise NotImplementedError()
[docs] def field_generator(self): return StructuredField_p(self.__manifold.gd[0], self.moments, self.__sigma, backend=self.backend)
[docs] def adjoint(self, manifold): return manifold.cot_to_vs(self.__sigma, backend=self.backend)
class ImplicitModule1_Torch(ImplicitModule1Base): def __init__(self, manifold, sigma, C, nu, coeff, label): super().__init__(manifold, sigma, C, nu, coeff, label) @property def backend(self): return 'torch' def cost(self): return 0.5 * self.coeff * torch.dot(self.__aqh.view(-1), self.__lambdas.view(-1)) def compute_geodesic_control(self, man): vs = self.adjoint(man) d_vx = vs(self.manifold.gd[0], k=1) S = 0.5 * (d_vx + torch.transpose(d_vx, 1, 2)) S = torch.tensordot(S, eta(self.manifold.dim, device=self.device), dims=2) self.__compute_sks() tlambdas, _ = torch.solve(S.view(-1, 1), self.__sks) tlambdas = tlambdas/self.coeff (aq, aqkiaq) = self.__compute_aqkiaq() c, _ = torch.solve(torch.mm(aq.t(), tlambdas), aqkiaq) self.controls = c.reshape(-1) self.__compute_moments() def compute_moments(self): self.__compute_sks() self.__compute_moments() def __compute_aqh(self, h): R = self.manifold.gd[1].view(-1, self.manifold.dim, self.manifold.dim) return torch.einsum('nli, nik, k, nui, niv, lvt->nt', R, self.C, h, torch.eye(self.manifold.dim, device=self.device).repeat(self.manifold.nb_pts, 1, 1), torch.transpose(R, 1, 2), eta(self.dim, device=self.device)) def __compute_sks(self): self.__sks = compute_sks(self.manifold.gd[0], self.sigma, 1) + self.nu * torch.eye(self.sym_dim * self.manifold.nb_pts, device=self.device) def __compute_moments(self): self.__aqh = self.__compute_aqh(self.controls) lambdas, _ = torch.solve(self.__aqh.view(-1, 1), self.__sks) self.__lambdas = lambdas.contiguous() self.moments = torch.tensordot(self.__lambdas.view(-1, self.sym_dim), torch.transpose(eta(self.manifold.dim, device=self.device), 0, 2), dims=1) def __compute_aqkiaq(self): lambdas = torch.zeros(self.dim_controls, self.sym_dim * self.manifold.nb_pts, device=self.device) aq = torch.zeros(self.sym_dim * self.manifold.nb_pts, self.dim_controls, device=self.device) for i in range(self.dim_controls): h = torch.zeros(self.dim_controls, device=self.device) h[i] = 1. aqi = self.__compute_aqh(h).flatten() aq[:, i] = aqi l, _ = torch.solve(aqi.view(-1, 1), self.__sks) lambdas[i, :] = l.flatten() return (aq, torch.mm(lambdas, aq)) class ImplicitModule1_KeOps(ImplicitModule1Base): def __init__(self, manifold, sigma, C, nu, coeff, label): super().__init__(manifold, sigma, C, nu, coeff, label) self.__keops_dtype = str(manifold.gd[0].dtype).split(".")[1] self.__keops_backend = 'CPU' if str(self.device) != 'cpu': self.__keops_backend = 'GPU' self.__keops_invsigmasq = torch.tensor([1./sigma/sigma], dtype=manifold.dtype, device=self.device) self.__keops_eye = torch.eye(self.dim, device=self.device, dtype=manifold.dtype).flatten() self.__keops_A = A(self.dim, device=self.device, dtype=manifold.dtype).flatten() formula_solve_sks = "TensorDot(TensorDot((-S*Exp(-S*SqNorm2(x_i - y_j)*IntInv(2))*(S*TensorDot(x_i - y_j, x_i - y_j, Ind({dim}), Ind({dim}), Ind(), Ind()) - eye)), A, Ind({dim}, {dim}), Ind({dim}, {dim}, {symdim}, {symdim}), Ind(0, 1), Ind(0, 1)), X, Ind({symdim}, {symdim}), Ind({symdim}), Ind(0), Ind(0))".format(dim=self.dim, symdim=self.sym_dim) alias_solve_sks = ["x_i=Vi({dim})".format(dim=self.dim), "y_j=Vj({dim})".format(dim=self.dim), "X=Vj({symdim})".format(symdim=self.sym_dim), "eye=Pm({dimsq})".format(dimsq=self.dim*self.dim), "S=Pm(1)", "A=Pm({dima})".format(dima=self.__keops_A.numel())] self.solve_sks = KernelSolve(formula_solve_sks, alias_solve_sks, "X", axis=1, dtype=self.__keops_dtype) self.eps = 1e-6 @property def backend(self): return 'keops' def to_(self, *args, **kwargs): super().to_(*args, **kwargs) self.__keops_invsigmasq = self.__keops_invsigmasq.to(*args, **kwargs) self.__keops_eye = self.__keops_eye.to(*args, **kwargs) self.__keops_A = self.__keops_A.to(*args, **kwargs) if 'device' in kwargs: if kwargs['device'].split(":")[0].lower() == "cuda": self.__keops_backend = 'GPU' elif kwargs['device'].split(":")[0].lower() == "cpu": self.__keops_backend = 'CPU' def cost(self): return 0.5 * self.coeff * torch.dot(self.__aqh.view(-1), self.__lambdas.view(-1)) def compute_geodesic_control(self, man): vs = self.adjoint(man) d_vx = vs(self.manifold.gd[0].view(-1, self.manifold.dim), k=1) S = 0.5 * (d_vx + torch.transpose(d_vx, 1, 2)) S = torch.tensordot(S, eta(self.manifold.dim, device=self.device), dims=2) tlambdas = self.solve_sks(self.manifold.gd[0].reshape(-1, self.dim), self.manifold.gd[0].reshape(-1, self.dim), S, self.__keops_eye, self.__keops_invsigmasq, self.__keops_A, backend=self.__keops_backend, alpha=self.nu, eps=self.eps)/self.coeff (aq, aqkiaq) = self.__compute_aqkiaq() c, _ = torch.solve(torch.mm(aq.t(), tlambdas.view(-1, 1)), aqkiaq) self.controls = c.flatten() self.__compute_moments() def compute_moments(self): self.__compute_moments() def __compute_aqh(self, h): R = self.manifold.gd[1] return torch.einsum('nli, nik, k, nui, niv, lvt->nt', R, self.C, h, torch.eye(self.manifold.dim, device=self.device).repeat(self.manifold.nb_pts, 1, 1), torch.transpose(R, 1, 2), eta(self.manifold.dim, device=self.device)) def __compute_moments(self): self.__aqh = self.__compute_aqh(self.controls) self.__lambdas = self.solve_sks(self.manifold.gd[0].reshape(-1, self.dim), self.manifold.gd[0].reshape(-1, self.dim), self.__aqh, self.__keops_eye, self.__keops_invsigmasq, self.__keops_A, backend=self.__keops_backend, alpha=self.nu) self.moments = torch.tensordot(self.__lambdas.view(-1, self.sym_dim), torch.transpose(eta(self.manifold.dim, device=self.device), 0, 2), dims=1) def __compute_aqkiaq(self): lambdas = torch.zeros(self.dim_controls, self.sym_dim * self.manifold.nb_pts, device=self.device) aq = torch.zeros(self.sym_dim * self.manifold.nb_pts, self.dim_controls, device=self.device) for i in range(self.dim_controls): h = torch.zeros(self.dim_controls, device=self.device) h[i] = 1. aqi = self.__compute_aqh(h).flatten() aq[:, i] = aqi lambdas[i, :] = self.solve_sks(self.manifold.gd[0], self.manifold.gd[0], aqi.view(-1, self.sym_dim), self.__keops_eye, self.__keops_invsigmasq, self.__keops_A, backend=self.__keops_backend, alpha=self.nu, eps=self.eps).view(-1) return (aq, torch.mm(lambdas, aq)) ImplicitModule1 = create_deformation_module_with_backends(ImplicitModule1_Torch.build, ImplicitModule1_KeOps.build)