Sampling on the 3D rotation group

Let’s show how to sample some distributions on the compact manifold SO(3).

Introduction

First, some standard imports.

import numpy as np
import torch
from matplotlib import pyplot as plt

import warnings
import matplotlib.cbook

warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation)
plt.rcParams.update({"figure.max_open_warning": 0})

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

All functions that are relevant to the rotation group are stored in the monaco.rotations submodule. We first create our manifold, on the GPU whenever possible.

from monaco.rotations import Rotations

space = Rotations(dtype=dtype)

Rotations are encoded as quaternions and displayed as Euler vectors in a sphere of radius pi. Antipodal points at the boundary are identified with each other. Here, we create two arbitrary rotations and sample points at random in their neighborhood.

N = 10000 if use_cuda else 50

ref = torch.ones(N, 1).type(dtype) * torch.FloatTensor([1, 0, 0, 0]).type(dtype)
ref2 = torch.ones(N, 1).type(dtype) * torch.FloatTensor([1, 5, 5, 0]).type(dtype)

ref = torch.cat((ref, ref2), dim=0)

from monaco.rotations import BallProposal

proposal = BallProposal(space, scale=0.5)

x = proposal.sample(ref)

# Display the initial configuration:
plt.figure(figsize=(8, 8))
space.scatter(x, "red")
space.draw_frame()
plt.tight_layout()
plot rotations

Procrustes analysis

We consider the Von Mises distribution associated to a Procrustes registration problem:

from monaco.rotations import quat_to_matrices


class ProcrustesDistribution(object):
    def __init__(self, source, target, temperature=1.0):
        self.source = source
        self.target = target
        self.temperature = temperature

    def potential(self, q):
        """Evaluates the potential on the point cloud x."""
        R = quat_to_matrices(q)  # (N, 3, 3)
        models = R @ self.source.t()  # (N, 3, npoints)

        V_i = ((models - self.target.t().view(1, 3, -1)) ** 2).mean(2).sum(1)

        return V_i.view(-1) / (2 * self.temperature)  # (N,)

Then, we load two proteins as point clouds in the ambient 3D space:

def load_csv(fname):
    x = np.loadtxt(fname, skiprows=1, delimiter=",")
    x = torch.from_numpy(x).type(dtype)
    x -= x.mean(0)
    scale = (x**2).sum(1).mean(0)
    x /= scale
    return x


A = load_csv("data/Ca1UBQ.csv")
B = load_csv("data/Ca1D3Z_1.csv")

distribution = ProcrustesDistribution(A, B, temperature=1e-4)

Finally, we use the MOKA sampler on this distribution:

from monaco.samplers import MOKA_CMC, display_samples

N = 10000 if use_cuda else 50

start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])

moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=2, small=False)
  • it = 0
  • it = 1
  • it = 2
  • it = 5
  • it = 10
  • it = 20
  • it = 50
  • it = 80
  • it = 100
  • plot rotations
  • plot rotations
  • plot rotations
  • plot rotations

Out:

/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/.local/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)

