Note
Go to the end to download the full example code
Surface registration
Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.
Define our dataset
Standard imports
import os
import time
import torch
from torch.autograd import grad
import plotly.graph_objs as go
from pykeops.torch import Vi, Vj
# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device("cuda:0") if use_cuda else "cpu"
torchdtype = torch.float32
# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index # id of Gpu device (in case Gpu is used)
KeOpsdtype = torchdtype.__str__().split(".")[1] # 'float32'
Import data file, one of :
“hippos.pt” : original data (6611 vertices),
“hippos_red.pt” : reduced size (1654 vertices),
“hippos_reduc.pt” : further reduced (662 vertices),
“hippos_reduc_reduc.pt” : further reduced (68 vertices)
if use_cuda:
datafile = "data/hippos.pt"
else:
datafile = "data/hippos_reduc_reduc.pt"
Define the kernels
Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2)b_j\)
def GaussKernel(sigma):
x, y, b = Vi(0, 3), Vj(1, 3), Vj(2, 3)
gamma = 1 / (sigma * sigma)
D2 = x.sqdist(y)
K = (-D2 * gamma).exp()
return (K * b).sum_reduction(axis=1)
Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)
def GaussLinKernel(sigma):
x, y, u, v, b = Vi(0, 3), Vj(1, 3), Vi(2, 3), Vj(3, 3), Vj(4, 1)
gamma = 1 / (sigma * sigma)
D2 = x.sqdist(y)
K = (-D2 * gamma).exp() * (u * v).sum() ** 2
return (K * b).sum_reduction(axis=1)
Custom ODE solver, for ODE systems which are defined on tuples
def RalstonIntegrator():
def f(ODESystem, x0, nt, deltat=1.0):
x = tuple(map(lambda x: x.clone(), x0))
dt = deltat / nt
l = [x]
for i in range(nt):
xdot = ODESystem(*x)
xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
xdoti = ODESystem(*xi)
x = tuple(
map(
lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
x,
xdot,
xdoti,
)
)
l.append(x)
return l
return f
LDDMM implementation
Deformations: diffeomorphism
Hamiltonian system
def Hamiltonian(K):
def H(p, q):
return 0.5 * (p * K(q, q, p)).sum()
return H
def HamiltonianSystem(K):
H = Hamiltonian(K)
def HS(p, q):
Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
return -Gq, Gp
return HS
Shooting approach
def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
return Integrator(HamiltonianSystem(K), (p0, q0), nt)
def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
HS = HamiltonianSystem(K)
def FlowEq(x, p, q):
return (K(x, q, p),) + HS(p, q)
return Integrator(FlowEq, (x0, p0, q0), deltat)[0]
def LDDMMloss(K, dataloss, gamma=0):
def loss(p0, q0):
p, q = Shooting(p0, q0, K)[-1]
return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)
return loss
Data attachment term
Varifold data attachment loss for surfaces
# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
def get_center_length_normal(F, V):
V0, V1, V2 = (
V.index_select(0, F[:, 0]),
V.index_select(0, F[:, 1]),
V.index_select(0, F[:, 2]),
)
centers, normals = (V0 + V1 + V2) / 3, 0.5 * torch.cross(V1 - V0, V2 - V0)
length = (normals**2).sum(dim=1)[:, None].sqrt()
return centers, length, normals / length
CT, LT, NTn = get_center_length_normal(FT, VT)
cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()
def loss(VS):
CS, LS, NSn = get_center_length_normal(FS, VS)
return (
cst
+ (LS * K(CS, CS, NSn, NSn, LS)).sum()
- 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()
)
return loss
Registration
Load the dataset and plot it
VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)
x, y, z = (
q0[:, 0].detach().cpu().numpy(),
q0[:, 1].detach().cpu().numpy(),
q0[:, 2].detach().cpu().numpy(),
)
i, j, k = (
FS[:, 0].detach().cpu().numpy(),
FS[:, 1].detach().cpu().numpy(),
FS[:, 2].detach().cpu().numpy(),
)
xt, yt, zt = (
VT[:, 0].detach().cpu().numpy(),
VT[:, 1].detach().cpu().numpy(),
VT[:, 2].detach().cpu().numpy(),
)
it, jt, kt = (
FT[:, 0].detach().cpu().numpy(),
FT[:, 1].detach().cpu().numpy(),
FT[:, 2].detach().cpu().numpy(),
)
save_folder = os.path.join("..", "..", "..", "..", "doc", "_build", "html", "_images")
os.makedirs(save_folder, exist_ok=True)
fig = go.Figure(
data=[
go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, color="blue", opacity=0.50),
go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color="red", opacity=0.50),
]
)
fig.write_html(os.path.join(save_folder, "data.html"), auto_open=False)
# sphinx_gallery_thumbnail_path = '_static/plot_LDDMM_Surface_thumb.png'
Define data attachment and LDDMM functional
dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)
Perform optimization
# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)
optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print("performing optimization...")
start = time.time()
def closure():
optimizer.zero_grad()
L = loss(p0, q0)
print("loss", L.detach().cpu().numpy())
L.backward()
return L
for i in range(10):
print("it ", i, ": ", end="")
optimizer.step(closure)
print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
performing optimization...
it 0 : loss 87667.31
loss 81358.0
loss 21514.5
loss 15458.375
loss 9952.25
loss 7281.875
loss 4681.25
loss 4017.125
loss 3976.875
loss 3928.5
it 1 : loss 3928.5
loss 3887.75
loss 3773.625
loss 3568.25
loss 3268.0
loss 2933.375
loss 2562.75
loss 2287.5
loss 1969.75
loss 1708.125
it 2 : loss 1708.125
loss 1504.0
loss 1403.625
loss 1321.25
loss 1244.5
loss 1164.375
loss 1087.375
loss 1056.375
loss 1041.25
loss 1026.375
it 3 : loss 1026.375
loss 1013.125
loss 990.625
loss 948.5
loss 874.75
loss 833.25
loss 774.25
loss 731.375
loss 715.25
loss 700.625
it 4 : loss 700.625
loss 687.0
loss 673.25
loss 650.375
loss 609.75
loss 580.625
loss 563.5
loss 553.5
loss 547.75
loss 542.25
it 5 : loss 542.25
loss 535.0
loss 525.125
loss 514.375
loss 489.375
loss 459.0
loss 438.875
loss 443.75
loss 427.0
loss 425.125
it 6 : loss 425.125
loss 423.625
loss 422.625
loss 421.625
loss 419.25
loss 412.875
loss 404.75
loss 396.0
loss 393.125
loss 392.75
it 7 : loss 392.75
loss 392.375
loss 391.375
loss 388.875
loss 382.625
loss 368.75
loss 346.625
loss 329.625
loss 319.375
loss 318.875
it 8 : loss 318.875
loss 313.0
loss 312.5
loss 311.75
loss 311.0
loss 310.5
loss 310.0
loss 309.5
loss 308.875
loss 307.75
it 9 : loss 307.75
loss 306.25
loss 304.75
loss 303.875
loss 303.125
loss 301.875
loss 299.125
loss 297.25
loss 296.0
loss 295.375
Optimization (L-BFGS) time: 24.96 seconds
Display output
The animated version of the deformation:
nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)
The code to generate the figure:
VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()
# Create figure
fig = go.Figure()
fig.add_trace(
go.Mesh3d(
visible=True,
x=VTnp[:, 0],
y=VTnp[:, 1],
z=VTnp[:, 2],
i=FTnp[:, 0],
j=FTnp[:, 1],
k=FTnp[:, 2],
)
)
# Add traces, one for each slider step
for t in range(nt):
qnp = listpq[t][1].detach().cpu().numpy()
fig.add_trace(
go.Mesh3d(
visible=False,
x=qnp[:, 0],
y=qnp[:, 1],
z=qnp[:, 2],
i=FSnp[:, 0],
j=FSnp[:, 1],
k=FSnp[:, 2],
)
)
# Make 10th trace visible
fig.data[1].visible = True
# Create and add slider
steps = []
for i in range(len(fig.data) - 1):
step = dict(
method="restyle",
args=["visible", [False] * len(fig.data)],
)
step["args"][1][0] = True
step["args"][1][i + 1] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [
dict(active=0, currentvalue={"prefix": "time: "}, pad={"t": 20}, steps=steps)
]
fig.update_layout(sliders=sliders)
fig.write_html(os.path.join(save_folder, "results.html"), auto_open=False)
Total running time of the script: ( 0 minutes 25.493 seconds)