Note
Go to the end to download the full example code
Scaling up Gaussian convolutions on 3D point clouds
Let’s compare the performance of PyTorch and KeOps on simple Gaussian RBF kernel products, as the number of samples grows from 100 to 1,000,000.
Note
In this demo, we use exact bruteforce computations (tensorized for PyTorch and online for KeOps), without leveraging any multiscale or low-rank (Nystroem/multipole) decomposition of the Kernel matrix. We are working on providing transparent support for these approximations in KeOps.
Setup
import numpy as np
import torch
from matplotlib import pyplot as plt
from benchmark_utils import flatten, random_normal, full_benchmark
use_cuda = torch.cuda.is_available()
print(
f"Running torch version {torch.__version__} with {'GPU' if use_cuda else 'CPU'}..."
)
Running torch version 2.2.0 with GPU...
Benchmark specifications:
# Numbers of samples that we'll loop upon:
problem_sizes = flatten(
[[1 * 10**k, 2 * 10**k, 5 * 10**k] for k in [2, 3, 4, 5]] + [[10**6]]
)
D = 3 # We work with 3D points
MAX_TIME = 0.1 # Run each experiment for at most 0.1 second
Synthetic dataset. Feel free to use a Stanford Bunny, or whatever!
def generate_samples(N, device="cuda", lang="torch", batchsize=1, **kwargs):
"""Generates two point clouds x, y and a scalar signal b of size N.
Args:
N (int): number of point.
device (str, optional): "cuda", "cpu", etc. Defaults to "cuda".
lang (str, optional): "torch", "numpy", etc. Defaults to "torch".
batchsize (int, optional): number of experiments to run in parallel. Defaults to None.
Returns:
3-uple of arrays: x, y, b
"""
randn = random_normal(device=device, lang=lang)
x = randn((batchsize, N, D))
y = randn((batchsize, N, D))
b = randn((batchsize, N, 1))
return x, y, b
Define a simple Gaussian RBF product, using a tensorized implementation. Note that expanding the squared norm \(\|x-y\|^2\) as a sum \(\|x\|^2 - 2 \langle x, y \rangle + \|y\|^2\) allows us to leverage the fast matrix-matrix product of the BLAS/cuBLAS libraries.
def gaussianconv_numpy(x, y, b, **kwargs):
"""(1,N,D), (1,N,D), (1,N,1) -> (1,N,1)"""
# N.B.: NumPy does not really support batch matrix multiplications:
x, y, b = x.squeeze(0), y.squeeze(0), b.squeeze(0)
D_xx = np.sum((x**2), axis=-1)[:, None] # (N,1)
D_xy = x @ y.T # (N,D) @ (D,M) = (N,M)
D_yy = np.sum((y**2), axis=-1)[None, :] # (1,M)
D_xy = D_xx - 2 * D_xy + D_yy # (N,M)
K_xy = np.exp(-D_xy) # (B,N,M)
return K_xy @ b
def gaussianconv_pytorch_eager(x, y, b, tf32=False, cdist=False, **kwargs):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
# If False, we stick to float32 computations.
# If True, we use TensorFloat32 whenever possible.
# As of PyTorch 2.0, this has no impact on run times so we
# do not use this option.
torch.backends.cuda.matmul.allow_tf32 = tf32
# We may use the cdist function to compute the squared norms:
if cdist:
D_xy = torch.cdist(x, y, p=2) # (B,N,M)
else:
D_xx = (x * x).sum(-1).unsqueeze(2) # (B,N,1)
D_xy = torch.matmul(x, y.permute(0, 2, 1)) # (B,N,D) @ (B,D,M) = (B,N,M)
D_yy = (y * y).sum(-1).unsqueeze(1) # (B,1,M)
D_xy = D_xx - 2 * D_xy + D_yy # (B,N,M)
K_xy = (-D_xy).exp() # (B,N,M)
return K_xy @ b # (B,N,1)
PyTorch 2.0 introduced a new compiler that improves speed and memory usage.
We use it with dynamic shapes to avoid re-compilation for every value of N.
Please note that torch.compile(...)
is still experimental:
we will update this demo with new PyTorch releases.
# Inner function to be compiled:
def _gaussianconv_pytorch(x, y, b):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
# Note that cdist is not currently supported by torch.compile with dynamic=True.
D_xx = (x * x).sum(-1).unsqueeze(2) # (B,N,1)
D_xy = torch.matmul(x, y.permute(0, 2, 1)) # (B,N,D) @ (B,D,M) = (B,N,M)
D_yy = (y * y).sum(-1).unsqueeze(1) # (B,1,M)
D_xy = D_xx - 2 * D_xy + D_yy # (B,N,M)
K_xy = (-D_xy).exp() # (B,N,M)
return K_xy @ b # (B,N,1)
# Compile the function:
gaussianconv_pytorch_compiled = torch.compile(_gaussianconv_pytorch, dynamic=True)
# Wrap it to ignore optional keyword arguments:
def gaussianconv_pytorch_dynamic(x, y, b, **kwargs):
return gaussianconv_pytorch_compiled(x, y, b)
# And apply our function to compile the function once and for all:
# On the GPU, if it is available:
_ = gaussianconv_pytorch_compiled(*generate_samples(1000))
# And on the CPU, in any case:
# _ = gaussianconv_pytorch_compiled(*generate_samples(1000, device="cpu"))
Define a simple Gaussian RBF product, using an online implementation:
from pykeops.torch import generic_sum
fun_gaussianconv_keops = generic_sum(
"Exp(-SqDist(X,Y)) * B", # Formula
"A = Vi(1)", # Output
"X = Vi({})".format(D), # 1st argument
"Y = Vj({})".format(D), # 2nd argument
"B = Vj(1)", # 3rd argument
)
fun_gaussianconv_keops_no_fast_math = generic_sum(
"Exp(-SqDist(X,Y)) * B", # Formula
"A = Vi(1)", # Output
"X = Vi({})".format(D), # 1st argument
"Y = Vj({})".format(D), # 2nd argument
"B = Vj(1)", # 3rd argument
use_fast_math=False,
)
def gaussianconv_keops(x, y, b, backend="GPU", **kwargs):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
x, y, b = x.squeeze(), y.squeeze(), b.squeeze()
return fun_gaussianconv_keops(x, y, b, backend=backend)
def gaussianconv_keops_no_fast_math(x, y, b, backend="GPU", **kwargs):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
x, y, b = x.squeeze(), y.squeeze(), b.squeeze()
return fun_gaussianconv_keops_no_fast_math(x, y, b, backend=backend)
Finally, perform the same operation with our high-level pykeops.torch.LazyTensor
wrapper:
from pykeops.torch import LazyTensor
def gaussianconv_lazytensor(x, y, b, backend="GPU", **kwargs):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
x_i = LazyTensor(x.unsqueeze(-2)) # (B, M, 1, D)
y_j = LazyTensor(y.unsqueeze(-3)) # (B, 1, N, D)
D_ij = ((x_i - y_j) ** 2).sum(-1) # (B, M, N, 1)
K_ij = (-D_ij).exp() # (B, M, N, 1)
S_ij = K_ij * b.unsqueeze(-3) # (B, M, N, 1) * (B, 1, N, 1)
return S_ij.sum(dim=2, backend=backend)
NumPy vs. PyTorch vs. KeOps (Gpu)
if use_cuda:
routines = [
(gaussianconv_numpy, "Numpy (CPU)", {"lang": "numpy"}),
(gaussianconv_pytorch_eager, "PyTorch (GPU, matmul)", {"cdist": False}),
(gaussianconv_pytorch_eager, "PyTorch (GPU, cdist)", {"cdist": True}),
(
gaussianconv_pytorch_dynamic,
"PyTorch (GPU, compiled with dynamic shapes)",
{},
),
(gaussianconv_lazytensor, "KeOps (GPU, LazyTensor)", {}),
(
gaussianconv_lazytensor,
"KeOps (GPU, LazyTensor, batchsize=100)",
{"batchsize": 100},
),
(gaussianconv_keops, "KeOps (GPU, Genred)", {}),
(gaussianconv_keops_no_fast_math, "KeOps (GPU, use_fast_math=False)", {}),
]
full_benchmark(
"Gaussian Matrix-Vector products (GPU)",
routines,
generate_samples,
problem_sizes=problem_sizes,
max_time=MAX_TIME,
)
Benchmarking : Gaussian Matrix-Vector products (GPU) ===============================
Numpy (CPU) -------------
1x100 loops of size 100 : 1x100x 41.7 µs
1x100 loops of size 200 : 1x100x 166.0 µs
1x100 loops of size 500 : 1x100x 815.0 µs
1x100 loops of size 1 k: 1x100x 3.1 ms
1x 10 loops of size 2 k: 1x 10x 14.2 ms
1x 1 loops of size 5 k: 1x 1x 132.3 ms
** Too slow!
PyTorch (GPU, matmul) -------------
1x100 loops of size 100 : 1x100x 166.1 µs
1x100 loops of size 200 : 1x100x 165.6 µs
1x100 loops of size 500 : 1x100x 163.8 µs
1x100 loops of size 1 k: 1x100x 162.8 µs
1x100 loops of size 2 k: 1x100x 161.0 µs
1x100 loops of size 5 k: 1x100x 809.7 µs
1x100 loops of size 10 k: 1x100x 2.9 ms
1x 10 loops of size 20 k: 1x 10x 11.4 ms
1x 1 loops of size 50 k: 1x 1x 72.0 ms
CUDA out of memory. Tried to allocate 37.25 GiB. GPU 0 has a total capacity of 79.15 GiB of which 381.38 MiB is free. Including non-PyTorch memory, this process has 78.62 GiB memory in use. Of the allocated memory 74.53 GiB is allocated by PyTorch, and 21.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
** Runtime error!
PyTorch (GPU, cdist) -------------
1x100 loops of size 100 : 1x100x 137.4 µs
1x100 loops of size 200 : 1x100x 138.4 µs
1x100 loops of size 500 : 1x100x 136.1 µs
1x100 loops of size 1 k: 1x100x 134.7 µs
1x100 loops of size 2 k: 1x100x 133.7 µs
1x100 loops of size 5 k: 1x100x 652.4 µs
1x100 loops of size 10 k: 1x100x 2.4 ms
1x 10 loops of size 20 k: 1x 10x 9.3 ms
1x 10 loops of size 50 k: 1x 10x 58.3 ms
CUDA out of memory. Tried to allocate 37.25 GiB. GPU 0 has a total capacity of 79.15 GiB of which 383.38 MiB is free. Including non-PyTorch memory, this process has 78.62 GiB memory in use. Of the allocated memory 74.53 GiB is allocated by PyTorch, and 20.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
** Runtime error!
PyTorch (GPU, compiled with dynamic shapes) -------------
1x100 loops of size 100 : 1x100x 152.6 µs
1x100 loops of size 200 : 1x100x 150.3 µs
1x100 loops of size 500 : 1x100x 147.6 µs
1x100 loops of size 1 k: 1x100x 146.6 µs
1x100 loops of size 2 k: 1x100x 146.7 µs
1x100 loops of size 5 k: 1x100x 271.2 µs
1x100 loops of size 10 k: 1x100x 970.1 µs
1x100 loops of size 20 k: 1x100x 3.8 ms
1x 10 loops of size 50 k: 1x 10x 23.9 ms
1x 1 loops of size 100 k: 1x 1x 119.1 ms
** Too slow!
KeOps (GPU, LazyTensor) -------------
1x100 loops of size 100 : 1x100x 353.5 µs
1x100 loops of size 200 : 1x100x 349.4 µs
1x100 loops of size 500 : 1x100x 367.7 µs
1x100 loops of size 1 k: 1x100x 397.2 µs
1x100 loops of size 2 k: 1x100x 456.7 µs
1x100 loops of size 5 k: 1x100x 632.1 µs
1x100 loops of size 10 k: 1x100x 925.4 µs
1x100 loops of size 20 k: 1x100x 1.5 ms
1x 10 loops of size 50 k: 1x 10x 4.3 ms
1x 10 loops of size 100 k: 1x 10x 12.7 ms
1x 1 loops of size 200 k: 1x 1x 45.0 ms
1x 1 loops of size 500 k: 1x 1x 262.0 ms
** Too slow!
KeOps (GPU, LazyTensor, batchsize=100) -------------
100x100 loops of size 100 : 100x100x 3.5 µs
100x100 loops of size 200 : 100x100x 3.6 µs
100x100 loops of size 500 : 100x100x 3.9 µs
100x100 loops of size 1 k: 100x100x 5.0 µs
100x100 loops of size 2 k: 100x100x 8.5 µs
100x100 loops of size 5 k: 100x100x 29.8 µs
100x100 loops of size 10 k: 100x100x 102.4 µs
100x100 loops of size 20 k: 100x100x 388.3 µs
100x100 loops of size 50 k: 100x100x 2.4 ms
100x 10 loops of size 100 k: 100x 10x 9.7 ms
100x 10 loops of size 200 k: 100x 10x 38.8 ms
100x 1 loops of size 500 k: 100x 1x 243.0 ms
** Too slow!
KeOps (GPU, Genred) -------------
1x100 loops of size 100 : 1x100x 149.3 µs
1x100 loops of size 200 : 1x100x 152.5 µs
1x100 loops of size 500 : 1x100x 164.5 µs
1x100 loops of size 1 k: 1x100x 183.5 µs
1x100 loops of size 2 k: 1x100x 222.6 µs
1x100 loops of size 5 k: 1x100x 337.6 µs
1x100 loops of size 10 k: 1x100x 529.5 µs
1x100 loops of size 20 k: 1x100x 914.1 µs
1x100 loops of size 50 k: 1x100x 3.9 ms
1x 10 loops of size 100 k: 1x 10x 11.8 ms
1x 1 loops of size 200 k: 1x 1x 44.4 ms
1x 1 loops of size 500 k: 1x 1x 267.7 ms
** Too slow!
KeOps (GPU, use_fast_math=False) -------------
1x100 loops of size 100 : 1x100x 151.1 µs
1x100 loops of size 200 : 1x100x 157.1 µs
1x100 loops of size 500 : 1x100x 174.3 µs
1x100 loops of size 1 k: 1x100x 202.2 µs
1x100 loops of size 2 k: 1x100x 257.8 µs
1x100 loops of size 5 k: 1x100x 426.0 µs
1x100 loops of size 10 k: 1x100x 704.6 µs
1x100 loops of size 20 k: 1x100x 1.3 ms
1x 10 loops of size 50 k: 1x 10x 6.0 ms
1x 10 loops of size 100 k: 1x 10x 18.7 ms
1x 1 loops of size 200 k: 1x 1x 71.3 ms
1x 1 loops of size 500 k: 1x 1x 434.1 ms
** Too slow!
We make several observations:
Asymptotically, all routines scale in O(N^2): multiplying N by 10 increases the computation time by a factor of 100. This is expected, since we are performing bruteforce computations. However, constants vary wildly between different implementations.
The NumPy implementation is slow, and prevents us from working efficiently with more than 10k points at a time.
The PyTorch GPU implementation is typically 100 times faster than the NumPy CPU code.
The
torch.compile(...)
function, introduced by PyTorch 2.0, is making a real difference. It outperforms eager mode by a factor of 2 to 3.The CUDA kernel generated by KeOps is faster and more scalable than the PyTorch GPU implementation.
All GPU implementations have a constant overhead (< 1ms) which makes them less attractive when working with a single, small point cloud.
This overhead is especially large for the convenient
LazyTensor
syntax. As detailed below, this issue can be mitigated through the use of a batch dimension.
NumPy vs. PyTorch vs. KeOps (Cpu)
routines = [
(gaussianconv_numpy, "Numpy (CPU)", {"device": "cpu", "lang": "numpy"}),
(
gaussianconv_pytorch_eager,
"PyTorch (CPU, matmul)",
{"device": "cpu", "cdist": False},
),
(
gaussianconv_pytorch_eager,
"PyTorch (CPU, cdist)",
{"device": "cpu", "cdist": True},
),
(
gaussianconv_lazytensor,
"KeOps (CPU, LazyTensor)",
{"device": "cpu", "backend": "CPU"},
),
(gaussianconv_keops, "KeOps (CPU, Genred)", {"device": "cpu", "backend": "CPU"}),
]
full_benchmark(
"Gaussian Matrix-Vector products (CPU)",
routines,
generate_samples,
problem_sizes=problem_sizes,
max_time=MAX_TIME,
)
Benchmarking : Gaussian Matrix-Vector products (CPU) ===============================
Numpy (CPU) -------------
1x100 loops of size 100 : 1x100x 41.3 µs
1x100 loops of size 200 : 1x100x 167.7 µs
1x100 loops of size 500 : 1x100x 806.2 µs
1x100 loops of size 1 k: 1x100x 3.0 ms
1x 10 loops of size 2 k: 1x 10x 14.5 ms
1x 1 loops of size 5 k: 1x 1x 132.7 ms
** Too slow!
PyTorch (CPU, matmul) -------------
1x100 loops of size 100 : 1x100x 95.4 µs
1x100 loops of size 200 : 1x100x 134.2 µs
1x100 loops of size 500 : 1x100x 307.9 µs
1x100 loops of size 1 k: 1x100x 745.4 µs
1x100 loops of size 2 k: 1x100x 5.7 ms
1x 10 loops of size 5 k: 1x 10x 62.8 ms
1x 1 loops of size 10 k: 1x 1x 260.4 ms
** Too slow!
PyTorch (CPU, cdist) -------------
1x100 loops of size 100 : 1x100x 103.9 µs
1x100 loops of size 200 : 1x100x 142.0 µs
1x100 loops of size 500 : 1x100x 311.4 µs
1x100 loops of size 1 k: 1x100x 768.4 µs
1x100 loops of size 2 k: 1x100x 3.8 ms
1x 10 loops of size 5 k: 1x 10x 45.7 ms
1x 1 loops of size 10 k: 1x 1x 182.8 ms
** Too slow!
KeOps (CPU, LazyTensor) -------------
1x100 loops of size 100 : 1x100x 329.6 µs
1x100 loops of size 200 : 1x100x 368.9 µs
1x100 loops of size 500 : 1x100x 679.9 µs
1x100 loops of size 1 k: 1x100x 1.8 ms
1x 10 loops of size 2 k: 1x 10x 6.2 ms
1x 10 loops of size 5 k: 1x 10x 36.8 ms
1x 1 loops of size 10 k: 1x 1x 146.8 ms
** Too slow!
KeOps (CPU, Genred) -------------
1x100 loops of size 100 : 1x100x 149.7 µs
1x100 loops of size 200 : 1x100x 197.5 µs
1x100 loops of size 500 : 1x100x 538.8 µs
1x100 loops of size 1 k: 1x100x 1.8 ms
1x 10 loops of size 2 k: 1x 10x 6.7 ms
1x 10 loops of size 5 k: 1x 10x 41.1 ms
1x 1 loops of size 10 k: 1x 1x 163.4 ms
** Too slow!
We note that the KeOps CPU implementation is typically slower than the PyTorch CPU implementation. This is because over the 2017-22 period, we prioritized “peak GPU performance” for research codes and provided a CPU backend mostly for testing and debugging. Going forward, as we work on making KeOps easier to integrate as a backend dependency in mature libraries, improving the performance of the KeOps CPU backend is a priority - both for compilation and runtime performance.
Genred vs. LazyTensor vs. batched LazyTensor
if use_cuda:
routines = [
(gaussianconv_keops, "KeOps (Genred)", {}),
(gaussianconv_lazytensor, "KeOps (LazyTensor)", {}),
(
gaussianconv_lazytensor,
"KeOps (LazyTensor, batchsize=10)",
{"batchsize": 10},
),
]
full_benchmark(
"Gaussian Matrix-Vector products (batch)",
routines,
generate_samples,
problem_sizes=problem_sizes,
max_time=MAX_TIME,
)
plt.show()
Benchmarking : Gaussian Matrix-Vector products (batch) ===============================
KeOps (Genred) -------------
1x100 loops of size 100 : 1x100x 152.4 µs
1x100 loops of size 200 : 1x100x 156.5 µs
1x100 loops of size 500 : 1x100x 170.0 µs
1x100 loops of size 1 k: 1x100x 193.8 µs
1x100 loops of size 2 k: 1x100x 240.5 µs
1x100 loops of size 5 k: 1x100x 382.0 µs
1x100 loops of size 10 k: 1x100x 617.0 µs
1x100 loops of size 20 k: 1x100x 950.4 µs
1x100 loops of size 50 k: 1x100x 3.9 ms
1x 10 loops of size 100 k: 1x 10x 11.8 ms
1x 1 loops of size 200 k: 1x 1x 44.4 ms
1x 1 loops of size 500 k: 1x 1x 267.6 ms
** Too slow!
KeOps (LazyTensor) -------------
1x100 loops of size 100 : 1x100x 344.4 µs
1x100 loops of size 200 : 1x100x 350.3 µs
1x100 loops of size 500 : 1x100x 370.5 µs
1x100 loops of size 1 k: 1x100x 399.0 µs
1x100 loops of size 2 k: 1x100x 459.2 µs
1x100 loops of size 5 k: 1x100x 633.1 µs
1x100 loops of size 10 k: 1x100x 926.8 µs
1x100 loops of size 20 k: 1x100x 1.5 ms
1x 10 loops of size 50 k: 1x 10x 4.3 ms
1x 10 loops of size 100 k: 1x 10x 12.7 ms
1x 1 loops of size 200 k: 1x 1x 44.7 ms
1x 1 loops of size 500 k: 1x 1x 262.0 ms
** Too slow!
KeOps (LazyTensor, batchsize=10) -------------
10x100 loops of size 100 : 10x100x 34.9 µs
10x100 loops of size 200 : 10x100x 35.4 µs
10x100 loops of size 500 : 10x100x 37.2 µs
10x100 loops of size 1 k: 10x100x 40.3 µs
10x100 loops of size 2 k: 10x100x 47.2 µs
10x100 loops of size 5 k: 10x100x 75.4 µs
10x100 loops of size 10 k: 10x100x 160.3 µs
10x100 loops of size 20 k: 10x100x 479.6 µs
10x100 loops of size 50 k: 10x100x 2.6 ms
10x 10 loops of size 100 k: 10x 10x 10.1 ms
10x 1 loops of size 200 k: 10x 1x 39.4 ms
10x 1 loops of size 500 k: 10x 1x 244.5 ms
** Too slow!
As expected, using a batch dimension reduces the relative overhead of
the LazyTensor
syntax.
Total running time of the script: (2 minutes 45.306 seconds)