Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import time

import imageio
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import grad

from pykeops.torch import Kernel, kernel_product
from pykeops.torch.kernel_product.formula import *

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device('cuda:0') if use_cuda else 'cpu'
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split('.')[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = 'data/hippos.pt'
else:
    datafile = 'data/hippos_reduc_reduc.pt'

Define the kernels

Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2)b_j\)

def GaussKernel(sigma):
    def K(x, y, b):
        params = {
            'id': Kernel('gaussian(x,y)'),
            'gamma': 1 / (sigma * sigma),
            'backend': 'auto'
        }
        return kernel_product(params, x, y, b)
    return K

Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)

def GaussLinKernel(sigma):
    def K(x, y, u, v, b):
        params = {
            'id': Kernel('gaussian(x,y) * linear(u,v)**2'),
            'gamma': (1 / (sigma * sigma), None),
            'backend': 'auto'
        }
        return kernel_product(params, (x, u), (y, v), b)
    return K

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(map(lambda x, xdot, xdoti: x + (.25 * dt) * (xdot + 3 * xdoti), x, xdot, xdoti))
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return .5 * (p * K(q, q, p)).sum()
    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)
    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp
    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)
    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p,q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)
    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def CompCLNn(F, V):
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        C, N = .5 * (V0 + V1 + V2), .5 * torch.cross(V1 - V0, V2 - V0)
        L = (N ** 2).sum(dim=1)[:, None].sqrt()
        return C, L, N / L

    CT, LT, NTn = CompCLNn(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = CompCLNn(FS, VS)
        return cst + (LS * K(CS, CS, NSn, NSn, LS)).sum() - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

fig = plt.figure()
ax = Axes3D(fig)
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.plot_trisurf(q0.detach().cpu().numpy()[:, 0],
                q0.detach().cpu().numpy()[:, 1],
                q0.detach().cpu().numpy()[:, 2],
                triangles=FS.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(1, 0, 0, .08), linewidth=1)
ax.plot_trisurf(VT.detach().cpu().numpy()[:, 0],
                VT.detach().cpu().numpy()[:, 1],
                VT.detach().cpu().numpy()[:, 2],
                triangles=FT.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)
blue_proxy = plt.Rectangle((0, 0), 1, 1, fc="b")
red_proxy = plt.Rectangle((0, 0), 1, 1, fc=(1, 0, 0, .5))
ax.legend([red_proxy,  blue_proxy], ['source', 'target'])
ax.set_title('Data')
plt.show()
../../_images/sphx_glr_plot_LDDMM_Surface_001.png

Out:

/home/bcharlier/keops/pykeops/tutorials/surface_registration/plot_LDDMM_Surface.py:208: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  plt.show()

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print('performing optimization...')
start = time.time()

def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print('loss', L.detach().cpu().numpy())
    L.backward()
    return L

for i in range(10):
    print('it ', i, ': ', end='')
    optimizer.step(closure)

print('Optimization (L-BFGS) time: ', round(time.time() - start, 2), ' seconds')

Out:

performing optimization...
it  0 : loss 91181.375
loss 84040.06
loss 26011.562
loss 19569.25
loss 16383.3125
loss 14535.3125
loss 10051.25
loss 9112.6875
loss 7602.5
loss 7533.25
it  1 : loss 7533.25
loss 7459.875
loss 7357.5625
loss 7194.8125
loss 6773.6875
loss 6232.0
loss 5672.625
loss 5336.4375
loss 5147.4375
loss 4827.375
it  2 : loss 4827.375
loss 4333.6875
loss 3663.625
loss 3326.0
loss 2900.0625
loss 2759.75
loss 2682.0625
loss 2521.625
loss 2276.625
loss 2153.4375
it  3 : loss 2153.4375
loss 2100.5
loss 2073.5625
loss 2034.375
loss 1959.625
loss 1834.0
loss 1782.3125
loss 1690.75
loss 1648.4375
loss 1602.9375
it  4 : loss 1602.9375
loss 1532.625
loss 1461.3125
loss 1371.5
loss 1320.75
loss 1297.25
loss 1278.125
loss 1244.0
loss 1218.6875
loss 1201.4375
it  5 : loss 1201.4375
loss 1188.1875
loss 1175.0625
loss 1164.5625
loss 1147.5625
loss 1135.625
loss 1116.125
loss 1084.125
loss 1042.0
loss 1015.5
it  6 : loss 1015.5
loss 1005.875
loss 1000.375
loss 997.8125
loss 995.25
loss 990.75
loss 982.1875
loss 971.3125
loss 958.5625
loss 949.1875
it  7 : loss 949.1875
loss 943.0625
loss 940.875
loss 939.0625
loss 934.625
loss 924.25
loss 906.0625
loss 880.25
loss 859.125
loss 843.625
it  8 : loss 843.625
loss 835.0625
loss 831.9375
loss 830.3125
loss 827.9375
loss 824.75
loss 821.8125
loss 815.875
loss 808.25
loss 797.5
it  9 : loss 797.5
loss 786.375
loss 775.8125
loss 770.8125
loss 768.625
loss 767.1875
loss 766.375
loss 765.5625
loss 764.0
loss 761.125
Optimization (L-BFGS) time:  149.76  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the .gif:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()


images = []
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()

    # create Figure
    fig = Figure(figsize=(6, 5), dpi=100)
    # Link canvas to fig
    canvas = FigureCanvasAgg(fig)

    # make the plot
    ax = Axes3D(fig)
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.plot_trisurf(qnp[:, 0],  qnp[:, 1],  qnp[:, 2],  triangles=FSnp, color=(1, 1, 0, .5), edgecolor=(1, 1, 1, .3),  linewidth=1)
    ax.plot_trisurf(VTnp[:, 0], VTnp[:, 1], VTnp[:, 2], triangles=FTnp, color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)

    yellow_proxy = plt.Rectangle((0, 0), 1, 1, fc="y")
    ax.legend([yellow_proxy, blue_proxy], ['deformed', 'target'])
    ax.set_title('LDDMM matching example, step ' + str(t))

    # draw it!
    canvas.draw()

    # save plot in a numpy array through buffer
    s, (width, height) = canvas.print_to_buffer()
    images.append(np.frombuffer(s, np.uint8).reshape((height, width, 4)))

save_folder = '../../../doc/_build/html/_images/'
os.makedirs(save_folder, exist_ok=True)
imageio.mimsave(save_folder + 'surface_matching.gif', images, duration=.5)

Total running time of the script: ( 3 minutes 1.391 seconds)

Gallery generated by Sphinx-Gallery