Creating a fancy interpolation video between 3D meshes.
Standard imports.
import numpy as np
import torch
import os
use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
numpy = lambda x: x.detach().cpu().numpy()
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from geomloss import SamplesLoss
from pykeops.torch import LazyTensor
Utility: turn a triangle mesh into a weighted point cloud.
def to_measure(points, triangles):
"""Turns a triangle into a weighted point cloud."""
# Our mesh is given as a collection of ABC triangles:
A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]
# Locations and weights of our Dirac atoms:
X = (A + B + C) / 3 # centers of the faces
S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2 # areas of the faces
"File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format(
# We return a (normalized) vector of weights + a "list" of points
return tensor(S / np.sum(S)), tensor(X)
Utility: load “.ply” mesh file.
from plyfile import PlyData, PlyElement
def load_ply_file(fname):
"""Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face."""
# Load the data, and read the connectivity information:
plydata =
triangles = np.vstack(plydata["face"].data["vertex_indices"])
# Normalize the point cloud, as specified by the user:
points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]])
return to_measure(points, triangles)
Utility: load “.nii” volume file.
import SimpleITK as sitk
from skimage.measure import marching_cubes
def load_nii_file(fname, threshold=0.5):
"""Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud."""
mask = sitk.GetArrayFromImage(sitk.ReadImage(fname))
# mask = skimage.transform.downscale_local_mean(mask, (4,4,4))
verts, faces, normals, values = marching_cubes(mask, threshold)
return to_measure(verts, faces)
Synthetic sphere - a typical source measure:
def create_sphere(n_samples=1000):
"""Creates a uniform sample on the unit sphere."""
n_samples = int(n_samples)
indices = np.arange(0, n_samples, dtype=float) + 0.5
phi = np.arccos(1 - 2 * indices / n_samples)
theta = np.pi * (1 + 5 ** 0.5) * indices
x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)
points = np.vstack((x, y, z)).T
weights = np.ones(n_samples) / n_samples
return tensor(weights), tensor(points)
Simple (slow) display routine:
def display_cloud(ax, measure, color):
w_i, x_i = numpy(measure[0]), numpy(measure[1])
ax.view_init(elev=110, azim=-90)
# ax.set_aspect('equal')
weights = w_i / w_i.sum()
ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color)
ax.axes.set_xlim3d(left=-1.4, right=1.4)
ax.axes.set_ylim3d(bottom=-1.4, top=1.4)
ax.axes.set_zlim3d(bottom=-1.4, top=1.4)
Save the output as a VTK folder, to be rendered with Paraview:
folder = "output/wasserstein_3D/"
os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True)
from pyvtk import PolyData, PointData, CellData, Scalars, VtkData, PointData
def save_vtk(fname, points, colors):
"""N.B.: Paraview is a good VTK viewer, which supports ray-tracing."""
structure = PolyData(points=points, vertices=np.arange(len(points)))
values = PointData(Scalars(colors, name="colors"))
vtk = VtkData(structure, values)
vtk.tofile(folder + fname, "binary")
Shall we work on subsampled data or at full resolution?
fast_demo = False if use_cuda else True
if use_cuda:
Npoints = 1e4 if fast_demo else 2e5
Npoints = 1e3
Create a reference template:
template = create_sphere(Npoints)
Use color labels to track the particles:
K = 12
colors = (K * template[1][:, 0]).cos()
colors = colors.view(-1).detach().cpu().numpy()
Fetch the data:
os.makedirs(os.path.dirname("data/"), exist_ok=True)
if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"):
print("Fetching the data... ", end="", flush=True)
import urllib.request
import shutil
shutil.unpack_archive("data/", "data")
Load the data on the GPU:
print("Loading the data:")
# N.B.: Since Plyfile is far from being optimized, this may take some time!
targets = [
Normalize and subsample everyone, if required:
def normalize(measure, n=None):
"""Reduce a point cloud to at most n points and normalize the weights and point cloud."""
weights, locations = measure
N = len(weights)
if n is not None and n < N:
n = int(n)
indices = torch.randperm(N)
indices = indices[:n]
weights, locations = weights[indices], locations[indices]
weights = weights / weights.sum()
weights, locations = weights.contiguous(), locations.contiguous()
# Center, normalize the point cloud
mean = (weights.view(-1, 1) * locations).sum(dim=0)
locations -= mean
std = (weights.view(-1) * (locations ** 2).sum(dim=1).view(-1)).sum().sqrt()
locations /= std
return weights, locations
targets = [normalize(t, n=Npoints) for t in targets]
Fine tuning:
template = template[0], template[1] / 2 + tensor(
[0.5, 0.0, 0.0]
) # Smaller sphere, towards the back of the dragon
targets[1] = targets[1][0], targets[1][1] @ tensor(
[[0, 0, 1], [0, 1, 0], [1, 0, 0]]
) # Turn the vertebra
targets[2] = targets[2][0], -targets[2][1] # Flip the brain
Optimal Transport matchings
Define our solver:
import time
Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1)
def OT_registration(source, target, name):
a, x = source # weights, locations
b, y = target # weights, locations
x.requires_grad = True
z = x.clone() # Moving point cloud
if use_cuda:
start = time.time()
nits = 4 if fast_demo else 10
for it in range(nits):
wasserstein_zy = Loss(a, z, b, y)
[grad_z] = torch.autograd.grad(wasserstein_zy, [z])
z -= grad_z / a[:, None] # Apply the regularized Brenier map
# save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors)
end = time.time()
print("Registered {} in {:.3f}s.".format(name, end - start))
return z
Register the source onto the targets:
matchings = [
OT_registration(template, target, f"shape{i+1}")
for (i, target) in enumerate(targets)
Registered shape1 in 8.594s.
Registered shape2 in 17.520s.
Registered shape3 in 4.895s.
Display our matchings:
for (i, (matching, target)) in enumerate(zip(matchings, targets)):
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(1, 1, 1, projection="3d")
display_cloud(ax, (template[0], matching), colors)
display_cloud(ax, target, "blue")
"Registered (N={:,}) and target {} (M={:,}) point clouds".format(
len(matching), i + 1, len(target[0])
Save them as a collection of VTK files:
FPS = 32 if fast_demo else 32
source = template[1]
pairs = [
(source, source),
(source, matchings[0]),
(matchings[0], matchings[0]),
(matchings[0], matchings[1]),
(matchings[1], matchings[1]),
(matchings[1], matchings[2]),
(matchings[2], matchings[2]),
(matchings[2], source),
frame = 0
print("Save as a VTK movie...", end="", flush=True)
for (A, B) in pairs:
A, B = numpy(A), numpy(B)
for t in np.linspace(0, 1, FPS):
save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors)
frame += 1
Save as a VTK movie...Done.
Total running time of the script: ( 3 minutes 4.567 seconds)