{'iteration': array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,   1,
         2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102]), 'rate': array([0.95429999, 0.57730001, 0.51959997, 0.51449996, 0.51849997,
       0.52719998, 0.53240001, 0.54629999, 0.56330001, 0.59969997,
       0.63690001, 0.68039995, 0.72009999, 0.74169999, 0.75379997,
       0.75580001, 0.7719    , 0.75940001, 0.77289999, 0.7737    ,
       0.7622    , 0.77569997, 0.76559997, 0.76479995, 0.77739996,
       0.77129996, 0.77349997, 0.77639997, 0.7694    , 0.7701    ,
       0.77160001, 0.77819997, 0.76809996, 0.78239995, 0.77749997,
       0.78789997, 0.76980001, 0.77669996, 0.77179998, 0.76879996,
       0.77289999, 0.76889998, 0.77359998, 0.77569997, 0.77269995,
       0.7723    , 0.77599996, 0.76429999, 0.7802    , 0.77209997,
       0.77770001, 0.76010001, 0.77859998, 0.7665    , 0.7809    ,
       0.77109998, 0.77019995, 0.78049999, 0.77019995, 0.76959997,
       0.77340001, 0.76569998, 0.76349998, 0.76620001, 0.77139997,
       0.78329998, 0.75949997, 0.7723    , 0.76559997, 0.7784    ,
       0.76819998, 0.76999998, 0.7669    , 0.7762    , 0.7665    ,
       0.78189999, 0.77289999, 0.7694    , 0.7766    , 0.77520001,
       0.7694    , 0.77869999, 0.78639996, 0.76669997, 0.77629995,
       0.76809996, 0.76289999, 0.78599995, 0.76739997, 0.77599996,
       0.77289999, 0.76639998, 0.78740001, 0.7633    , 0.77129996,
       0.77139997, 0.77639997, 0.76620001, 0.76440001, 0.76370001,
       0.76419997, 0.77429998, 0.90880001, 0.57609999, 0.5068    ,
       0.5126    , 0.51249999, 0.51539999, 0.52489996, 0.5449    ,
       0.56529999, 0.60679996, 0.62939996, 0.68579996, 0.72959995,
       0.74970001, 0.7626    , 0.75709999, 0.77459997, 0.76349998,
       0.75989997, 0.77139997, 0.77999997, 0.76959997, 0.76909995,
       0.77329999, 0.7723    , 0.76629996, 0.77700001, 0.77179998,
       0.77520001, 0.77649999, 0.76489997, 0.77509999, 0.76069999,
       0.76789999, 0.76370001, 0.76550001, 0.78609997, 0.77179998,
       0.78219998, 0.77489996, 0.76069999, 0.7845    , 0.76959997,
       0.77129996, 0.77179998, 0.7766    , 0.77239996, 0.76199996,
       0.76319999, 0.7802    , 0.7712    , 0.76550001, 0.78059995,
       0.76529998, 0.77559996, 0.75729996, 0.77429998, 0.75709999,
       0.77459997, 0.76539999, 0.76370001, 0.77489996, 0.7694    ,
       0.76629996, 0.78219998, 0.77449995, 0.76370001, 0.7651    ,
       0.77279997, 0.77569997, 0.76669997, 0.78299999, 0.77559996,
       0.7726    , 0.77219999, 0.78079998, 0.77199996, 0.77309996,
       0.76909995, 0.76739997, 0.76309997, 0.76629996, 0.77199996,
       0.7694    , 0.77999997, 0.76999998, 0.76319999, 0.77809995,
       0.7712    , 0.7773    , 0.7694    , 0.76949996, 0.78169996,
       0.76909995, 0.77129996, 0.76859999, 0.7608    , 0.77579999,
       0.7766    , 0.76989996, 0.77779996, 0.77419996]), 'normalizing constant': array([0.000318  , 0.00043131, 0.0006636 , 0.00058932, 0.00062415,
       0.00055452, 0.00058131, 0.00053881, 0.0005416 , 0.00057076,
       0.0005522 , 0.00054144, 0.00054438, 0.00054982, 0.00054566,
       0.00055339, 0.00054652, 0.00055227, 0.00054904, 0.00054955,
       0.00054651, 0.00054974, 0.00054796, 0.00054754, 0.0005532 ,
       0.00055009, 0.00055059, 0.00054802, 0.00054532, 0.00055228,
       0.00054816, 0.00055046, 0.00054803, 0.00055509, 0.00054939,
       0.00055028, 0.00055258, 0.00055124, 0.00054663, 0.00054746,
       0.00054896, 0.00054569, 0.00054929, 0.00055151, 0.00055121,
       0.00054733, 0.00054864, 0.00054647, 0.00055306, 0.00054993,
       0.00055126, 0.00054835, 0.00054966, 0.00055007, 0.00055092,
       0.00054907, 0.0005484 , 0.00054693, 0.00055041, 0.0005486 ,
       0.00054802, 0.00054776, 0.00054809, 0.00055102, 0.00054364,
       0.00055325, 0.00054825, 0.00055118, 0.0005491 , 0.00055134,
       0.00055103, 0.00055057, 0.00054697, 0.00054814, 0.00054818,
       0.0005522 , 0.00054856, 0.00055061, 0.00055313, 0.00054963,
       0.00054927, 0.00054784, 0.00055016, 0.00054709, 0.00054855,
       0.00054391, 0.00055235, 0.00055381, 0.00054771, 0.00055002,
       0.00055176, 0.00055466, 0.0005504 , 0.00054883, 0.0005503 ,
       0.00055376, 0.00054883, 0.00054734, 0.00054695, 0.00054594,
       0.00054756, 0.0005522 , 0.00064543, 0.00048963, 0.00065833,
       0.00045164, 0.00054626, 0.0006193 , 0.00052453, 0.00051454,
       0.00055273, 0.00052458, 0.00055163, 0.00055412, 0.00054958,
       0.00055538, 0.00055066, 0.00054616, 0.00054876, 0.00054885,
       0.00054347, 0.00055073, 0.00055212, 0.00055024, 0.00054668,
       0.00054807, 0.00054899, 0.00054994, 0.00054819, 0.00054899,
       0.0005547 , 0.00054934, 0.00055066, 0.00055071, 0.00054689,
       0.00054766, 0.00054693, 0.00054984, 0.00055203, 0.00054608,
       0.00055067, 0.00055245, 0.00054761, 0.00054997, 0.00054892,
       0.00054899, 0.0005503 , 0.00055039, 0.00054872, 0.00054632,
       0.00054669, 0.00054876, 0.0005501 , 0.0005502 , 0.00055035,
       0.00054524, 0.00054873, 0.00054805, 0.00054687, 0.00054954,
       0.00055027, 0.00054842, 0.00054826, 0.00055131, 0.00054932,
       0.00054745, 0.00055105, 0.00055144, 0.00054731, 0.00054585,
       0.0005502 , 0.00054897, 0.00055401, 0.00055015, 0.0005483 ,
       0.00055173, 0.00054861, 0.00054828, 0.00054702, 0.00054853,
       0.00054761, 0.00055084, 0.00054905, 0.00054787, 0.00054629,
       0.00055021, 0.00054942, 0.0005485 , 0.00055015, 0.00055086,
       0.00055104, 0.0005503 , 0.00054993, 0.00055019, 0.00054945,
       0.00055052, 0.00054615, 0.00054817, 0.00054363, 0.00055264,
       0.00054863, 0.00055311, 0.00055165, 0.00055429]), 'error': [], 'fluctuation': [], 'probas': array([[0.15004013, 0.20752017, 0.20654385, ..., 0.45694146, 0.4498341 ,
        0.45580137],
       [0.19785951, 0.2065739 , 0.36189184, ..., 0.39082795, 0.39970857,
        0.39993352],
       [0.21729353, 0.17098795, 0.29464155, ..., 0.13261054, 0.13084137,
        0.12465435],
       [0.21744092, 0.2359102 , 0.12694044, ..., 0.00981004, 0.00980794,
        0.00980539],
       [0.21736595, 0.17900777, 0.00998236, ..., 0.00981004, 0.00980794,
        0.00980539]], dtype=float32), 'number of neighbours': array([[4.66498604e+01, 4.47234802e+01, 5.65862732e+01, ...,
        1.46109678e+04, 1.46680381e+04, 1.45732832e+04],
       [3.73198883e+02, 3.57787842e+02, 4.52690186e+02, ...,
        1.16887742e+05, 1.17344305e+05, 1.16586266e+05],
       [5.83123291e+03, 5.59043457e+03, 7.07328369e+03, ...,
        1.82637088e+06, 1.83350462e+06, 1.82166050e+06],
       [4.66498633e+04, 4.47234766e+04, 5.65862695e+04, ...,
        1.46109670e+07, 1.46680370e+07, 1.45732840e+07],
       [3.73198906e+05, 3.57787812e+05, 4.52690156e+05, ...,
        1.16887736e+08, 1.17344296e+08, 1.16586272e+08]], dtype=float32), 'ESS': []}

