# Benchmarking Gaussian convolutions in high dimensions

Let’s compare the performances of PyTorch and KeOps on simple Gaussian RBF kernel products, as the dimension grows.

## Setup

import torch
from matplotlib import pyplot as plt

from benchmark_utils import random_normal, full_benchmark

use_cuda = torch.cuda.is_available()


Benchmark specifications:

N = 10000  # Number of samples
# Dimensions to test:
Dims = [1, 3, 5, 10, 20, 30, 50, 80, 100, 120, 150, 200, 300, 500, 1000, 2000, 3000]


Synthetic dataset.

def generate_samples(D, device="cuda", lang="torch", batchsize=1, **kwargs):
"""Generates two point clouds x, y and a scalar signal b of size N.

Args:
D (int): dimension of the ambient space.
device (str, optional): "cuda", "cpu", etc. Defaults to "cuda".
lang (str, optional): "torch", "numpy", etc. Defaults to "torch".
batchsize (int, optional): number of experiments to run in parallel. Defaults to None.

Returns:
3-uple of arrays: x, y, b
"""
randn = random_normal(device=device, lang=lang)

x = randn((batchsize, N, D))
y = randn((batchsize, N, D))
b = randn((batchsize, N, 1))

return x, y, b


Define a simple Gaussian RBF product, using a tensorized implementation. Note that expanding the squared norm $$\|x-y\|^2$$ as a sum $$\|x\|^2 - 2 \langle x, y \rangle + \|y\|^2$$ allows us to leverage the fast matrix-matrix product of the BLAS/cuBLAS libraries.

def gaussianconv_pytorch(x, y, b, tf32=False, **kwargs):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""

# If False, we stick to float32 computations.
# If True, we use TensorFloat32 whenever possible.
torch.backends.cuda.matmul.allow_tf32 = tf32

D_xx = (x * x).sum(-1).unsqueeze(2)  # (B,N,1)
D_xy = torch.matmul(x, y.permute(0, 2, 1))  # (B,N,D) @ (B,D,M) = (B,N,M)
D_yy = (y * y).sum(-1).unsqueeze(1)  # (B,1,M)
D_xy = D_xx - 2 * D_xy + D_yy  # (B,N,M)
K_xy = (-D_xy).exp()  # (B,N,M)

return K_xy @ b  # (B,N,1)


Define a simple Gaussian RBF product, using an online implementation:

from pykeops.torch import generic_sum

def gaussianconv_keops(x, y, b, backend="GPU", **kwargs):
D = x.shape[-1]
fun = generic_sum(
"Exp(X|Y) * B",  # Formula
"A = Vi(1)",  # Output
"X = Vi({})".format(D),  # 1st argument
"Y = Vj({})".format(D),  # 2nd argument
"B = Vj(1)",  # 3rd argument
)
ex = (-(x * x).sum(-1)).exp()[:, :, None]
ey = (-(y * y).sum(-1)).exp()[:, :, None]
return ex * fun(2 * x, y, b * ey, backend=backend)


Same, but without the chunked computation mode:

def gaussianconv_keops_nochunks(x, y, b, backend="GPU", **kwargs):
D = x.shape[-1]
fun = generic_sum(
"Exp(X|Y) * B",  # Formula
"A = Vi(1)",  # Output
"X = Vi({})".format(D),  # 1st argument
"Y = Vj({})".format(D),  # 2nd argument
"B = Vj(1)",  # 3rd argument
enable_chunks=False,
)
ex = (-(x * x).sum(-1)).exp()[:, :, None]
ey = (-(y * y).sum(-1)).exp()[:, :, None]
return ex * fun(2 * x, y, b * ey, backend=backend)


## PyTorch vs. KeOps (Gpu)

routines = [
(gaussianconv_pytorch, "PyTorch (GPU, TF32=False)", {"tf32": False}),
(gaussianconv_pytorch, "PyTorch (GPU, TF32=True)", {"tf32": True}),
(gaussianconv_keops_nochunks, "KeOps < 1.4.2 (GPU)", {}),
(gaussianconv_keops, "KeOps >= 1.4.2 (GPU)", {}),
]

full_benchmark(
f"Gaussian Matrix-Vector products in high dimension, with N={N:,} (GPU)",
routines,
generate_samples,
problem_sizes=Dims,
xlabel="Dimension of the points",
)

plt.show()

