Creating a fancy interpolation video between 3D meshes.

N.B.: I am currently very busy writing my PhD thesis. Comments will come soon!

Setup

Standard imports.

import numpy as np
import torch
import os

use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
numpy = lambda x: x.detach().cpu().numpy()

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from geomloss import SamplesLoss
from pykeops.torch import LazyTensor

Utility: turn a triangle mesh into a weighted point cloud.

def to_measure(points, triangles):
    """Turns a triangle into a weighted point cloud."""

    # Our mesh is given as a collection of ABC triangles:
    A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]

    # Locations and weights of our Dirac atoms:
    X = (A + B + C) / 3  # centers of the faces
    S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2  # areas of the faces

    print(
        "File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format(
            len(X)
        )
    )

    # We return a (normalized) vector of weights + a "list" of points
    return tensor(S / np.sum(S)), tensor(X)

Utility: load “.ply” mesh file.

from plyfile import PlyData, PlyElement


def load_ply_file(fname):
    """Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face."""

    # Load the data, and read the connectivity information:
    plydata = PlyData.read(fname)
    triangles = np.vstack(plydata["face"].data["vertex_indices"])

    # Normalize the point cloud, as specified by the user:
    points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]])

    return to_measure(points, triangles)

Utility: load “.nii” volume file.

import SimpleITK as sitk
from skimage.measure import marching_cubes


def load_nii_file(fname, threshold=0.5):
    """Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud."""

    mask = sitk.GetArrayFromImage(sitk.ReadImage(fname))
    # mask = skimage.transform.downscale_local_mean(mask, (4,4,4))
    verts, faces, normals, values = marching_cubes(mask, threshold)

    return to_measure(verts, faces)

Synthetic sphere - a typical source measure:

def create_sphere(n_samples=1000):
    """Creates a uniform sample on the unit sphere."""
    n_samples = int(n_samples)

    indices = np.arange(0, n_samples, dtype=float) + 0.5
    phi = np.arccos(1 - 2 * indices / n_samples)
    theta = np.pi * (1 + 5 ** 0.5) * indices

    x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)
    points = np.vstack((x, y, z)).T
    weights = np.ones(n_samples) / n_samples

    return tensor(weights), tensor(points)

Simple (slow) display routine:

def display_cloud(ax, measure, color):

    w_i, x_i = numpy(measure[0]), numpy(measure[1])

    ax.view_init(elev=110, azim=-90)
    # ax.set_aspect('equal')

    weights = w_i / w_i.sum()
    ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color)

    ax.axes.set_xlim3d(left=-1.4, right=1.4)
    ax.axes.set_ylim3d(bottom=-1.4, top=1.4)
    ax.axes.set_zlim3d(bottom=-1.4, top=1.4)

Save the output as a VTK folder, to be rendered with Paraview:

folder = "output/wasserstein_3D/"
os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True)

from pyvtk import PolyData, PointData, CellData, Scalars, VtkData, PointData


def save_vtk(fname, points, colors):
    """N.B.: Paraview is a good VTK viewer, which supports ray-tracing."""

    structure = PolyData(points=points, vertices=np.arange(len(points)))
    values = PointData(Scalars(colors, name="colors"))
    vtk = VtkData(structure, values)

    vtk.tofile(folder + fname, "binary")

Data

Shall we work on subsampled data or at full resolution?

fast_demo = False if use_cuda else True

if use_cuda:
    Npoints = 1e4 if fast_demo else 2e5
else:
    Npoints = 1e3

Create a reference template:

template = create_sphere(Npoints)

Use color labels to track the particles:

K = 12
colors = (K * template[1][:, 0]).cos()
colors = colors.view(-1).detach().cpu().numpy()

Fetch the data:

os.makedirs(os.path.dirname("data/"), exist_ok=True)
if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"):
    print("Fetching the data... ", end="", flush=True)
    import urllib.request

    urllib.request.urlretrieve(
        "http://www.kernel-operations.io/data/wasserstein_3D_models.zip",
        "data/wasserstein_3D_models.zip",
    )

    import shutil

    shutil.unpack_archive("data/wasserstein_3D_models.zip", "data")
    print("Done.")

Load the data on the GPU:

