Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
import imageio

from torch.autograd import grad

import time

from pykeops.torch import Kernel, kernel_product
from pykeops.torch.kernel_product.formula import *

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device('cuda:0') if use_cuda else 'cpu'
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split('.')[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = 'data/hippos.pt'
else:
    datafile = 'data/hippos_reduc_reduc.pt'

Define the kernels

Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\|x_i-y_j\|^2)b_j\)

def GaussKernel(sigma):
    def K(x, y, b):
        params = {
            'id': Kernel('gaussian(x,y)'),
            'gamma': 1 / (sigma * sigma),
            'backend': 'auto'
        }
        return kernel_product(params, x, y, b)
    return K

Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)

def GaussLinKernel(sigma):
    def K(x, y, u, v, b):
        params = {
            'id': Kernel('gaussian(x,y) * linear(u,v)**2'),
            'gamma': (1 / (sigma * sigma), None),
            'backend': 'auto'
        }
        return kernel_product(params, (x, u), (y, v), b)
    return K

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(map(lambda x, xdot, xdoti: x + (.25 * dt) * (xdot + 3 * xdoti), x, xdot, xdoti))
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return .5 * (p * K(q, q, p)).sum()
    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)
    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp
    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)
    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p,q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)
    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def CompCLNn(F, V):
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        C, N = .5 * (V0 + V1 + V2), .5 * torch.cross(V1 - V0, V2 - V0)
        L = (N ** 2).sum(dim=1)[:, None].sqrt()
        return C, L, N / L

    CT, LT, NTn = CompCLNn(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = CompCLNn(FS, VS)
        return cst + (LS * K(CS, CS, NSn, NSn, LS)).sum() - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

fig = plt.figure()
ax = Axes3D(fig)
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.plot_trisurf(q0.detach().cpu().numpy()[:, 0],
                q0.detach().cpu().numpy()[:, 1],
                q0.detach().cpu().numpy()[:, 2],
                triangles=FS.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(1, 0, 0, .08), linewidth=1)
ax.plot_trisurf(VT.detach().cpu().numpy()[:, 0],
                VT.detach().cpu().numpy()[:, 1],
                VT.detach().cpu().numpy()[:, 2],
                triangles=FT.detach().cpu().numpy(),
                color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)
blue_proxy = plt.Rectangle((0, 0), 1, 1, fc="b")
red_proxy = plt.Rectangle((0, 0), 1, 1, fc=(1, 0, 0, .5))
ax.legend([red_proxy,  blue_proxy], ['source', 'target'])
ax.set_title('Data')
plt.show()
../../_images/sphx_glr_plot_LDDMM_Surface_001.png

Out:

/home/bcharlier/keops/pykeops/tutorials/surface_registration/plot_LDDMM_Surface.py:211: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  plt.show()

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print('performing optimization...')
start = time.time()

def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print('loss', L.detach().cpu().numpy())
    L.backward()
    return L

for i in range(10):
    print('it ', i, ': ', end='')
    optimizer.step(closure)

print('Optimization (L-BFGS) time: ', round(time.time() - start, 2), ' seconds')

Out:

performing optimization...
it  0 : loss 91181.375
loss 84040.125
loss 26011.562
loss 19569.188
loss 16383.25
loss 14535.3125
loss 10051.25
loss 9112.6875
loss 7602.5
loss 7533.25
it  1 : loss 7533.25
loss 7459.75
loss 7357.5
loss 7194.8125
loss 6773.6875
loss 6232.0
loss 5672.5
loss 5336.25
loss 5147.4375
loss 4827.25
it  2 : loss 4827.25
loss 4333.6875
loss 3663.625
loss 3325.6875
loss 2899.625
loss 2759.25
loss 2681.5
loss 2521.0625
loss 2276.4375
loss 2153.4375
it  3 : loss 2153.4375
loss 2100.4375
loss 2073.5
loss 2034.375
loss 1959.75
loss 1834.125
loss 1798.9375
loss 1699.0
loss 1662.6875
loss 1618.5
it  4 : loss 1618.5
loss 1552.5625
loss 1478.8125
loss 1393.6875
loss 1344.3125
loss 1326.875
loss 1314.3125
loss 1283.4375
loss 1243.5
loss 1207.6875
it  5 : loss 1207.6875
loss 1190.5625
loss 1179.25
loss 1167.375
loss 1154.75
loss 1139.9375
loss 1121.875
loss 1087.375
loss 1053.25
loss 1022.75
it  6 : loss 1022.75
loss 1009.625
loss 1001.8125
loss 999.1875
loss 997.125
loss 994.125
loss 986.25
loss 977.25
loss 967.4375
loss 957.6875
it  7 : loss 957.6875
loss 949.25
loss 944.4375
loss 942.5625
loss 940.75
loss 937.0625
loss 927.3125
loss 909.5625
loss 883.875
loss 860.125
it  8 : loss 860.125
loss 844.4375
loss 841.3125
loss 836.125
loss 833.9375
loss 829.5
loss 826.9375
loss 823.6875
loss 819.3125
loss 810.75
it  9 : loss 810.75
loss 798.9375
loss 782.375
loss 773.75
loss 771.1875
loss 770.625
loss 769.875
loss 768.1875
loss 764.25
loss 755.875
Optimization (L-BFGS) time:  193.31  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the .gif:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()


images = []
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()

    # create Figure
    fig = Figure(figsize=(6, 5), dpi=100)
    # Link canvas to fig
    canvas = FigureCanvasAgg(fig)

    # make the plot
    ax = Axes3D(fig)
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.plot_trisurf(q0np[:, 0], q0np[:, 1], q0np[:, 2], triangles=FSnp, color=(0, 0, 0, 0),  edgecolor=(1, 0, 0, .08 * (nt-1-t)/(nt-1)), linewidth=1)
    ax.plot_trisurf(qnp[:, 0],  qnp[:, 1],  qnp[:, 2],  triangles=FSnp, color=(1, 1, 0, .5), edgecolor=(1, 1, 1, .3),  linewidth=1)
    ax.plot_trisurf(VTnp[:, 0], VTnp[:, 1], VTnp[:, 2], triangles=FTnp, color=(0, 0, 0, 0),  edgecolor=(0, 0, 1, .3),  linewidth=1)

    yellow_proxy = plt.Rectangle((0, 0), 1, 1, fc="y")
    red_proxy = plt.Rectangle((0, 0), 1, 1, fc=(1, 0, 0, .8 * (nt - 1 - t) / (nt - 1)))
    ax.legend([red_proxy, yellow_proxy, blue_proxy], ['source', 'deformed', 'target'])
    ax.set_title('LDDMM matching example, step ' + str(t))

    # draw it!
    canvas.draw()

    # save plot in a numpy array through buffer
    s, (width, height) = canvas.print_to_buffer()
    images.append(np.frombuffer(s, np.uint8).reshape((height, width, 4)))

save_folder = '../../../doc/_build/html/_images/'
os.makedirs(save_folder, exist_ok=True)
imageio.mimsave(save_folder + 'surface_matching.gif', images, duration=.5)

Total running time of the script: ( 3 minutes 53.500 seconds)

Gallery generated by Sphinx-Gallery