Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import time

import imageio
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import grad

from pykeops.torch import Kernel, kernel_product
from pykeops.torch.kernel_product.formula import *

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device('cuda:0') if use_cuda else 'cpu'
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split('.')[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = 'data/hippos.pt'
else:
    datafile = 'data/hippos_reduc_reduc.pt'

Define the kernels

Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2)b_j\)

def GaussKernel(sigma):
    def K(x, y, b):
        params = {
            'id': Kernel('gaussian(x,y)'),
            'gamma': 1 / (sigma * sigma),
            'backend': 'auto'
        }
        return kernel_product(params, x, y, b)
    return K

Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)

def GaussLinKernel(sigma):
    def K(x, y, u, v, b):
        params = {
            'id': Kernel('gaussian(x,y) * linear(u,v)**2'),
            'gamma': (1 / (sigma * sigma), None),
            'backend': 'auto'
        }
        return kernel_product(params, (x, u), (y, v), b)
    return K

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(map(lambda x, xdot, xdoti: x + (.25 * dt) * (xdot + 3 * xdoti), x, xdot, xdoti))
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return .5 * (p * K(q, q, p)).sum()
    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)
    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp
    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)
    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p,q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)
    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def CompCLNn(F, V):
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        C, N = .5 * (V0 + V1 + V2), .5 * torch.cross(V1 - V0, V2 - V0)
        L = (N ** 2).sum(dim=1)[:, None].sqrt()
        return C, L, N / L

    CT, LT, NTn = CompCLNn(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = CompCLNn(FS, VS)
        return cst + (LS * K(CS, CS, NSn, NSn, LS)).sum() - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

fig = plt.figure()
ax = Axes3D(fig)
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.plot_trisurf(q0.detach().cpu().numpy()[:, 0],
                q0.detach().cpu().numpy()[:, 1],
                q0.detach().cpu().numpy()[:, 2],
                triangles=FS.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(1, 0, 0, .08), linewidth=1)
ax.plot_trisurf(VT.detach().cpu().numpy()[:, 0],
                VT.detach().cpu().numpy()[:, 1],
                VT.detach().cpu().numpy()[:, 2],
                triangles=FT.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)
blue_proxy = plt.Rectangle((0, 0), 1, 1, fc="b")
red_proxy = plt.Rectangle((0, 0), 1, 1, fc=(1, 0, 0, .5))
ax.legend([red_proxy,  blue_proxy], ['source', 'target'])
ax.set_title('Data')
plt.show()
../../_images/sphx_glr_plot_LDDMM_Surface_001.png

Out:

/home/bcharlier/keops/pykeops/tutorials/surface_registration/plot_LDDMM_Surface.py:208: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  plt.show()

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print('performing optimization...')
start = time.time()

def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print('loss', L.detach().cpu().numpy())
    L.backward()
    return L

for i in range(10):
    print('it ', i, ': ', end='')
    optimizer.step(closure)

print('Optimization (L-BFGS) time: ', round(time.time() - start, 2), ' seconds')

Out:

performing optimization...
it  0 : loss 91181.125
loss 84039.875
loss 26011.312
loss 19569.0
loss 16383.125
loss 14535.25
loss 10051.125
loss 9112.75
loss 7602.375
loss 7533.0625
it  1 : loss 7533.0625
loss 7459.625
loss 7357.3125
loss 7194.625
loss 6773.4375
loss 6231.25
loss 5671.375
loss 5335.125
loss 5146.5
loss 4826.75
it  2 : loss 4826.75
loss 4333.125
loss 3663.125
loss 3324.4375
loss 2898.125
loss 2757.625
loss 2679.875
loss 2519.125
loss 2275.3125
loss 2152.6875
it  3 : loss 2152.6875
loss 2099.875
loss 2072.875
loss 2033.8125
loss 1959.0625
loss 1833.375
loss 1819.75
loss 1708.5625
loss 1678.8125
loss 1638.3125
it  4 : loss 1638.3125
loss 1578.375
loss 1502.25
loss 1420.125
loss 1365.125
loss 1349.1875
loss 1339.25
loss 1320.875
loss 1260.3125
loss 1208.75
it  5 : loss 1208.75
loss 1181.375
loss 1167.9375
loss 1152.125
loss 1139.0625
loss 1129.5625
loss 1123.0
loss 1116.375
loss 1105.0
loss 1073.4375
it  6 : loss 1073.4375
loss 1031.3125
loss 1012.0625
loss 996.5
loss 993.375
loss 989.8125
loss 988.75
loss 986.5625
loss 984.6875
loss 980.625
it  7 : loss 980.625
loss 972.4375
loss 962.375
loss 952.375
loss 945.75
loss 941.8125
loss 939.5
loss 936.75
loss 931.0
loss 918.9375
it  8 : loss 918.9375
loss 898.875
loss 875.625
loss 857.25
loss 845.75
loss 839.75
loss 836.875
loss 832.0
loss 827.375
loss 824.9375
it  9 : loss 824.9375
loss 822.125
loss 819.625
loss 814.125
loss 805.875
loss 793.1875
loss 780.8125
loss 772.75
loss 768.625
loss 766.0
Optimization (L-BFGS) time:  215.4  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the .gif:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()


images = []
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()

    # create Figure
    fig = Figure(figsize=(6, 5), dpi=100)
    # Link canvas to fig
    canvas = FigureCanvasAgg(fig)

    # make the plot
    ax = Axes3D(fig)
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.plot_trisurf(qnp[:, 0],  qnp[:, 1],  qnp[:, 2],  triangles=FSnp, color=(1, 1, 0, .5), edgecolor=(1, 1, 1, .3),  linewidth=1)
    ax.plot_trisurf(VTnp[:, 0], VTnp[:, 1], VTnp[:, 2], triangles=FTnp, color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)

    yellow_proxy = plt.Rectangle((0, 0), 1, 1, fc="y")
    ax.legend([yellow_proxy, blue_proxy], ['deformed', 'target'])
    ax.set_title('LDDMM matching example, step ' + str(t))

    # draw it!
    canvas.draw()

    # save plot in a numpy array through buffer
    s, (width, height) = canvas.print_to_buffer()
    images.append(np.frombuffer(s, np.uint8).reshape((height, width, 4)))

save_folder = '../../../doc/_build/html/_images/'
os.makedirs(save_folder, exist_ok=True)
imageio.mimsave(save_folder + 'surface_matching.gif', images, duration=.5)

Total running time of the script: ( 4 minutes 9.253 seconds)

Gallery generated by Sphinx-Gallery