print("Loading the data:")
# N.B.: Since Plyfile is far from being optimized, this may take some time!
targets = [
    load_ply_file("data/wasserstein_3D_models/Stanford_dragon_200k.ply"),
    load_ply_file("data/wasserstein_3D_models/vertebrae_400k_biol260_sketchfab_CC.ply"),
    load_nii_file("data/wasserstein_3D_models/brain.nii.gz"),
]
Loading the data:
File loaded, and encoded as the weighted sum of 202,520 atoms in 3D.
File loaded, and encoded as the weighted sum of 394,237 atoms in 3D.
File loaded, and encoded as the weighted sum of 382,074 atoms in 3D.

Normalize and subsample everyone, if required:

def normalize(measure, n=None):
    """Reduce a point cloud to at most n points and normalize the weights and point cloud."""
    weights, locations = measure
    N = len(weights)

    if n is not None and n < N:
        n = int(n)
        indices = torch.randperm(N)
        indices = indices[:n]
        weights, locations = weights[indices], locations[indices]

    weights = weights / weights.sum()
    weights, locations = weights.contiguous(), locations.contiguous()

    # Center, normalize the point cloud
    mean = (weights.view(-1, 1) * locations).sum(dim=0)
    locations -= mean
    std = (weights.view(-1) * (locations ** 2).sum(dim=1).view(-1)).sum().sqrt()
    locations /= std

    return weights, locations


targets = [normalize(t, n=Npoints) for t in targets]

Fine tuning:

template = template[0], template[1] / 2 + tensor(
    [0.5, 0.0, 0.0]
)  # Smaller sphere, towards the back of the dragon
targets[1] = targets[1][0], targets[1][1] @ tensor(
    [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
)  # Turn the vertebra
targets[2] = targets[2][0], -targets[2][1]  # Flip the brain

Optimal Transport matchings

Define our solver:

import time

Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1)


def OT_registration(source, target, name):
    a, x = source  # weights, locations
    b, y = target  # weights, locations

    x.requires_grad = True
    z = x.clone()  # Moving point cloud

    if use_cuda:
        torch.cuda.synchronize()
    start = time.time()

    nits = 4 if fast_demo else 10

    for it in range(nits):
        wasserstein_zy = Loss(a, z, b, y)
        [grad_z] = torch.autograd.grad(wasserstein_zy, [z])
        z -= grad_z / a[:, None]  # Apply the regularized Brenier map

        # save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors)

    end = time.time()
    print("Registered {} in {:.3f}s.".format(name, end - start))

    return z

Register the source onto the targets:

matchings = [
    OT_registration(template, target, f"shape{i+1}")
    for (i, target) in enumerate(targets)
]
Registered shape1 in 8.594s.
Registered shape2 in 17.520s.
Registered shape3 in 4.895s.

Display our matchings:

for (i, (matching, target)) in enumerate(zip(matchings, targets)):

    fig = plt.figure(figsize=(6, 6))
    plt.set_cmap("hsv")

    ax = fig.add_subplot(1, 1, 1, projection="3d")

    display_cloud(ax, (template[0], matching), colors)
    display_cloud(ax, target, "blue")
    ax.set_title(
        "Registered (N={:,}) and target {} (M={:,}) point clouds".format(
            len(matching), i + 1, len(target[0])
        )
    )
    plt.tight_layout()
  • Registered (N=200,000) and target 1 (M=200,000) point clouds
  • Registered (N=200,000) and target 2 (M=200,000) point clouds
  • Registered (N=200,000) and target 3 (M=200,000) point clouds

Movie

Save them as a collection of VTK files:

FPS = 32 if fast_demo else 32

source = template[1]
pairs = [
    (source, source),
    (source, matchings[0]),
    (matchings[0], matchings[0]),
    (matchings[0], matchings[1]),
    (matchings[1], matchings[1]),
    (matchings[1], matchings[2]),
    (matchings[2], matchings[2]),
    (matchings[2], source),
]

frame = 0

print("Save as a VTK movie...", end="", flush=True)
for (A, B) in pairs:
    A, B = numpy(A), numpy(B)
    for t in np.linspace(0, 1, FPS):
        save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors)
        frame += 1

print("Done.")
plt.show()
Save as a VTK movie...Done.

Total running time of the script: ( 3 minutes 4.567 seconds)

Gallery generated by Sphinx-Gallery