Note
Go to the end to download the full example code
Creating a fancy interpolation video between 3D meshes.
N.B.: I am currently very busy writing my PhD thesis. Comments will come soon!
Setup
Standard imports.
import numpy as np
import torch
import os
use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
numpy = lambda x: x.detach().cpu().numpy()
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from geomloss import SamplesLoss
from pykeops.torch import LazyTensor
Utility: turn a triangle mesh into a weighted point cloud.
def to_measure(points, triangles):
"""Turns a triangle into a weighted point cloud."""
# Our mesh is given as a collection of ABC triangles:
A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]
# Locations and weights of our Dirac atoms:
X = (A + B + C) / 3 # centers of the faces
S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2 # areas of the faces
print(
"File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format(
len(X)
)
)
# We return a (normalized) vector of weights + a "list" of points
return tensor(S / np.sum(S)), tensor(X)
Utility: load “.ply” mesh file.
from plyfile import PlyData, PlyElement
def load_ply_file(fname):
"""Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face."""
# Load the data, and read the connectivity information:
plydata = PlyData.read(fname)
triangles = np.vstack(plydata["face"].data["vertex_indices"])
# Normalize the point cloud, as specified by the user:
points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]])
return to_measure(points, triangles)
Utility: load “.nii” volume file.
import SimpleITK as sitk
from skimage.measure import marching_cubes
def load_nii_file(fname, threshold=0.5):
"""Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud."""
mask = sitk.GetArrayFromImage(sitk.ReadImage(fname))
# mask = skimage.transform.downscale_local_mean(mask, (4,4,4))
verts, faces, normals, values = marching_cubes(mask, threshold)
return to_measure(verts, faces)
Synthetic sphere - a typical source measure:
def create_sphere(n_samples=1000):
"""Creates a uniform sample on the unit sphere."""
n_samples = int(n_samples)
indices = np.arange(0, n_samples, dtype=float) + 0.5
phi = np.arccos(1 - 2 * indices / n_samples)
theta = np.pi * (1 + 5 ** 0.5) * indices
x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)
points = np.vstack((x, y, z)).T
weights = np.ones(n_samples) / n_samples
return tensor(weights), tensor(points)
Simple (slow) display routine:
def display_cloud(ax, measure, color):
w_i, x_i = numpy(measure[0]), numpy(measure[1])
ax.view_init(elev=110, azim=-90)
# ax.set_aspect('equal')
weights = w_i / w_i.sum()
ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color)
ax.axes.set_xlim3d(left=-1.4, right=1.4)
ax.axes.set_ylim3d(bottom=-1.4, top=1.4)
ax.axes.set_zlim3d(bottom=-1.4, top=1.4)
Save the output as a VTK folder, to be rendered with Paraview:
folder = "output/wasserstein_3D/"
os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True)
from pyvtk import PolyData, PointData, CellData, Scalars, VtkData, PointData
def save_vtk(fname, points, colors):
"""N.B.: Paraview is a good VTK viewer, which supports ray-tracing."""
structure = PolyData(points=points, vertices=np.arange(len(points)))
values = PointData(Scalars(colors, name="colors"))
vtk = VtkData(structure, values)
vtk.tofile(folder + fname, "binary")
Data
Shall we work on subsampled data or at full resolution?
fast_demo = False if use_cuda else True
if use_cuda:
Npoints = 1e4 if fast_demo else 2e5
else:
Npoints = 1e3
Create a reference template:
template = create_sphere(Npoints)
Use color labels to track the particles:
K = 12
colors = (K * template[1][:, 0]).cos()
colors = colors.view(-1).detach().cpu().numpy()
Fetch the data:
os.makedirs(os.path.dirname("data/"), exist_ok=True)
if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"):
print("Fetching the data... ", end="", flush=True)
import urllib.request
urllib.request.urlretrieve(
"http://www.kernel-operations.io/data/wasserstein_3D_models.zip",
"data/wasserstein_3D_models.zip",
)
import shutil
shutil.unpack_archive("data/wasserstein_3D_models.zip", "data")
print("Done.")
Load the data on the GPU:
print("Loading the data:")
# N.B.: Since Plyfile is far from being optimized, this may take some time!
targets = [
load_ply_file("data/wasserstein_3D_models/Stanford_dragon_200k.ply"),
load_ply_file("data/wasserstein_3D_models/vertebrae_400k_biol260_sketchfab_CC.ply"),
load_nii_file("data/wasserstein_3D_models/brain.nii.gz"),
]
Loading the data:
File loaded, and encoded as the weighted sum of 202,520 atoms in 3D.
File loaded, and encoded as the weighted sum of 394,237 atoms in 3D.
File loaded, and encoded as the weighted sum of 382,074 atoms in 3D.
Normalize and subsample everyone, if required:
def normalize(measure, n=None):
"""Reduce a point cloud to at most n points and normalize the weights and point cloud."""
weights, locations = measure
N = len(weights)
if n is not None and n < N:
n = int(n)
indices = torch.randperm(N)
indices = indices[:n]
weights, locations = weights[indices], locations[indices]
weights = weights / weights.sum()
weights, locations = weights.contiguous(), locations.contiguous()
# Center, normalize the point cloud
mean = (weights.view(-1, 1) * locations).sum(dim=0)
locations -= mean
std = (weights.view(-1) * (locations ** 2).sum(dim=1).view(-1)).sum().sqrt()
locations /= std
return weights, locations
targets = [normalize(t, n=Npoints) for t in targets]
Fine tuning:
template = template[0], template[1] / 2 + tensor(
[0.5, 0.0, 0.0]
) # Smaller sphere, towards the back of the dragon
targets[1] = targets[1][0], targets[1][1] @ tensor(
[[0, 0, 1], [0, 1, 0], [1, 0, 0]]
) # Turn the vertebra
targets[2] = targets[2][0], -targets[2][1] # Flip the brain
Optimal Transport matchings
Define our solver:
import time
Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1)
def OT_registration(source, target, name):
a, x = source # weights, locations
b, y = target # weights, locations
x.requires_grad = True
z = x.clone() # Moving point cloud
if use_cuda:
torch.cuda.synchronize()
start = time.time()
nits = 4 if fast_demo else 10
for it in range(nits):
wasserstein_zy = Loss(a, z, b, y)
[grad_z] = torch.autograd.grad(wasserstein_zy, [z])
z -= grad_z / a[:, None] # Apply the regularized Brenier map
# save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors)
end = time.time()
print("Registered {} in {:.3f}s.".format(name, end - start))
return z
Register the source onto the targets:
matchings = [
OT_registration(template, target, f"shape{i+1}")
for (i, target) in enumerate(targets)
]
Registered shape1 in 8.594s.
Registered shape2 in 17.520s.
Registered shape3 in 4.895s.
Display our matchings:
for (i, (matching, target)) in enumerate(zip(matchings, targets)):
fig = plt.figure(figsize=(6, 6))
plt.set_cmap("hsv")
ax = fig.add_subplot(1, 1, 1, projection="3d")
display_cloud(ax, (template[0], matching), colors)
display_cloud(ax, target, "blue")
ax.set_title(
"Registered (N={:,}) and target {} (M={:,}) point clouds".format(
len(matching), i + 1, len(target[0])
)
)
plt.tight_layout()
Movie
Save them as a collection of VTK files:
FPS = 32 if fast_demo else 32
source = template[1]
pairs = [
(source, source),
(source, matchings[0]),
(matchings[0], matchings[0]),
(matchings[0], matchings[1]),
(matchings[1], matchings[1]),
(matchings[1], matchings[2]),
(matchings[2], matchings[2]),
(matchings[2], source),
]
frame = 0
print("Save as a VTK movie...", end="", flush=True)
for (A, B) in pairs:
A, B = numpy(A), numpy(B)
for t in np.linspace(0, 1, FPS):
save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors)
frame += 1
print("Done.")
plt.show()
Save as a VTK movie...Done.
Total running time of the script: ( 3 minutes 4.567 seconds)