Note
Click here to download the full example code
Sampling on the 3D rotation group
Let’s show how to sample some distributions on the compact manifold SO(3).
Introduction
First, some standard imports.
import numpy as np
import torch
from matplotlib import pyplot as plt
import warnings
import matplotlib.cbook
warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation)
plt.rcParams.update({"figure.max_open_warning": 0})
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
All functions that are relevant to the rotation group are stored in the monaco.rotations submodule. We first create our manifold, on the GPU whenever possible.
from monaco.rotations import Rotations
space = Rotations(dtype=dtype)
Rotations are encoded as quaternions and displayed as Euler vectors in a sphere of radius pi. Antipodal points at the boundary are identified with each other. Here, we create two arbitrary rotations and sample points at random in their neighborhood.
N = 10000 if use_cuda else 50
ref = torch.ones(N, 1).type(dtype) * torch.FloatTensor([1, 0, 0, 0]).type(dtype)
ref2 = torch.ones(N, 1).type(dtype) * torch.FloatTensor([1, 5, 5, 0]).type(dtype)
ref = torch.cat((ref, ref2), dim=0)
from monaco.rotations import BallProposal
proposal = BallProposal(space, scale=0.5)
x = proposal.sample(ref)
# Display the initial configuration:
plt.figure(figsize=(8, 8))
space.scatter(x, "red")
space.draw_frame()
plt.tight_layout()
Procrustes analysis
We consider the Von Mises distribution associated to a Procrustes registration problem:
from monaco.rotations import quat_to_matrices
class ProcrustesDistribution(object):
def __init__(self, source, target, temperature=1.0):
self.source = source
self.target = target
self.temperature = temperature
def potential(self, q):
"""Evaluates the potential on the point cloud x."""
R = quat_to_matrices(q) # (N, 3, 3)
models = R @ self.source.t() # (N, 3, npoints)
V_i = ((models - self.target.t().view(1, 3, -1)) ** 2).mean(2).sum(1)
return V_i.view(-1) / (2 * self.temperature) # (N,)
Then, we load two proteins as point clouds in the ambient 3D space:
def load_csv(fname):
x = np.loadtxt(fname, skiprows=1, delimiter=",")
x = torch.from_numpy(x).type(dtype)
x -= x.mean(0)
scale = (x**2).sum(1).mean(0)
x /= scale
return x
A = load_csv("data/Ca1UBQ.csv")
B = load_csv("data/Ca1D3Z_1.csv")
distribution = ProcrustesDistribution(A, B, temperature=1e-4)
Finally, we use the MOKA sampler on this distribution:
from monaco.samplers import MOKA_CMC, display_samples
N = 10000 if use_cuda else 50
start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])
moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=2, small=False)
Out:
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
{'iteration': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 1,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92,
93, 94, 95, 96, 97, 98, 99, 100, 101, 102]), 'rate': array([0.95429999, 0.57730001, 0.51959997, 0.51449996, 0.51849997,
0.52719998, 0.53240001, 0.54629999, 0.56330001, 0.59969997,
0.63690001, 0.68039995, 0.72009999, 0.74169999, 0.75379997,
0.75580001, 0.7719 , 0.75940001, 0.77289999, 0.7737 ,
0.7622 , 0.77569997, 0.76559997, 0.76479995, 0.77739996,
0.77129996, 0.77349997, 0.77639997, 0.7694 , 0.7701 ,
0.77160001, 0.77819997, 0.76809996, 0.78239995, 0.77749997,
0.78789997, 0.76980001, 0.77669996, 0.77179998, 0.76879996,
0.77289999, 0.76889998, 0.77359998, 0.77569997, 0.77269995,
0.7723 , 0.77599996, 0.76429999, 0.7802 , 0.77209997,
0.77770001, 0.76010001, 0.77859998, 0.7665 , 0.7809 ,
0.77109998, 0.77019995, 0.78049999, 0.77019995, 0.76959997,
0.77340001, 0.76569998, 0.76349998, 0.76620001, 0.77139997,
0.78329998, 0.75949997, 0.7723 , 0.76559997, 0.7784 ,
0.76819998, 0.76999998, 0.7669 , 0.7762 , 0.7665 ,
0.78189999, 0.77289999, 0.7694 , 0.7766 , 0.77520001,
0.7694 , 0.77869999, 0.78639996, 0.76669997, 0.77629995,
0.76809996, 0.76289999, 0.78599995, 0.76739997, 0.77599996,
0.77289999, 0.76639998, 0.78740001, 0.7633 , 0.77129996,
0.77139997, 0.77639997, 0.76620001, 0.76440001, 0.76370001,
0.76419997, 0.77429998, 0.90880001, 0.57609999, 0.5068 ,
0.5126 , 0.51249999, 0.51539999, 0.52489996, 0.5449 ,
0.56529999, 0.60679996, 0.62939996, 0.68579996, 0.72959995,
0.74970001, 0.7626 , 0.75709999, 0.77459997, 0.76349998,
0.75989997, 0.77139997, 0.77999997, 0.76959997, 0.76909995,
0.77329999, 0.7723 , 0.76629996, 0.77700001, 0.77179998,
0.77520001, 0.77649999, 0.76489997, 0.77509999, 0.76069999,
0.76789999, 0.76370001, 0.76550001, 0.78609997, 0.77179998,
0.78219998, 0.77489996, 0.76069999, 0.7845 , 0.76959997,
0.77129996, 0.77179998, 0.7766 , 0.77239996, 0.76199996,
0.76319999, 0.7802 , 0.7712 , 0.76550001, 0.78059995,
0.76529998, 0.77559996, 0.75729996, 0.77429998, 0.75709999,
0.77459997, 0.76539999, 0.76370001, 0.77489996, 0.7694 ,
0.76629996, 0.78219998, 0.77449995, 0.76370001, 0.7651 ,
0.77279997, 0.77569997, 0.76669997, 0.78299999, 0.77559996,
0.7726 , 0.77219999, 0.78079998, 0.77199996, 0.77309996,
0.76909995, 0.76739997, 0.76309997, 0.76629996, 0.77199996,
0.7694 , 0.77999997, 0.76999998, 0.76319999, 0.77809995,
0.7712 , 0.7773 , 0.7694 , 0.76949996, 0.78169996,
0.76909995, 0.77129996, 0.76859999, 0.7608 , 0.77579999,
0.7766 , 0.76989996, 0.77779996, 0.77419996]), 'normalizing constant': array([0.000318 , 0.00043131, 0.0006636 , 0.00058932, 0.00062415,
0.00055452, 0.00058131, 0.00053881, 0.0005416 , 0.00057076,
0.0005522 , 0.00054144, 0.00054438, 0.00054982, 0.00054566,
0.00055339, 0.00054652, 0.00055227, 0.00054904, 0.00054955,
0.00054651, 0.00054974, 0.00054796, 0.00054754, 0.0005532 ,
0.00055009, 0.00055059, 0.00054802, 0.00054532, 0.00055228,
0.00054816, 0.00055046, 0.00054803, 0.00055509, 0.00054939,
0.00055028, 0.00055258, 0.00055124, 0.00054663, 0.00054746,
0.00054896, 0.00054569, 0.00054929, 0.00055151, 0.00055121,
0.00054733, 0.00054864, 0.00054647, 0.00055306, 0.00054993,
0.00055126, 0.00054835, 0.00054966, 0.00055007, 0.00055092,
0.00054907, 0.0005484 , 0.00054693, 0.00055041, 0.0005486 ,
0.00054802, 0.00054776, 0.00054809, 0.00055102, 0.00054364,
0.00055325, 0.00054825, 0.00055118, 0.0005491 , 0.00055134,
0.00055103, 0.00055057, 0.00054697, 0.00054814, 0.00054818,
0.0005522 , 0.00054856, 0.00055061, 0.00055313, 0.00054963,
0.00054927, 0.00054784, 0.00055016, 0.00054709, 0.00054855,
0.00054391, 0.00055235, 0.00055381, 0.00054771, 0.00055002,
0.00055176, 0.00055466, 0.0005504 , 0.00054883, 0.0005503 ,
0.00055376, 0.00054883, 0.00054734, 0.00054695, 0.00054594,
0.00054756, 0.0005522 , 0.00064543, 0.00048963, 0.00065833,
0.00045164, 0.00054626, 0.0006193 , 0.00052453, 0.00051454,
0.00055273, 0.00052458, 0.00055163, 0.00055412, 0.00054958,
0.00055538, 0.00055066, 0.00054616, 0.00054876, 0.00054885,
0.00054347, 0.00055073, 0.00055212, 0.00055024, 0.00054668,
0.00054807, 0.00054899, 0.00054994, 0.00054819, 0.00054899,
0.0005547 , 0.00054934, 0.00055066, 0.00055071, 0.00054689,
0.00054766, 0.00054693, 0.00054984, 0.00055203, 0.00054608,
0.00055067, 0.00055245, 0.00054761, 0.00054997, 0.00054892,
0.00054899, 0.0005503 , 0.00055039, 0.00054872, 0.00054632,
0.00054669, 0.00054876, 0.0005501 , 0.0005502 , 0.00055035,
0.00054524, 0.00054873, 0.00054805, 0.00054687, 0.00054954,
0.00055027, 0.00054842, 0.00054826, 0.00055131, 0.00054932,
0.00054745, 0.00055105, 0.00055144, 0.00054731, 0.00054585,
0.0005502 , 0.00054897, 0.00055401, 0.00055015, 0.0005483 ,
0.00055173, 0.00054861, 0.00054828, 0.00054702, 0.00054853,
0.00054761, 0.00055084, 0.00054905, 0.00054787, 0.00054629,
0.00055021, 0.00054942, 0.0005485 , 0.00055015, 0.00055086,
0.00055104, 0.0005503 , 0.00054993, 0.00055019, 0.00054945,
0.00055052, 0.00054615, 0.00054817, 0.00054363, 0.00055264,
0.00054863, 0.00055311, 0.00055165, 0.00055429]), 'error': [], 'fluctuation': [], 'probas': array([[0.15004013, 0.20752017, 0.20654385, ..., 0.45694146, 0.4498341 ,
0.45580137],
[0.19785951, 0.2065739 , 0.36189184, ..., 0.39082795, 0.39970857,
0.39993352],
[0.21729353, 0.17098795, 0.29464155, ..., 0.13261054, 0.13084137,
0.12465435],
[0.21744092, 0.2359102 , 0.12694044, ..., 0.00981004, 0.00980794,
0.00980539],
[0.21736595, 0.17900777, 0.00998236, ..., 0.00981004, 0.00980794,
0.00980539]], dtype=float32), 'number of neighbours': array([[4.66498604e+01, 4.47234802e+01, 5.65862732e+01, ...,
1.46109678e+04, 1.46680381e+04, 1.45732832e+04],
[3.73198883e+02, 3.57787842e+02, 4.52690186e+02, ...,
1.16887742e+05, 1.17344305e+05, 1.16586266e+05],
[5.83123291e+03, 5.59043457e+03, 7.07328369e+03, ...,
1.82637088e+06, 1.83350462e+06, 1.82166050e+06],
[4.66498633e+04, 4.47234766e+04, 5.65862695e+04, ...,
1.46109670e+07, 1.46680370e+07, 1.45732840e+07],
[3.73198906e+05, 3.57787812e+05, 4.52690156e+05, ...,
1.16887736e+08, 1.17344296e+08, 1.16586272e+08]], dtype=float32), 'ESS': []}
Wasserstein potential
We rely on the GeomLoss library to define a Procrustes-like potential, where the discrepancy between two point clouds is computed using a good approximation of the squared Wasserstein distance.
N = 10000
from monaco.rotations import quat_to_matrices
from geomloss import SamplesLoss
wasserstein = SamplesLoss("sinkhorn", p=2, blur=0.05)
class WassersteinDistribution(object):
def __init__(self, source, target, temperature=1.0):
self.source = source
self.target = target
self.temperature = temperature
def potential(self, q):
"""Evaluates the potential on the point cloud x."""
R = quat_to_matrices(q) # (N, 3, 3)
models = R @ self.source.t() # (N, 3, npoints)
N = len(models)
models = models.permute(0, 2, 1).contiguous()
targets = self.target.repeat(N, 1, 1).contiguous()
V_i = wasserstein(models, targets)
return V_i.view(-1) / self.temperature # (N,)
Just as in the example above, we use the MOKA algorithm to generate samples for this useful distribution:
distribution = WassersteinDistribution(A, B, temperature=1e-4)
start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])
moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=1, small=False)
Out:
{'iteration': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102]), 'rate': array([0.95559996, 0.93789995, 0.90779996, 0.88439995, 0.86429995,
0.85339999, 0.8398 , 0.8351 , 0.8276 , 0.82239997,
0.82989997, 0.83199996, 0.82379997, 0.8193 , 0.8251 ,
0.81949997, 0.81769997, 0.824 , 0.82019997, 0.8179 ,
0.81689996, 0.81629997, 0.80539995, 0.81799996, 0.82159996,
0.81439996, 0.81309998, 0.81489998, 0.81379998, 0.81629997,
0.81619996, 0.81900001, 0.81639999, 0.81089997, 0.8161 ,
0.81209999, 0.81239998, 0.824 , 0.81349999, 0.8168 ,
0.81449997, 0.81559998, 0.80789995, 0.81569999, 0.81009996,
0.8161 , 0.80409998, 0.8114 , 0.81290001, 0.81219995,
0.81229997, 0.8168 , 0.81809998, 0.81779999, 0.81830001,
0.8211 , 0.81629997, 0.81419998, 0.81119996, 0.81769997,
0.81219995, 0.81459999, 0.81639999, 0.82139999, 0.81699997,
0.81239998, 0.81589997, 0.81040001, 0.81509995, 0.81739998,
0.81279999, 0.81979996, 0.81089997, 0.81169999, 0.8175 ,
0.82059997, 0.82169998, 0.82139999, 0.81299996, 0.80799997,
0.81299996, 0.81729996, 0.81949997, 0.81040001, 0.81729996,
0.82229996, 0.81989998, 0.82209998, 0.80829996, 0.81409997,
0.81909996, 0.81239998, 0.81659997, 0.81819999, 0.81989998,
0.81110001, 0.81659997, 0.82679999, 0.81009996, 0.82029998,
0.82129997, 0.8204 ]), 'normalizing constant': array([0.04336737, 0.04369855, 0.04423331, 0.04363485, 0.04395279,
0.04410308, 0.04401544, 0.0437651 , 0.04463526, 0.04405245,
0.04408799, 0.04414004, 0.04418202, 0.04404185, 0.04407706,
0.04434944, 0.04393895, 0.04406662, 0.04378387, 0.04391291,
0.04423869, 0.04363356, 0.04396987, 0.04437705, 0.04406813,
0.04403404, 0.04413763, 0.04390783, 0.04421007, 0.04419962,
0.04380873, 0.04389819, 0.04387028, 0.0439248 , 0.04394285,
0.04409145, 0.04431931, 0.04445085, 0.04371682, 0.04400166,
0.04377104, 0.04387679, 0.04387242, 0.04416407, 0.04382678,
0.04377542, 0.04353883, 0.04408424, 0.04378124, 0.04399851,
0.04408044, 0.04396119, 0.04424458, 0.04400456, 0.04402829,
0.04392023, 0.04391998, 0.0437169 , 0.04381058, 0.04398689,
0.04392833, 0.04381116, 0.04414421, 0.04387758, 0.04383671,
0.04452438, 0.04385154, 0.04405428, 0.04390055, 0.04382679,
0.04401871, 0.04418558, 0.04344092, 0.04382248, 0.04427725,
0.04445632, 0.04394558, 0.04416358, 0.043821 , 0.04390873,
0.04402443, 0.04386914, 0.04394275, 0.04393733, 0.04420728,
0.04415243, 0.04422145, 0.04390053, 0.04373772, 0.0439169 ,
0.04444473, 0.04385469, 0.04405254, 0.04382789, 0.04412382,
0.0439774 , 0.0443418 , 0.04400473, 0.04398995, 0.04396579,
0.04442929, 0.04415386]), 'error': [], 'fluctuation': [], 'probas': array([[0.1485162 , 0.16055454, 0.16194163, 0.16323854, 0.165397 ,
0.16621989, 0.16728093, 0.17298429, 0.17842184, 0.18131582,
0.18153068, 0.17789325, 0.1813869 , 0.18076019, 0.18275514,
0.18365256, 0.18592648, 0.18325569, 0.18655397, 0.18346013,
0.18728304, 0.179383 , 0.18785241, 0.18468739, 0.18248786,
0.18625754, 0.18594848, 0.18960007, 0.1854024 , 0.18682131,
0.1864778 , 0.18224381, 0.18711585, 0.18683943, 0.18406609,
0.19090426, 0.1924864 , 0.18520172, 0.1846443 , 0.1929851 ,
0.18172711, 0.1836783 , 0.18908027, 0.18223165, 0.193563 ,
0.17798285, 0.1862951 , 0.1862 , 0.18678845, 0.18335605,
0.19117722, 0.18448731, 0.18774757, 0.18200655, 0.19009076,
0.18407753, 0.18604074, 0.18397592, 0.18950382, 0.1848921 ,
0.18820594, 0.18761402, 0.18486324, 0.18718396, 0.18315272,
0.19165844, 0.18590954, 0.18055874, 0.19307667, 0.18678352,
0.19253562, 0.18572135, 0.18937008, 0.18272325, 0.18419279,
0.19023734, 0.18418981, 0.18741153, 0.18475959, 0.18409261,
0.18734142, 0.18401104, 0.18547107, 0.18180841, 0.18173346,
0.19195673, 0.18321392, 0.18623051, 0.18521367, 0.18962799,
0.18197635, 0.18261091, 0.1892677 , 0.18726113, 0.18275808,
0.1861751 , 0.18952785, 0.18479809, 0.183695 , 0.18695967,
0.17729138, 0.19055195],
[0.19695036, 0.1952997 , 0.1963708 , 0.20266855, 0.20458776,
0.21107852, 0.21435012, 0.22044064, 0.2116901 , 0.22071116,
0.22535945, 0.22171246, 0.22021502, 0.22291917, 0.22676362,
0.23151723, 0.2231986 , 0.22244409, 0.22920397, 0.22123633,
0.22816384, 0.22116703, 0.2270531 , 0.22712754, 0.2278181 ,
0.22829005, 0.22584152, 0.21881564, 0.22274776, 0.22581019,
0.22367638, 0.22128773, 0.22824468, 0.22568 , 0.22601405,
0.22555912, 0.22718042, 0.2281449 , 0.22549398, 0.22531363,
0.2243658 , 0.23421022, 0.2307628 , 0.22594686, 0.2281148 ,
0.23078103, 0.2261846 , 0.22598433, 0.22653121, 0.22734319,
0.22857134, 0.22943072, 0.22432452, 0.22385977, 0.229133 ,
0.23056032, 0.22723328, 0.22323352, 0.22479899, 0.22198647,
0.22635092, 0.22984894, 0.22773677, 0.2289532 , 0.22364406,
0.2244531 , 0.22951028, 0.22641128, 0.22176535, 0.22352435,
0.22655046, 0.22729242, 0.22853047, 0.22577965, 0.23312789,
0.22426799, 0.2258728 , 0.22626734, 0.23163481, 0.2271303 ,
0.2238996 , 0.23266205, 0.23015009, 0.22860742, 0.23014596,
0.22237347, 0.22659877, 0.2246113 , 0.23056054, 0.22889037,
0.22729112, 0.2314258 , 0.22499506, 0.23068833, 0.22764853,
0.2262422 , 0.22736192, 0.22999698, 0.2285015 , 0.22283384,
0.23345605, 0.22579303],
[0.21667337, 0.21320991, 0.210758 , 0.21933709, 0.22118852,
0.22280113, 0.22226848, 0.22120504, 0.22483476, 0.2266392 ,
0.22933356, 0.22368371, 0.22578923, 0.229657 , 0.22989716,
0.2280411 , 0.22829853, 0.23074262, 0.2277067 , 0.23302622,
0.23149994, 0.23580757, 0.22534515, 0.22992378, 0.22979376,
0.22899309, 0.23109256, 0.22946812, 0.23072276, 0.22923347,
0.23557179, 0.23238409, 0.2243902 , 0.23136432, 0.22487636,
0.23321244, 0.23017202, 0.2327922 , 0.23939423, 0.22442845,
0.23565231, 0.22703998, 0.2241762 , 0.23035897, 0.22896367,
0.23492745, 0.23067477, 0.23332123, 0.22982734, 0.23373495,
0.2306382 , 0.22959946, 0.234819 , 0.23699102, 0.22628915,
0.23106912, 0.22704528, 0.23206265, 0.22835125, 0.2375098 ,
0.22831424, 0.22638877, 0.22893187, 0.23165943, 0.23086727,
0.23069027, 0.22865061, 0.23296542, 0.22881323, 0.22867419,
0.22928163, 0.23255555, 0.23240142, 0.2306141 , 0.22801651,
0.2303999 , 0.22782049, 0.2311663 , 0.22596866, 0.22869745,
0.23056997, 0.23063104, 0.23105143, 0.22809877, 0.23043825,
0.22579089, 0.22618362, 0.22405076, 0.2313529 , 0.22929037,
0.22828597, 0.22507906, 0.22755624, 0.22866166, 0.23259835,
0.22428328, 0.2289164 , 0.22679083, 0.23322482, 0.22809243,
0.22874312, 0.22969708],
[0.21873702, 0.21478498, 0.21511282, 0.21073158, 0.21264848,
0.20805344, 0.20862307, 0.2016053 , 0.20614582, 0.20526274,
0.1934086 , 0.20644084, 0.1998549 , 0.19951342, 0.19595125,
0.1983026 , 0.19979505, 0.20181738, 0.19496965, 0.19976951,
0.19974917, 0.19827159, 0.19965018, 0.19903651, 0.19414812,
0.19613774, 0.19732763, 0.20016123, 0.199242 , 0.20071582,
0.19595002, 0.20309903, 0.19664107, 0.19848093, 0.20610625,
0.1936942 , 0.19422592, 0.19131985, 0.19035204, 0.19516057,
0.19756751, 0.19478549, 0.19614144, 0.1965923 , 0.19277209,
0.1929744 , 0.19930099, 0.19569592, 0.19462654, 0.1923654 ,
0.19340381, 0.19902703, 0.19423525, 0.19715865, 0.19890656,
0.1957541 , 0.19803362, 0.19821084, 0.19210425, 0.19498955,
0.19883785, 0.19853753, 0.20037489, 0.19634695, 0.1971254 ,
0.19375683, 0.19723134, 0.19563098, 0.19665174, 0.19697207,
0.19418135, 0.19441092, 0.1925318 , 0.20035046, 0.19254498,
0.19543938, 0.19714406, 0.19314252, 0.19607449, 0.19951957,
0.19830583, 0.19310701, 0.19535136, 0.20170636, 0.1947327 ,
0.20092243, 0.19808438, 0.20003602, 0.19214189, 0.19308993,
0.20139639, 0.19975169, 0.19974798, 0.19206996, 0.19768326,
0.19906256, 0.19842745, 0.19665974, 0.19099855, 0.20497857,
0.19913922, 0.19149065],
[0.21912304, 0.21615088, 0.21581674, 0.2040242 , 0.19617833,
0.19184701, 0.1874774 , 0.18376476, 0.1789075 , 0.1660711 ,
0.17036766, 0.1702698 , 0.17275392, 0.1671502 , 0.16463287,
0.15848656, 0.1627813 , 0.16174027, 0.16156572, 0.1625078 ,
0.15330403, 0.16537087, 0.16009918, 0.15922482, 0.1657522 ,
0.1603216 , 0.15978988, 0.16195492, 0.16188505, 0.15741919,
0.15832408, 0.16098537, 0.16360822, 0.15763529, 0.15893728,
0.15662996, 0.15593524, 0.16254133, 0.16011542, 0.16211228,
0.16068721, 0.1602861 , 0.15983929, 0.16487019, 0.1565864 ,
0.16333425, 0.15754461, 0.15879855, 0.16222644, 0.1632004 ,
0.15620948, 0.15745552, 0.15887368, 0.15998402, 0.15558062,
0.15853895, 0.16164702, 0.16251707, 0.16524166, 0.16062212,
0.15829106, 0.15761073, 0.15809323, 0.15585648, 0.16521049,
0.1594414 , 0.15869813, 0.1644336 , 0.159693 , 0.16404587,
0.15745091, 0.16001981, 0.15716627, 0.16053256, 0.16211785,
0.15965539, 0.16497283, 0.16201235, 0.16156249, 0.16056012,
0.15988323, 0.15958895, 0.15797612, 0.15977909, 0.16294965,
0.15895648, 0.16591929, 0.16507144, 0.160731 , 0.15910137,
0.16105023, 0.16113253, 0.15843302, 0.16131893, 0.15931177,
0.1642369 , 0.15576638, 0.1617544 , 0.16358018, 0.15713553,
0.16137026, 0.1624673 ]], dtype=float32), 'number of neighbours': array([[4.6646168e+01, 4.4917942e+01, 4.5352409e+01, 4.5982750e+01,
4.6964226e+01, 4.8985584e+01, 5.2064220e+01, 5.5067158e+01,
5.8666492e+01, 6.2912903e+01, 6.7091217e+01, 6.9701340e+01,
7.1786278e+01, 7.3883736e+01, 7.5293114e+01, 7.6890228e+01,
7.8401169e+01, 7.9652412e+01, 7.8117981e+01, 7.8679825e+01,
7.9638657e+01, 8.1112122e+01, 8.2286110e+01, 8.4492500e+01,
8.3100800e+01, 8.2240402e+01, 8.3979042e+01, 8.6303482e+01,
8.5534882e+01, 8.5168755e+01, 8.5894913e+01, 8.6136116e+01,
8.5650635e+01, 8.3945740e+01, 8.3991249e+01, 8.3299355e+01,
8.5752243e+01, 8.8673615e+01, 8.8633476e+01, 8.7025909e+01,
8.4008553e+01, 8.4636703e+01, 8.6026840e+01, 8.6068710e+01,
8.5076920e+01, 8.5684525e+01, 8.3068527e+01, 8.3674316e+01,
8.5310738e+01, 8.5778114e+01, 8.5076950e+01, 8.6482010e+01,
8.6463264e+01, 8.4942200e+01, 8.6351471e+01, 8.5012993e+01,
8.3857819e+01, 8.4557312e+01, 8.5139671e+01, 8.5521576e+01,
8.4552528e+01, 8.4903816e+01, 8.5593071e+01, 8.5911354e+01,
8.5759277e+01, 8.5463158e+01, 8.7545631e+01, 8.7243431e+01,
8.6353928e+01, 8.6105888e+01, 8.3930893e+01, 8.3844971e+01,
8.3831993e+01, 8.5468452e+01, 8.5832359e+01, 8.7372589e+01,
8.5430794e+01, 8.3483696e+01, 8.2409454e+01, 8.1731087e+01,
8.3420128e+01, 8.4673447e+01, 8.6019058e+01, 8.5739235e+01,
8.6289330e+01, 8.5565010e+01, 8.7241325e+01, 8.4839790e+01,
8.2576126e+01, 8.3405739e+01, 8.5599960e+01, 8.4623940e+01,
8.4694252e+01, 8.5954735e+01, 8.5201485e+01, 8.3126953e+01,
8.4245003e+01, 8.6582397e+01, 8.4537247e+01, 8.4084785e+01,
8.3297211e+01, 8.3406021e+01],
[3.7316934e+02, 3.5934354e+02, 3.6281927e+02, 3.6786200e+02,
3.7571381e+02, 3.9188467e+02, 4.1651376e+02, 4.4053726e+02,
4.6933194e+02, 5.0330322e+02, 5.3672974e+02, 5.5761072e+02,
5.7429022e+02, 5.9106989e+02, 6.0234491e+02, 6.1512183e+02,
6.2720935e+02, 6.3721930e+02, 6.2494385e+02, 6.2943860e+02,
6.3710925e+02, 6.4889697e+02, 6.5828888e+02, 6.7594000e+02,
6.6480640e+02, 6.5792322e+02, 6.7183234e+02, 6.9042786e+02,
6.8427905e+02, 6.8135004e+02, 6.8715930e+02, 6.8908893e+02,
6.8520508e+02, 6.7156592e+02, 6.7192999e+02, 6.6639484e+02,
6.8601794e+02, 7.0938892e+02, 7.0906781e+02, 6.9620728e+02,
6.7206842e+02, 6.7709363e+02, 6.8821472e+02, 6.8854968e+02,
6.8061536e+02, 6.8547620e+02, 6.6454822e+02, 6.6939453e+02,
6.8248590e+02, 6.8622491e+02, 6.8061560e+02, 6.9185608e+02,
6.9170612e+02, 6.7953760e+02, 6.9081177e+02, 6.8010394e+02,
6.7086255e+02, 6.7645850e+02, 6.8111737e+02, 6.8417261e+02,
6.7642023e+02, 6.7923053e+02, 6.8474457e+02, 6.8729083e+02,
6.8607422e+02, 6.8370526e+02, 7.0036505e+02, 6.9794745e+02,
6.9083142e+02, 6.8884711e+02, 6.7144714e+02, 6.7075977e+02,
6.7065594e+02, 6.8374762e+02, 6.8665887e+02, 6.9898071e+02,
6.8344635e+02, 6.6786957e+02, 6.5927563e+02, 6.5384869e+02,
6.6736102e+02, 6.7738757e+02, 6.8815247e+02, 6.8591388e+02,
6.9031464e+02, 6.8452008e+02, 6.9793060e+02, 6.7871832e+02,
6.6060901e+02, 6.6724591e+02, 6.8479968e+02, 6.7699152e+02,
6.7755402e+02, 6.8763788e+02, 6.8161188e+02, 6.6501562e+02,
6.7396002e+02, 6.9265918e+02, 6.7629797e+02, 6.7267828e+02,
6.6637769e+02, 6.6724817e+02],
[5.8307705e+03, 5.6147432e+03, 5.6690508e+03, 5.7478433e+03,
5.8705283e+03, 6.1231978e+03, 6.5080273e+03, 6.8833945e+03,
7.3333110e+03, 7.8641128e+03, 8.3864014e+03, 8.7126670e+03,
8.9732852e+03, 9.2354668e+03, 9.4116387e+03, 9.6112773e+03,
9.8001455e+03, 9.9565518e+03, 9.7647461e+03, 9.8349775e+03,
9.9548330e+03, 1.0139015e+04, 1.0285764e+04, 1.0561562e+04,
1.0387600e+04, 1.0280050e+04, 1.0497379e+04, 1.0787936e+04,
1.0691859e+04, 1.0646094e+04, 1.0736863e+04, 1.0767014e+04,
1.0706328e+04, 1.0493217e+04, 1.0498906e+04, 1.0412419e+04,
1.0719031e+04, 1.1084201e+04, 1.1079184e+04, 1.0878238e+04,
1.0501068e+04, 1.0579589e+04, 1.0753355e+04, 1.0758588e+04,
1.0634615e+04, 1.0710566e+04, 1.0383565e+04, 1.0459288e+04,
1.0663842e+04, 1.0722264e+04, 1.0634618e+04, 1.0810250e+04,
1.0807907e+04, 1.0617773e+04, 1.0793936e+04, 1.0626625e+04,
1.0482228e+04, 1.0569663e+04, 1.0642459e+04, 1.0690195e+04,
1.0569067e+04, 1.0612977e+04, 1.0699133e+04, 1.0738918e+04,
1.0719910e+04, 1.0682894e+04, 1.0943204e+04, 1.0905429e+04,
1.0794240e+04, 1.0763235e+04, 1.0491361e+04, 1.0480621e+04,
1.0479000e+04, 1.0683556e+04, 1.0729044e+04, 1.0921572e+04,
1.0678850e+04, 1.0435463e+04, 1.0301183e+04, 1.0216387e+04,
1.0427516e+04, 1.0584181e+04, 1.0752382e+04, 1.0717404e+04,
1.0786166e+04, 1.0695627e+04, 1.0905166e+04, 1.0604973e+04,
1.0322016e+04, 1.0425717e+04, 1.0699994e+04, 1.0577991e+04,
1.0586781e+04, 1.0744342e+04, 1.0650186e+04, 1.0390869e+04,
1.0530624e+04, 1.0822799e+04, 1.0567156e+04, 1.0510598e+04,
1.0412151e+04, 1.0425751e+04],
[4.6646164e+04, 4.4917945e+04, 4.5352406e+04, 4.5982746e+04,
4.6964227e+04, 4.8985582e+04, 5.2064219e+04, 5.5067156e+04,
5.8666488e+04, 6.2912902e+04, 6.7091211e+04, 6.9701336e+04,
7.1786281e+04, 7.3883734e+04, 7.5293109e+04, 7.6890219e+04,
7.8401164e+04, 7.9652414e+04, 7.8117969e+04, 7.8679820e+04,
7.9638664e+04, 8.1112117e+04, 8.2286109e+04, 8.4492492e+04,
8.3100797e+04, 8.2240398e+04, 8.3979031e+04, 8.6303484e+04,
8.5534875e+04, 8.5168750e+04, 8.5894906e+04, 8.6136109e+04,
8.5650625e+04, 8.3945734e+04, 8.3991250e+04, 8.3299352e+04,
8.5752250e+04, 8.8673609e+04, 8.8633469e+04, 8.7025906e+04,
8.4008547e+04, 8.4636711e+04, 8.6026844e+04, 8.6068703e+04,
8.5076922e+04, 8.5684531e+04, 8.3068523e+04, 8.3674305e+04,
8.5310734e+04, 8.5778109e+04, 8.5076945e+04, 8.6482000e+04,
8.6463258e+04, 8.4942188e+04, 8.6351484e+04, 8.5013000e+04,
8.3857820e+04, 8.4557305e+04, 8.5139672e+04, 8.5521562e+04,
8.4552539e+04, 8.4903812e+04, 8.5593062e+04, 8.5911344e+04,
8.5759281e+04, 8.5463148e+04, 8.7545633e+04, 8.7243430e+04,
8.6353922e+04, 8.6105883e+04, 8.3930891e+04, 8.3844969e+04,
8.3832000e+04, 8.5468445e+04, 8.5832352e+04, 8.7372578e+04,
8.5430797e+04, 8.3483703e+04, 8.2409461e+04, 8.1731094e+04,
8.3420125e+04, 8.4673445e+04, 8.6019055e+04, 8.5739234e+04,
8.6289328e+04, 8.5565016e+04, 8.7241328e+04, 8.4839781e+04,
8.2576125e+04, 8.3405734e+04, 8.5599953e+04, 8.4623930e+04,
8.4694250e+04, 8.5954734e+04, 8.5201484e+04, 8.3126953e+04,
8.4244992e+04, 8.6582391e+04, 8.4537250e+04, 8.4084781e+04,
8.3297211e+04, 8.3406008e+04],
[3.7316931e+05, 3.5934356e+05, 3.6281925e+05, 3.6786197e+05,
3.7571381e+05, 3.9188466e+05, 4.1651375e+05, 4.4053725e+05,
4.6933191e+05, 5.0330322e+05, 5.3672969e+05, 5.5761069e+05,
5.7429025e+05, 5.9106988e+05, 6.0234488e+05, 6.1512175e+05,
6.2720931e+05, 6.3721931e+05, 6.2494375e+05, 6.2943856e+05,
6.3710931e+05, 6.4889694e+05, 6.5828888e+05, 6.7593994e+05,
6.6480638e+05, 6.5792319e+05, 6.7183225e+05, 6.9042788e+05,
6.8427900e+05, 6.8135000e+05, 6.8715925e+05, 6.8908888e+05,
6.8520500e+05, 6.7156588e+05, 6.7193000e+05, 6.6639481e+05,
6.8601800e+05, 7.0938888e+05, 7.0906775e+05, 6.9620725e+05,
6.7206838e+05, 6.7709369e+05, 6.8821475e+05, 6.8854962e+05,
6.8061538e+05, 6.8547625e+05, 6.6454819e+05, 6.6939444e+05,
6.8248588e+05, 6.8622488e+05, 6.8061556e+05, 6.9185600e+05,
6.9170606e+05, 6.7953750e+05, 6.9081188e+05, 6.8010400e+05,
6.7086256e+05, 6.7645844e+05, 6.8111738e+05, 6.8417250e+05,
6.7642031e+05, 6.7923050e+05, 6.8474450e+05, 6.8729075e+05,
6.8607425e+05, 6.8370519e+05, 7.0036506e+05, 6.9794744e+05,
6.9083138e+05, 6.8884706e+05, 6.7144712e+05, 6.7075975e+05,
6.7065600e+05, 6.8374756e+05, 6.8665881e+05, 6.9898062e+05,
6.8344638e+05, 6.6786962e+05, 6.5927569e+05, 6.5384875e+05,
6.6736100e+05, 6.7738756e+05, 6.8815244e+05, 6.8591388e+05,
6.9031462e+05, 6.8452012e+05, 6.9793062e+05, 6.7871825e+05,
6.6060900e+05, 6.6724588e+05, 6.8479962e+05, 6.7699144e+05,
6.7755400e+05, 6.8763788e+05, 6.8161188e+05, 6.6501562e+05,
6.7395994e+05, 6.9265912e+05, 6.7629800e+05, 6.7267825e+05,
6.6637769e+05, 6.6724806e+05]], dtype=float32), 'ESS': []}
As a sanity check, we perform the same computation with simple pairs of points.
def load_coordinates(coordinates):
x = torch.FloatTensor(coordinates).type(dtype)
x -= x.mean(0)
scale = (x**2).sum(1).mean(0)
x /= scale
return x
A = load_coordinates([[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
B = load_coordinates([[0.0, 1.0, 0.0], [0.0, -1.0, 0.0]])
distribution = WassersteinDistribution(A, B, temperature=1e-4)
As expected, all the symmetries of the problem are respected by our sampler:
start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])
moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=2, small=False)
plt.show()
Total running time of the script: ( 0 minutes 32.490 seconds)