Benchmarking : Gaussian Matrix-Vector products in high dimension, with N=10,000 (GPU) ===============================
PyTorch (GPU, TF32=False) -------------
1x100 loops of size    1 :   1x100x   3.0 ms
1x100 loops of size    3 :   1x100x   2.9 ms
1x100 loops of size    5 :   1x100x   2.9 ms
1x100 loops of size   10 :   1x100x   2.9 ms
1x100 loops of size   20 :   1x100x   3.0 ms
1x100 loops of size   30 :   1x100x   3.1 ms
1x100 loops of size   50 :   1x100x   3.3 ms
1x100 loops of size   80 :   1x100x   3.6 ms
1x100 loops of size  100 :   1x100x   3.8 ms
1x100 loops of size  120 :   1x100x   4.0 ms
1x100 loops of size  150 :   1x100x   4.3 ms
1x100 loops of size  200 :   1x100x   4.8 ms
1x100 loops of size  300 :   1x100x   6.0 ms
1x100 loops of size  500 :   1x100x   8.1 ms
1x100 loops of size   1 k:   1x100x  13.4 ms
1x100 loops of size   2 k:   1x100x  24.2 ms
1x 10 loops of size   3 k:   1x 10x  35.0 ms
PyTorch (GPU, TF32=True) -------------
1x100 loops of size    1 :   1x100x   2.9 ms
1x100 loops of size    3 :   1x100x   2.9 ms
1x100 loops of size    5 :   1x100x   2.9 ms
1x100 loops of size   10 :   1x100x   2.9 ms
1x100 loops of size   20 :   1x100x   3.0 ms
1x100 loops of size   30 :   1x100x   3.1 ms
1x100 loops of size   50 :   1x100x   3.1 ms
1x100 loops of size   80 :   1x100x   3.0 ms
1x100 loops of size  100 :   1x100x   3.1 ms
1x100 loops of size  120 :   1x100x   3.1 ms
1x100 loops of size  150 :   1x100x   3.4 ms
1x100 loops of size  200 :   1x100x   3.2 ms
1x100 loops of size  300 :   1x100x   3.5 ms
1x100 loops of size  500 :   1x100x   3.7 ms
1x100 loops of size   1 k:   1x100x   4.5 ms
1x100 loops of size   2 k:   1x100x   6.1 ms
1x100 loops of size   3 k:   1x100x   7.9 ms
KeOps < 1.4.2 (GPU) -------------
1x100 loops of size    1 :   1x100x 589.8 µs
1x100 loops of size    3 :   1x100x 617.0 µs
1x100 loops of size    5 :   1x100x 675.4 µs
1x100 loops of size   10 :   1x100x 894.7 µs
1x100 loops of size   20 :   1x100x   1.5 ms
1x100 loops of size   30 :   1x100x   1.6 ms
1x100 loops of size   50 :   1x100x   4.1 ms
1x100 loops of size   80 :   1x100x   4.9 ms
1x100 loops of size  100 :   1x100x   4.3 ms
1x100 loops of size  120 :   1x100x   5.3 ms
1x100 loops of size  150 :   1x100x   8.5 ms
1x100 loops of size  200 :   1x100x  10.8 ms
1x100 loops of size  300 :   1x100x  65.8 ms
1x 10 loops of size  500 :   1x 10x 221.7 ms
1x  1 loops of size   1 k:   1x  1x    1.6 s
1x  1 loops of size   2 k:   1x  1x    6.9 s
1x  1 loops of size   3 k:   1x  1x   15.5 s
** Too slow!
KeOps >= 1.4.2 (GPU) -------------
1x100 loops of size    1 :   1x100x 584.1 µs
1x100 loops of size    3 :   1x100x 616.3 µs
1x100 loops of size    5 :   1x100x 678.4 µs
1x100 loops of size   10 :   1x100x 900.1 µs
1x100 loops of size   20 :   1x100x   1.5 ms
1x100 loops of size   30 :   1x100x   1.7 ms
1x100 loops of size   50 :   1x100x   4.1 ms
1x100 loops of size   80 :   1x100x   4.9 ms
1x100 loops of size  100 :   1x100x   7.9 ms
1x100 loops of size  120 :   1x100x  12.2 ms
1x100 loops of size  150 :   1x100x  29.2 ms
1x 10 loops of size  200 :   1x 10x  23.9 ms
1x 10 loops of size  300 :   1x 10x  33.7 ms
1x 10 loops of size  500 :   1x 10x  58.0 ms
1x 10 loops of size   1 k:   1x 10x 127.5 ms
1x 10 loops of size   2 k:   1x 10x 256.8 ms
1x  1 loops of size   3 k:   1x  1x 384.3 ms


Total running time of the script: (1 minutes 30.816 seconds)

