Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import time

import numpy as np
from torch.autograd import grad

import plotly
import plotly.graph_objs as go

from pykeops.torch import Kernel, kernel_product, LazyTensor, Vi, Vj
from pykeops.torch.kernel_product.formula import *

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device('cuda:0') if use_cuda else 'cpu'
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split('.')[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = 'data/hippos.pt'
else:
    datafile = 'data/hippos_reduc_reduc.pt'

Define the kernels

Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2)b_j\)

def GaussKernel(sigma):
    x, y, b = Vi(0,3), Vj(1,3), Vj(2,3)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2*gamma).exp()
    return (K*b).sum_reduction(axis=1)

Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)

def GaussLinKernel(sigma):
    x, y, u, v, b = Vi(0,3), Vj(1,3), Vi(2,3), Vj(3,3), Vj(4,1)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2*gamma).exp() * (u*v).sum()**2
    return (K*b).sum_reduction(axis=1)

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(map(lambda x, xdot, xdoti: x + (.25 * dt) * (xdot + 3 * xdoti), x, xdot, xdoti))
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return .5 * (p * K(q, q, p)).sum()
    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)
    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp
    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)
    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p,q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)
    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def get_center_length_normal(F, V):
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        centers, normals =  (V0 + V1 + V2) / 3, .5 * torch.cross(V1 - V0, V2 - V0)
        length = (normals ** 2).sum(dim=1)[:, None].sqrt()
        return centers, length, normals/ length

    CT, LT, NTn = get_center_length_normal(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = get_center_length_normal(FS, VS)
        return cst + (LS * K(CS, CS, NSn, NSn, LS)).sum() - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

x, y, z = q0[:,0].detach().cpu().numpy(), q0[:,1].detach().cpu().numpy(), q0[:,2].detach().cpu().numpy()
i, j, k = FS[:,0].detach().cpu().numpy(), FS[:,1].detach().cpu().numpy(), FS[:,2].detach().cpu().numpy()

xt, yt, zt = VT[:,0].detach().cpu().numpy(), VT[:,1].detach().cpu().numpy(), VT[:,2].detach().cpu().numpy()
it, jt, kt = FT[:,0].detach().cpu().numpy(), FT[:,1].detach().cpu().numpy(), FT[:,2].detach().cpu().numpy()

save_folder = '../../../doc/_build/html/_images/'
os.makedirs(save_folder, exist_ok=True)

fig = go.Figure(data=[go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, color='blue', opacity=0.50),
                      go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='red', opacity=0.50)])
fig.write_html(save_folder + 'data.html', auto_open=False)
# sphinx_gallery_thumbnail_path = '_static/plot_LDDMM_Surface_thumb.png'

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print('performing optimization...')
start = time.time()

def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print('loss', L.detach().cpu().numpy())
    L.backward()
    return L

for i in range(10):
    print('it ', i, ': ', end='')
    optimizer.step(closure)

print('Optimization (L-BFGS) time: ', round(time.time() - start, 2), ' seconds')

Out:

performing optimization...
it  0 : loss 87667.19
loss 81357.81
loss 21514.375
loss 15458.375
loss 9952.25
loss 7281.875
loss 4681.0
loss 4017.125
loss 3976.875
loss 3928.375
it  1 : loss 3928.375
loss 3887.625
loss 3773.5
loss 3567.875
loss 3267.875
loss 2933.5
loss 2562.625
loss 2287.5
loss 1969.375
loss 1708.125
it  2 : loss 1708.125
loss 1504.0
loss 1403.5
loss 1321.0
loss 1244.25
loss 1163.875
loss 1087.0
loss 1056.75
loss 1041.625
loss 1027.375
it  3 : loss 1027.375
loss 1014.25
loss 992.0
loss 950.25
loss 877.75
loss 836.0
loss 776.75
loss 733.5
loss 717.25
loss 701.5
it  4 : loss 701.5
loss 687.25
loss 674.75
loss 654.0
loss 616.625
loss 586.0
loss 567.0
loss 555.25
loss 549.0
loss 544.0
it  5 : loss 544.0
loss 536.0
loss 525.5
loss 515.375
loss 491.25
loss 463.875
loss 442.75
loss 434.0
loss 429.0
loss 425.75
it  6 : loss 425.75
loss 424.0
loss 422.375
loss 419.5
loss 415.75
loss 408.625
loss 399.75
loss 393.625
loss 391.625
loss 390.125
it  7 : loss 390.125
loss 389.25
loss 387.625
loss 385.5
loss 381.25
loss 372.625
loss 356.5
loss 336.5
loss 318.625
loss 315.875
it  8 : loss 315.875
loss 321.125
loss 311.375
loss 311.25
loss 310.625
loss 310.375
loss 310.25
loss 309.75
loss 309.375
loss 308.5
it  9 : loss 308.5
loss 307.25
loss 305.5
loss 304.0
loss 302.375
loss 300.25
loss 298.75
loss 297.5
loss 297.125
loss 296.75
Optimization (L-BFGS) time:  29.42  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the figure:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()

# Create figure
fig = go.Figure()
fig.add_trace(
        go.Mesh3d(
            visible=True,
            x=VTnp[:, 0], y=VTnp[:, 1], z=VTnp[:, 2],
            i=FTnp[:, 0], j=FTnp[:, 1], k=FTnp[:, 2],
        )
)

# Add traces, one for each slider step
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()
    fig.add_trace(
        go.Mesh3d(
            visible=False,
            x=qnp[:, 0], y=qnp[:, 1], z=qnp[:, 2],
            i=FSnp[:, 0], j=FSnp[:, 1], k=FSnp[:, 2],
            )
    )

# Make 10th trace visible
fig.data[1].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data) - 1):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][0] = True
    step["args"][1][i+1] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "time: "},
    pad={"t": 20},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.write_html( save_folder + "results.html", auto_open=False)

Total running time of the script: ( 0 minutes 30.296 seconds)

Gallery generated by Sphinx-Gallery