Wasserstein potential

We rely on the GeomLoss library to define a Procrustes-like potential, where the discrepancy between two point clouds is computed using a good approximation of the squared Wasserstein distance.

N = 10000

from monaco.rotations import quat_to_matrices
from geomloss import SamplesLoss

wasserstein = SamplesLoss("sinkhorn", p=2, blur=0.05)


class WassersteinDistribution(object):
    def __init__(self, source, target, temperature=1.0):
        self.source = source
        self.target = target
        self.temperature = temperature

    def potential(self, q):
        """Evaluates the potential on the point cloud x."""
        R = quat_to_matrices(q)  # (N, 3, 3)
        models = R @ self.source.t()  # (N, 3, npoints)

        N = len(models)
        models = models.permute(0, 2, 1).contiguous()
        targets = self.target.repeat(N, 1, 1).contiguous()

        V_i = wasserstein(models, targets)

        return V_i.view(-1) / self.temperature  # (N,)

Just as in the example above, we use the MOKA algorithm to generate samples for this useful distribution:

distribution = WassersteinDistribution(A, B, temperature=1e-4)

start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])

moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=1, small=False)
  • it = 0
  • it = 1
  • it = 2
  • it = 5
  • it = 10
  • it = 20
  • it = 50
  • it = 80
  • it = 100
  • plot rotations
  • plot rotations
  • plot rotations
  • plot rotations

Out:

{'iteration': array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102]), 'rate': array([0.95559996, 0.93789995, 0.90779996, 0.88439995, 0.86429995,
       0.85339999, 0.8398    , 0.8351    , 0.8276    , 0.82239997,
       0.82989997, 0.83199996, 0.82379997, 0.8193    , 0.8251    ,
       0.81949997, 0.81769997, 0.824     , 0.82019997, 0.8179    ,
       0.81689996, 0.81629997, 0.80539995, 0.81799996, 0.82159996,
       0.81439996, 0.81309998, 0.81489998, 0.81379998, 0.81629997,
       0.81619996, 0.81900001, 0.81639999, 0.81089997, 0.8161    ,
       0.81209999, 0.81239998, 0.824     , 0.81349999, 0.8168    ,
       0.81449997, 0.81559998, 0.80789995, 0.81569999, 0.81009996,
       0.8161    , 0.80409998, 0.8114    , 0.81290001, 0.81219995,
       0.81229997, 0.8168    , 0.81809998, 0.81779999, 0.81830001,
       0.8211    , 0.81629997, 0.81419998, 0.81119996, 0.81769997,
       0.81219995, 0.81459999, 0.81639999, 0.82139999, 0.81699997,
       0.81239998, 0.81589997, 0.81040001, 0.81509995, 0.81739998,
       0.81279999, 0.81979996, 0.81089997, 0.81169999, 0.8175    ,
       0.82059997, 0.82169998, 0.82139999, 0.81299996, 0.80799997,
       0.81299996, 0.81729996, 0.81949997, 0.81040001, 0.81729996,
       0.82229996, 0.81989998, 0.82209998, 0.80829996, 0.81409997,
       0.81909996, 0.81239998, 0.81659997, 0.81819999, 0.81989998,
       0.81110001, 0.81659997, 0.82679999, 0.81009996, 0.82029998,
       0.82129997, 0.8204    ]), 'normalizing constant': array([0.04336737, 0.04369855, 0.04423331, 0.04363485, 0.04395279,
       0.04410308, 0.04401544, 0.0437651 , 0.04463526, 0.04405245,
       0.04408799, 0.04414004, 0.04418202, 0.04404185, 0.04407706,
       0.04434944, 0.04393895, 0.04406662, 0.04378387, 0.04391291,
       0.04423869, 0.04363356, 0.04396987, 0.04437705, 0.04406813,
       0.04403404, 0.04413763, 0.04390783, 0.04421007, 0.04419962,
       0.04380873, 0.04389819, 0.04387028, 0.0439248 , 0.04394285,
       0.04409145, 0.04431931, 0.04445085, 0.04371682, 0.04400166,
       0.04377104, 0.04387679, 0.04387242, 0.04416407, 0.04382678,
       0.04377542, 0.04353883, 0.04408424, 0.04378124, 0.04399851,
       0.04408044, 0.04396119, 0.04424458, 0.04400456, 0.04402829,
       0.04392023, 0.04391998, 0.0437169 , 0.04381058, 0.04398689,
       0.04392833, 0.04381116, 0.04414421, 0.04387758, 0.04383671,
       0.04452438, 0.04385154, 0.04405428, 0.04390055, 0.04382679,
       0.04401871, 0.04418558, 0.04344092, 0.04382248, 0.04427725,
       0.04445632, 0.04394558, 0.04416358, 0.043821  , 0.04390873,
       0.04402443, 0.04386914, 0.04394275, 0.04393733, 0.04420728,
       0.04415243, 0.04422145, 0.04390053, 0.04373772, 0.0439169 ,
       0.04444473, 0.04385469, 0.04405254, 0.04382789, 0.04412382,
       0.0439774 , 0.0443418 , 0.04400473, 0.04398995, 0.04396579,
       0.04442929, 0.04415386]), 'error': [], 'fluctuation': [], 'probas': array([[0.1485162 , 0.16055454, 0.16194163, 0.16323854, 0.165397  ,
        0.16621989, 0.16728093, 0.17298429, 0.17842184, 0.18131582,
        0.18153068, 0.17789325, 0.1813869 , 0.18076019, 0.18275514,
        0.18365256, 0.18592648, 0.18325569, 0.18655397, 0.18346013,
        0.18728304, 0.179383  , 0.18785241, 0.18468739, 0.18248786,
        0.18625754, 0.18594848, 0.18960007, 0.1854024 , 0.18682131,
        0.1864778 , 0.18224381, 0.18711585, 0.18683943, 0.18406609,
        0.19090426, 0.1924864 , 0.18520172, 0.1846443 , 0.1929851 ,
        0.18172711, 0.1836783 , 0.18908027, 0.18223165, 0.193563  ,
        0.17798285, 0.1862951 , 0.1862    , 0.18678845, 0.18335605,
        0.19117722, 0.18448731, 0.18774757, 0.18200655, 0.19009076,
        0.18407753, 0.18604074, 0.18397592, 0.18950382, 0.1848921 ,
        0.18820594, 0.18761402, 0.18486324, 0.18718396, 0.18315272,
        0.19165844, 0.18590954, 0.18055874, 0.19307667, 0.18678352,
        0.19253562, 0.18572135, 0.18937008, 0.18272325, 0.18419279,
        0.19023734, 0.18418981, 0.18741153, 0.18475959, 0.18409261,
        0.18734142, 0.18401104, 0.18547107, 0.18180841, 0.18173346,
        0.19195673, 0.18321392, 0.18623051, 0.18521367, 0.18962799,
        0.18197635, 0.18261091, 0.1892677 , 0.18726113, 0.18275808,
        0.1861751 , 0.18952785, 0.18479809, 0.183695  , 0.18695967,
        0.17729138, 0.19055195],
       [0.19695036, 0.1952997 , 0.1963708 , 0.20266855, 0.20458776,
        0.21107852, 0.21435012, 0.22044064, 0.2116901 , 0.22071116,
        0.22535945, 0.22171246, 0.22021502, 0.22291917, 0.22676362,
        0.23151723, 0.2231986 , 0.22244409, 0.22920397, 0.22123633,
        0.22816384, 0.22116703, 0.2270531 , 0.22712754, 0.2278181 ,
        0.22829005, 0.22584152, 0.21881564, 0.22274776, 0.22581019,
        0.22367638, 0.22128773, 0.22824468, 0.22568   , 0.22601405,
        0.22555912, 0.22718042, 0.2281449 , 0.22549398, 0.22531363,
        0.2243658 , 0.23421022, 0.2307628 , 0.22594686, 0.2281148 ,
        0.23078103, 0.2261846 , 0.22598433, 0.22653121, 0.22734319,
        0.22857134, 0.22943072, 0.22432452, 0.22385977, 0.229133  ,
        0.23056032, 0.22723328, 0.22323352, 0.22479899, 0.22198647,
        0.22635092, 0.22984894, 0.22773677, 0.2289532 , 0.22364406,
        0.2244531 , 0.22951028, 0.22641128, 0.22176535, 0.22352435,
        0.22655046, 0.22729242, 0.22853047, 0.22577965, 0.23312789,
        0.22426799, 0.2258728 , 0.22626734, 0.23163481, 0.2271303 ,
        0.2238996 , 0.23266205, 0.23015009, 0.22860742, 0.23014596,
        0.22237347, 0.22659877, 0.2246113 , 0.23056054, 0.22889037,
        0.22729112, 0.2314258 , 0.22499506, 0.23068833, 0.22764853,
        0.2262422 , 0.22736192, 0.22999698, 0.2285015 , 0.22283384,
        0.23345605, 0.22579303],
       [0.21667337, 0.21320991, 0.210758  , 0.21933709, 0.22118852,
        0.22280113, 0.22226848, 0.22120504, 0.22483476, 0.2266392 ,
        0.22933356, 0.22368371, 0.22578923, 0.229657  , 0.22989716,
        0.2280411 , 0.22829853, 0.23074262, 0.2277067 , 0.23302622,
        0.23149994, 0.23580757, 0.22534515, 0.22992378, 0.22979376,
        0.22899309, 0.23109256, 0.22946812, 0.23072276, 0.22923347,
        0.23557179, 0.23238409, 0.2243902 , 0.23136432, 0.22487636,
        0.23321244, 0.23017202, 0.2327922 , 0.23939423, 0.22442845,
        0.23565231, 0.22703998, 0.2241762 , 0.23035897, 0.22896367,
        0.23492745, 0.23067477, 0.23332123, 0.22982734, 0.23373495,
        0.2306382 , 0.22959946, 0.234819  , 0.23699102, 0.22628915,
        0.23106912, 0.22704528, 0.23206265, 0.22835125, 0.2375098 ,
        0.22831424, 0.22638877, 0.22893187, 0.23165943, 0.23086727,
        0.23069027, 0.22865061, 0.23296542, 0.22881323, 0.22867419,
        0.22928163, 0.23255555, 0.23240142, 0.2306141 , 0.22801651,
        0.2303999 , 0.22782049, 0.2311663 , 0.22596866, 0.22869745,
        0.23056997, 0.23063104, 0.23105143, 0.22809877, 0.23043825,
        0.22579089, 0.22618362, 0.22405076, 0.2313529 , 0.22929037,
        0.22828597, 0.22507906, 0.22755624, 0.22866166, 0.23259835,
        0.22428328, 0.2289164 , 0.22679083, 0.23322482, 0.22809243,
        0.22874312, 0.22969708],
       [0.21873702, 0.21478498, 0.21511282, 0.21073158, 0.21264848,
        0.20805344, 0.20862307, 0.2016053 , 0.20614582, 0.20526274,
        0.1934086 , 0.20644084, 0.1998549 , 0.19951342, 0.19595125,
        0.1983026 , 0.19979505, 0.20181738, 0.19496965, 0.19976951,
        0.19974917, 0.19827159, 0.19965018, 0.19903651, 0.19414812,
        0.19613774, 0.19732763, 0.20016123, 0.199242  , 0.20071582,
        0.19595002, 0.20309903, 0.19664107, 0.19848093, 0.20610625,
        0.1936942 , 0.19422592, 0.19131985, 0.19035204, 0.19516057,
        0.19756751, 0.19478549, 0.19614144, 0.1965923 , 0.19277209,
        0.1929744 , 0.19930099, 0.19569592, 0.19462654, 0.1923654 ,
        0.19340381, 0.19902703, 0.19423525, 0.19715865, 0.19890656,
        0.1957541 , 0.19803362, 0.19821084, 0.19210425, 0.19498955,
        0.19883785, 0.19853753, 0.20037489, 0.19634695, 0.1971254 ,
        0.19375683, 0.19723134, 0.19563098, 0.19665174, 0.19697207,
        0.19418135, 0.19441092, 0.1925318 , 0.20035046, 0.19254498,
        0.19543938, 0.19714406, 0.19314252, 0.19607449, 0.19951957,
        0.19830583, 0.19310701, 0.19535136, 0.20170636, 0.1947327 ,
        0.20092243, 0.19808438, 0.20003602, 0.19214189, 0.19308993,
        0.20139639, 0.19975169, 0.19974798, 0.19206996, 0.19768326,
        0.19906256, 0.19842745, 0.19665974, 0.19099855, 0.20497857,
        0.19913922, 0.19149065],
       [0.21912304, 0.21615088, 0.21581674, 0.2040242 , 0.19617833,
        0.19184701, 0.1874774 , 0.18376476, 0.1789075 , 0.1660711 ,
        0.17036766, 0.1702698 , 0.17275392, 0.1671502 , 0.16463287,
        0.15848656, 0.1627813 , 0.16174027, 0.16156572, 0.1625078 ,
        0.15330403, 0.16537087, 0.16009918, 0.15922482, 0.1657522 ,
        0.1603216 , 0.15978988, 0.16195492, 0.16188505, 0.15741919,
        0.15832408, 0.16098537, 0.16360822, 0.15763529, 0.15893728,
        0.15662996, 0.15593524, 0.16254133, 0.16011542, 0.16211228,
        0.16068721, 0.1602861 , 0.15983929, 0.16487019, 0.1565864 ,
        0.16333425, 0.15754461, 0.15879855, 0.16222644, 0.1632004 ,
        0.15620948, 0.15745552, 0.15887368, 0.15998402, 0.15558062,
        0.15853895, 0.16164702, 0.16251707, 0.16524166, 0.16062212,
        0.15829106, 0.15761073, 0.15809323, 0.15585648, 0.16521049,
        0.1594414 , 0.15869813, 0.1644336 , 0.159693  , 0.16404587,
        0.15745091, 0.16001981, 0.15716627, 0.16053256, 0.16211785,
        0.15965539, 0.16497283, 0.16201235, 0.16156249, 0.16056012,
        0.15988323, 0.15958895, 0.15797612, 0.15977909, 0.16294965,
        0.15895648, 0.16591929, 0.16507144, 0.160731  , 0.15910137,
        0.16105023, 0.16113253, 0.15843302, 0.16131893, 0.15931177,
        0.1642369 , 0.15576638, 0.1617544 , 0.16358018, 0.15713553,
        0.16137026, 0.1624673 ]], dtype=float32), 'number of neighbours': array([[4.6646168e+01, 4.4917942e+01, 4.5352409e+01, 4.5982750e+01,
        4.6964226e+01, 4.8985584e+01, 5.2064220e+01, 5.5067158e+01,
        5.8666492e+01, 6.2912903e+01, 6.7091217e+01, 6.9701340e+01,
        7.1786278e+01, 7.3883736e+01, 7.5293114e+01, 7.6890228e+01,
        7.8401169e+01, 7.9652412e+01, 7.8117981e+01, 7.8679825e+01,
        7.9638657e+01, 8.1112122e+01, 8.2286110e+01, 8.4492500e+01,
        8.3100800e+01, 8.2240402e+01, 8.3979042e+01, 8.6303482e+01,
        8.5534882e+01, 8.5168755e+01, 8.5894913e+01, 8.6136116e+01,
        8.5650635e+01, 8.3945740e+01, 8.3991249e+01, 8.3299355e+01,
        8.5752243e+01, 8.8673615e+01, 8.8633476e+01, 8.7025909e+01,
        8.4008553e+01, 8.4636703e+01, 8.6026840e+01, 8.6068710e+01,
        8.5076920e+01, 8.5684525e+01, 8.3068527e+01, 8.3674316e+01,
        8.5310738e+01, 8.5778114e+01, 8.5076950e+01, 8.6482010e+01,
        8.6463264e+01, 8.4942200e+01, 8.6351471e+01, 8.5012993e+01,
        8.3857819e+01, 8.4557312e+01, 8.5139671e+01, 8.5521576e+01,
        8.4552528e+01, 8.4903816e+01, 8.5593071e+01, 8.5911354e+01,
        8.5759277e+01, 8.5463158e+01, 8.7545631e+01, 8.7243431e+01,
        8.6353928e+01, 8.6105888e+01, 8.3930893e+01, 8.3844971e+01,
        8.3831993e+01, 8.5468452e+01, 8.5832359e+01, 8.7372589e+01,
        8.5430794e+01, 8.3483696e+01, 8.2409454e+01, 8.1731087e+01,
        8.3420128e+01, 8.4673447e+01, 8.6019058e+01, 8.5739235e+01,
        8.6289330e+01, 8.5565010e+01, 8.7241325e+01, 8.4839790e+01,
        8.2576126e+01, 8.3405739e+01, 8.5599960e+01, 8.4623940e+01,
        8.4694252e+01, 8.5954735e+01, 8.5201485e+01, 8.3126953e+01,
        8.4245003e+01, 8.6582397e+01, 8.4537247e+01, 8.4084785e+01,
        8.3297211e+01, 8.3406021e+01],
       [3.7316934e+02, 3.5934354e+02, 3.6281927e+02, 3.6786200e+02,
        3.7571381e+02, 3.9188467e+02, 4.1651376e+02, 4.4053726e+02,
        4.6933194e+02, 5.0330322e+02, 5.3672974e+02, 5.5761072e+02,
        5.7429022e+02, 5.9106989e+02, 6.0234491e+02, 6.1512183e+02,
        6.2720935e+02, 6.3721930e+02, 6.2494385e+02, 6.2943860e+02,
        6.3710925e+02, 6.4889697e+02, 6.5828888e+02, 6.7594000e+02,
        6.6480640e+02, 6.5792322e+02, 6.7183234e+02, 6.9042786e+02,
        6.8427905e+02, 6.8135004e+02, 6.8715930e+02, 6.8908893e+02,
        6.8520508e+02, 6.7156592e+02, 6.7192999e+02, 6.6639484e+02,
        6.8601794e+02, 7.0938892e+02, 7.0906781e+02, 6.9620728e+02,
        6.7206842e+02, 6.7709363e+02, 6.8821472e+02, 6.8854968e+02,
        6.8061536e+02, 6.8547620e+02, 6.6454822e+02, 6.6939453e+02,
        6.8248590e+02, 6.8622491e+02, 6.8061560e+02, 6.9185608e+02,
        6.9170612e+02, 6.7953760e+02, 6.9081177e+02, 6.8010394e+02,
        6.7086255e+02, 6.7645850e+02, 6.8111737e+02, 6.8417261e+02,
        6.7642023e+02, 6.7923053e+02, 6.8474457e+02, 6.8729083e+02,
        6.8607422e+02, 6.8370526e+02, 7.0036505e+02, 6.9794745e+02,
        6.9083142e+02, 6.8884711e+02, 6.7144714e+02, 6.7075977e+02,
        6.7065594e+02, 6.8374762e+02, 6.8665887e+02, 6.9898071e+02,
        6.8344635e+02, 6.6786957e+02, 6.5927563e+02, 6.5384869e+02,
        6.6736102e+02, 6.7738757e+02, 6.8815247e+02, 6.8591388e+02,
        6.9031464e+02, 6.8452008e+02, 6.9793060e+02, 6.7871832e+02,
        6.6060901e+02, 6.6724591e+02, 6.8479968e+02, 6.7699152e+02,
        6.7755402e+02, 6.8763788e+02, 6.8161188e+02, 6.6501562e+02,
        6.7396002e+02, 6.9265918e+02, 6.7629797e+02, 6.7267828e+02,
        6.6637769e+02, 6.6724817e+02],
       [5.8307705e+03, 5.6147432e+03, 5.6690508e+03, 5.7478433e+03,
        5.8705283e+03, 6.1231978e+03, 6.5080273e+03, 6.8833945e+03,
        7.3333110e+03, 7.8641128e+03, 8.3864014e+03, 8.7126670e+03,
        8.9732852e+03, 9.2354668e+03, 9.4116387e+03, 9.6112773e+03,
        9.8001455e+03, 9.9565518e+03, 9.7647461e+03, 9.8349775e+03,
        9.9548330e+03, 1.0139015e+04, 1.0285764e+04, 1.0561562e+04,
        1.0387600e+04, 1.0280050e+04, 1.0497379e+04, 1.0787936e+04,
        1.0691859e+04, 1.0646094e+04, 1.0736863e+04, 1.0767014e+04,
        1.0706328e+04, 1.0493217e+04, 1.0498906e+04, 1.0412419e+04,
        1.0719031e+04, 1.1084201e+04, 1.1079184e+04, 1.0878238e+04,
        1.0501068e+04, 1.0579589e+04, 1.0753355e+04, 1.0758588e+04,
        1.0634615e+04, 1.0710566e+04, 1.0383565e+04, 1.0459288e+04,
        1.0663842e+04, 1.0722264e+04, 1.0634618e+04, 1.0810250e+04,
        1.0807907e+04, 1.0617773e+04, 1.0793936e+04, 1.0626625e+04,
        1.0482228e+04, 1.0569663e+04, 1.0642459e+04, 1.0690195e+04,
        1.0569067e+04, 1.0612977e+04, 1.0699133e+04, 1.0738918e+04,
        1.0719910e+04, 1.0682894e+04, 1.0943204e+04, 1.0905429e+04,
        1.0794240e+04, 1.0763235e+04, 1.0491361e+04, 1.0480621e+04,
        1.0479000e+04, 1.0683556e+04, 1.0729044e+04, 1.0921572e+04,
        1.0678850e+04, 1.0435463e+04, 1.0301183e+04, 1.0216387e+04,
        1.0427516e+04, 1.0584181e+04, 1.0752382e+04, 1.0717404e+04,
        1.0786166e+04, 1.0695627e+04, 1.0905166e+04, 1.0604973e+04,
        1.0322016e+04, 1.0425717e+04, 1.0699994e+04, 1.0577991e+04,
        1.0586781e+04, 1.0744342e+04, 1.0650186e+04, 1.0390869e+04,
        1.0530624e+04, 1.0822799e+04, 1.0567156e+04, 1.0510598e+04,
        1.0412151e+04, 1.0425751e+04],
       [4.6646164e+04, 4.4917945e+04, 4.5352406e+04, 4.5982746e+04,
        4.6964227e+04, 4.8985582e+04, 5.2064219e+04, 5.5067156e+04,
        5.8666488e+04, 6.2912902e+04, 6.7091211e+04, 6.9701336e+04,
        7.1786281e+04, 7.3883734e+04, 7.5293109e+04, 7.6890219e+04,
        7.8401164e+04, 7.9652414e+04, 7.8117969e+04, 7.8679820e+04,
        7.9638664e+04, 8.1112117e+04, 8.2286109e+04, 8.4492492e+04,
        8.3100797e+04, 8.2240398e+04, 8.3979031e+04, 8.6303484e+04,
        8.5534875e+04, 8.5168750e+04, 8.5894906e+04, 8.6136109e+04,
        8.5650625e+04, 8.3945734e+04, 8.3991250e+04, 8.3299352e+04,
        8.5752250e+04, 8.8673609e+04, 8.8633469e+04, 8.7025906e+04,
        8.4008547e+04, 8.4636711e+04, 8.6026844e+04, 8.6068703e+04,
        8.5076922e+04, 8.5684531e+04, 8.3068523e+04, 8.3674305e+04,
        8.5310734e+04, 8.5778109e+04, 8.5076945e+04, 8.6482000e+04,
        8.6463258e+04, 8.4942188e+04, 8.6351484e+04, 8.5013000e+04,
        8.3857820e+04, 8.4557305e+04, 8.5139672e+04, 8.5521562e+04,
        8.4552539e+04, 8.4903812e+04, 8.5593062e+04, 8.5911344e+04,
        8.5759281e+04, 8.5463148e+04, 8.7545633e+04, 8.7243430e+04,
        8.6353922e+04, 8.6105883e+04, 8.3930891e+04, 8.3844969e+04,
        8.3832000e+04, 8.5468445e+04, 8.5832352e+04, 8.7372578e+04,
        8.5430797e+04, 8.3483703e+04, 8.2409461e+04, 8.1731094e+04,
        8.3420125e+04, 8.4673445e+04, 8.6019055e+04, 8.5739234e+04,
        8.6289328e+04, 8.5565016e+04, 8.7241328e+04, 8.4839781e+04,
        8.2576125e+04, 8.3405734e+04, 8.5599953e+04, 8.4623930e+04,
        8.4694250e+04, 8.5954734e+04, 8.5201484e+04, 8.3126953e+04,
        8.4244992e+04, 8.6582391e+04, 8.4537250e+04, 8.4084781e+04,
        8.3297211e+04, 8.3406008e+04],
       [3.7316931e+05, 3.5934356e+05, 3.6281925e+05, 3.6786197e+05,
        3.7571381e+05, 3.9188466e+05, 4.1651375e+05, 4.4053725e+05,
        4.6933191e+05, 5.0330322e+05, 5.3672969e+05, 5.5761069e+05,
        5.7429025e+05, 5.9106988e+05, 6.0234488e+05, 6.1512175e+05,
        6.2720931e+05, 6.3721931e+05, 6.2494375e+05, 6.2943856e+05,
        6.3710931e+05, 6.4889694e+05, 6.5828888e+05, 6.7593994e+05,
        6.6480638e+05, 6.5792319e+05, 6.7183225e+05, 6.9042788e+05,
        6.8427900e+05, 6.8135000e+05, 6.8715925e+05, 6.8908888e+05,
        6.8520500e+05, 6.7156588e+05, 6.7193000e+05, 6.6639481e+05,
        6.8601800e+05, 7.0938888e+05, 7.0906775e+05, 6.9620725e+05,
        6.7206838e+05, 6.7709369e+05, 6.8821475e+05, 6.8854962e+05,
        6.8061538e+05, 6.8547625e+05, 6.6454819e+05, 6.6939444e+05,
        6.8248588e+05, 6.8622488e+05, 6.8061556e+05, 6.9185600e+05,
        6.9170606e+05, 6.7953750e+05, 6.9081188e+05, 6.8010400e+05,
        6.7086256e+05, 6.7645844e+05, 6.8111738e+05, 6.8417250e+05,
        6.7642031e+05, 6.7923050e+05, 6.8474450e+05, 6.8729075e+05,
        6.8607425e+05, 6.8370519e+05, 7.0036506e+05, 6.9794744e+05,
        6.9083138e+05, 6.8884706e+05, 6.7144712e+05, 6.7075975e+05,
        6.7065600e+05, 6.8374756e+05, 6.8665881e+05, 6.9898062e+05,
        6.8344638e+05, 6.6786962e+05, 6.5927569e+05, 6.5384875e+05,
        6.6736100e+05, 6.7738756e+05, 6.8815244e+05, 6.8591388e+05,
        6.9031462e+05, 6.8452012e+05, 6.9793062e+05, 6.7871825e+05,
        6.6060900e+05, 6.6724588e+05, 6.8479962e+05, 6.7699144e+05,
        6.7755400e+05, 6.8763788e+05, 6.8161188e+05, 6.6501562e+05,
        6.7395994e+05, 6.9265912e+05, 6.7629800e+05, 6.7267825e+05,
        6.6637769e+05, 6.6724806e+05]], dtype=float32), 'ESS': []}

As a sanity check, we perform the same computation with simple pairs of points.

def load_coordinates(coordinates):
    x = torch.FloatTensor(coordinates).type(dtype)
    x -= x.mean(0)
    scale = (x**2).sum(1).mean(0)
    x /= scale
    return x


A = load_coordinates([[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
B = load_coordinates([[0.0, 1.0, 0.0], [0.0, -1.0, 0.0]])

distribution = WassersteinDistribution(A, B, temperature=1e-4)

As expected, all the symmetries of the problem are respected by our sampler:

start = space.uniform_sample(N)
proposal = BallProposal(space, scale=[0.1, 0.2, 0.5, 1.0, 2.0])

moka_sampler = MOKA_CMC(space, start, proposal, annealing=5).fit(distribution)
display_samples(moka_sampler, iterations=100, runs=2, small=False)


plt.show()
  • it = 0
  • it = 1
  • it = 2
  • it = 5
  • it = 10
  • it = 20
  • it = 50
  • it = 80
  • it = 100
  • plot rotations
  • plot rotations
  • plot rotations
  • plot rotations

Total running time of the script: ( 0 minutes 32.490 seconds)

Gallery generated by Sphinx-Gallery