Note
Go to the end to download the full example code
Label transfer with Optimal Transport
Let’s use a regularized Optimal Transport plan to transfer labels from one point cloud to another.
Setup
Standard imports:
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
Display routines:
import imageio
def load_image(fname):
img = imageio.imread(fname)[::-1, :, :3] # RGB, without Alpha channel
return img / 255.0 # Normalized to [0,1]
def display_samples(ax, x, color="black"):
x_ = x.detach().cpu().numpy()
if type(color) is not str:
color = color.detach().cpu().numpy()
ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none")
Draw labeled samples from an RGB image:
from random import choices
def draw_samples(fname, n, dtype=torch.FloatTensor, labels=False):
A = load_image(fname)
xg, yg = np.meshgrid(
np.arange(A.shape[0]),
np.arange(A.shape[1]),
indexing="xy",
)
# Draw random coordinates according to the input density:
A_gray = (1 - A).sum(2)
grid = list(zip(xg.ravel(), yg.ravel()))
dens = A_gray.ravel() / A_gray.sum()
dots = np.array(choices(grid, dens, k=n))
# Pick the correct labels:
if labels:
labs = A[dots[:, 1], dots[:, 0]].reshape((n, 3))
# Normalize the coordinates to fit in the unit square, and add some noise
dots = (dots.astype(float) + 0.5) / np.array([A.shape[0], A.shape[1]])
dots += (0.5 / A.shape[0]) * np.random.standard_normal(dots.shape)
if labels:
return torch.from_numpy(dots).type(dtype), torch.from_numpy(labs).type(dtype)
else:
return torch.from_numpy(dots).type(dtype)
Dataset
Our source and target samples are drawn from measures whose densities are stored in simple PNG files. They allow us to define a pair of discrete probability measures:
with uniform weights
N, M = (500, 500) if not use_cuda else (10000, 10000)
X_i = draw_samples("data/threeblobs_a.png", N, dtype)
Y_j, l_j = draw_samples("data/threeblobs_b.png", M, dtype, labels=True)
/home/code/geomloss/geomloss/examples/optimal_transport/plot_optimal_transport_labels.py:31: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.
img = imageio.imread(fname)[::-1, :, :3] # RGB, without Alpha channel
In this tutorial, the
for red points, for green points, for blue points.
In the next few paragraphs, we’ll see how to use regularized Optimal Transport plans
to transfer these labels from the
plt.figure(figsize=(8, 8))
ax = plt.gca()
ax.scatter([10], [10]) # shameless hack to prevent a slight change of axis...
# Fancy display:
display_samples(ax, Y_j, l_j)
display_samples(ax, X_i)
ax.set_title("Source (Labeled) and Target point clouds")
ax.axis([0, 1, 0, 1])
ax.set_aspect("equal", adjustable="box")
plt.tight_layout()

Regularized Optimal Transport
The SamplesLoss("sinkhorn")
layer relies
on a fast multiscale solver for the regularized Optimal Transport problem:
where SamplesLoss
computes the
unbiased (positive, definite) Sinkhorn divergence
and returns a differentiable scalar value.
But if we set the optional parameters debias to False
and potentials to True, we will instead get access
to the optimal dual potentials
Note
By default, SamplesLoss("sinkhorn")
uses
an aggressive optimization heuristic where the blurring scale is halved
between two successive iterations of the Sinkhorn loop,
until reaching the required target value (scaling = .5).
This choice is sensible when the Optimal Transport plan
is used as a (cheap) gradient for an outer registration loop…
But in this tutorial, setting the trade-off between speed
(scaling
from geomloss import SamplesLoss
blur = 0.05
OT_solver = SamplesLoss(
"sinkhorn", p=2, blur=blur, scaling=0.9, debias=False, potentials=True
)
F_i, G_j = OT_solver(X_i, Y_j)
With a linear memory footprint, these two dual vectors encode
a full transport plan on the product space
and is such that
up to convergence in the Sinkhorn loop.
Transfer of labels. To transport our source labels
for all points
Implicit computations.
Keep in mind, however, that the full
from pykeops.torch import generic_sum
# Define our KeOps CUDA kernel:
transfer = generic_sum(
"Exp( (F_i + G_j - IntInv(2)*SqDist(X_i,Y_j)) / E ) * L_j", # See the formula above
"Lab = Vi(3)", # Output: one vector of size 3 per line
"E = Pm(1)", # 1st arg: a scalar parameter, the temperature
"X_i = Vi(2)", # 2nd arg: one 2d-point per line
"Y_j = Vj(2)", # 3rd arg: one 2d-point per column
"F_i = Vi(1)", # 4th arg: one scalar value per line
"G_j = Vj(1)", # 5th arg: one scalar value per column
"L_j = Vj(3)",
) # 6th arg: one vector of size 3 per column
# And apply it on the data (KeOps is pretty picky on the input shapes...):
labels_i = (
transfer(
torch.Tensor([blur ** 2]).type(dtype),
X_i,
Y_j,
F_i.view(-1, 1),
G_j.view(-1, 1),
l_j,
)
/ M
)
That’s it! We may now display our target point cloud
# sphinx_gallery_thumbnail_number = 2
plt.figure(figsize=(8, 8))
ax = plt.gca()
ax.scatter([10], [10]) # shameless hack to prevent a slight change of axis...
# Fancy display:
display_samples(ax, Y_j, l_j)
display_samples(ax, X_i, labels_i.clamp(0, 1))
ax.set_title("Labels transferred with Optimal Transport")
ax.axis([0, 1, 0, 1])
ax.set_aspect("equal", adjustable="box")
plt.tight_layout()

Unbalanced Optimal Transport
As evidenced above, the blur parameter allows us to smooth our optimal transport plan to remove noise in the final labelling. In most real-life situations, we may also wish to gain robustness against outliers by preventing samples from having too much influence outside of a fixed neighborhood.
SamplesLoss("sinkhorn")
allows us to do
so through the reach parameter, which is set to None (
where the hard marginal constraints have been replaced by
a soft Kullback-Leibler penalty whose strength is specified through
a positive parameter
OT_solver = SamplesLoss(
"sinkhorn", p=2, blur=blur, reach=0.2, scaling=0.9, debias=False, potentials=True
)
F_i, G_j = OT_solver(X_i, Y_j)
# And apply it on the data:
labels_i = (
transfer(
torch.Tensor([blur ** 2]).type(dtype),
X_i,
Y_j,
F_i.view(-1, 1),
G_j.view(-1, 1),
l_j,
)
/ M
)
As we display our new set of labels, we can check that
colors don’t get transported beyond the specified reach = .2.
Target points which are too far away from the source simply
stay black, with a soft label
plt.figure(figsize=(8, 8))
ax = plt.gca()
ax.scatter([10], [10]) # shameless hack to prevent a slight change of axis...
display_samples(ax, Y_j, l_j)
display_samples(ax, X_i, labels_i.clamp(0, 1))
ax.set_title("Labels transferred with unbalanced Optimal Transport")
ax.axis([0, 1, 0, 1])
ax.set_aspect("equal", adjustable="box")
plt.tight_layout()
plt.show()

Total running time of the script: ( 0 minutes 0.647 seconds)