Source code for Attachment.attachment_varifold

from collections import Iterable

import torch
from pykeops.torch import Genred

from imodal.Attachment.attachment import Attachment
from imodal.Utilities.meshutils import close_shape, compute_centers_normals_lengths
from imodal.Utilities.compute_backend import get_compute_backend
from imodal.Kernels import K_xy


[docs]class VarifoldAttachmentBase(Attachment): def __init__(self, sigmas, weight=1.): assert isinstance(sigmas, Iterable) super().__init__(weight=weight) self.__sigmas = sigmas def __str__(self): outstr = str(super().__str__()) + "\n" outstr += " Sigmas={sigmas}".format(sigmas=self.__sigmas) return outstr @property def dim(self): raise NotImplementedError @property def sigmas(self): """ scales of the varifods. """ return self.__sigmas
[docs] def cost_varifold(self, source, target, sigma): raise NotImplementedError
[docs] def loss(self, source, target): return sum([self.cost_varifold(source, target, sigma) for sigma in self.__sigmas])
class VarifoldAttachment2D(VarifoldAttachmentBase): def __init__(self, sigmas, weight=1.): super().__init__(sigmas, weight) @property def dim(self): return 2 class VarifoldAttachment2D_Torch(VarifoldAttachment2D): def cost_varifold(self, source, target, sigma): source = source[0] target = target[0] def dot_varifold(x, y, sigma): cx, cy = close_shape(x), close_shape(y) nx, ny = x.shape[0], y.shape[0] vx, vy = cx[1:nx + 1, :] - x, cy[1:ny + 1, :] - y mx, my = (cx[1:nx + 1, :] + x) / 2, (cy[1:ny + 1, :] + y) / 2 xy = torch.tensordot(torch.transpose(torch.tensordot(mx, my, dims=0), 1, 2), torch.eye(2, dtype=source.dtype, device=source.device)) d2 = torch.sum(mx * mx, dim=1).reshape(nx, 1).repeat(1, ny) + torch.sum(my * my, dim=1).repeat(nx, 1) - 2. * xy kxy = torch.exp(-d2 / (2. * sigma ** 2)) vxvy = torch.tensordot(torch.transpose(torch.tensordot(vx, vy, dims=0), 1, 2), torch.eye(2, dtype=source.dtype, device=source.device)) ** 2 nvx = torch.norm(vx, dim=1) nvy = torch.norm(vy, dim=1) mask = vxvy > 0 return torch.sum(kxy[mask] * vxvy[mask] / (torch.tensordot(nvx, nvy, dims=0)[mask])) return dot_varifold(source, source, sigma) + dot_varifold(target, target, sigma) - 2 * dot_varifold(source, target, sigma) class VarifoldAttachment2D_KeOps(VarifoldAttachment2D): def __init__(self, sigmas, weight=1.): super().__init__(sigmas, weight=weight) def cost_varifold(self, source, target, sigma): raise NotImplementedError class VarifoldAttachment3D(VarifoldAttachmentBase): def __init__(self, sigmas, weight=1.): super().__init__(sigmas, weight=weight) def dim(self): return 3 def cost_varifold(self, source, target, sigma): vertices_source, faces_source = source vertices_target, faces_target = target centers_source, normals_source, lengths_source = compute_centers_normals_lengths(vertices_source, faces_source) centers_target, normals_target, lengths_target = compute_centers_normals_lengths(vertices_target, faces_target) normalized_source = normals_source/lengths_source normalized_target = normals_target/lengths_target return self.varifold_scalar_product(centers_target, centers_target, lengths_target, lengths_target, normalized_target, normalized_target, sigma)\ + self.varifold_scalar_product(centers_source, centers_source, lengths_source, lengths_source, normalized_source, normalized_source, sigma)\ - 2. * self.varifold_scalar_product(centers_source, centers_target, lengths_source, lengths_target, normalized_source, normalized_target, sigma) class VarifoldAttachment3D_Torch(VarifoldAttachment3D): """ Taken and adapted from https://gitlab.com/icm-institute/aramislab/deformetrica/-/blob/master/deformetrica/core/model_tools/attachments/multi_object_attachment.py """ def __init__(self, sigmas, weight=1.): super().__init__(sigmas, weight=weight) def __convolve(self, x, y, p, sigma): def binet(x): return x * x K = K_xy(x[0], y[0], sigma) return torch.mm(K * binet(torch.mm(x[1], y[1].T)), p) def varifold_scalar_product(self, x, y, lengths_x, lengths_y, normalized_x, normalized_y, sigma): return torch.dot(lengths_x.view(-1), self.__convolve((x, normalized_x), (y, normalized_y), lengths_y.view(-1, 1), sigma).view(-1)) class VarifoldAttachment3D_KeOps(VarifoldAttachment3D): def __init__(self, sigmas, weight=1.): super().__init__(sigmas, weight) self.__K = None self.__oos2 = {} def varifold_scalar_product(self, x, y, lengths_x, lengths_y, normalized_x, normalized_y, sigma): if self.__K is None: formula = "Exp(-S*SqNorm2(x - y)/IntCst(2)) * Square((u|v))*p" alias = ["x=Vi(3)", "y=Vj(3)", "u=Vi(3)", "v=Vj(3)", "p=Vj(1)", "S=Pm(1)"] self.__K = Genred(formula, alias, reduction_op='Sum', axis=1, dtype=str(x.dtype).split('.')[1]) self.__keops_backend = 'CPU' if str(x.device) != 'cpu': self.__keops_backend = 'GPU' if sigma not in self.__oos2: self.__oos2[sigma] = torch.tensor([1./sigma/sigma], device=x.device, dtype=x.dtype) return (lengths_x * self.__K(x, y, normalized_x, normalized_y, lengths_y, self.__oos2[sigma], backend=self.__keops_backend)).sum() def VarifoldAttachment(dim, sigmas, weight=1., backend=None): if backend is None: backend = get_compute_backend() # Ugly manual switch between dimensions and compute backends if dim == 2: if backend == 'torch': return VarifoldAttachment2D_Torch(sigmas, weight=weight) elif backend == 'keops': return VarifoldAttachment2D_KeOps(sigmas, weight=weight) else: raise RuntimeError("VarifoldAttachment: Unrecognized backend {backend}!".format(backend=backend)) elif dim == 3: if backend == 'torch': return VarifoldAttachment3D_Torch(sigmas, weight=weight) elif backend == 'keops': return VarifoldAttachment3D_KeOps(sigmas, weight=weight) else: raise RuntimeError("VarifoldAttachment: Unrecognized backend {backend}!".format(backend=backend)) else: raise RuntimeError("VarifoldAttachment: Dimension {dim} not supported!".format(dim=dim))