Note
Go to the end to download the full example code
A wrapper for NumPy and PyTorch arrays
KeOps brings semi-symbolic calculus to modern computing libraries: it alleviates the need for huge intermediate variables such as kernel or distance matrices in machine learning and computational geometry.
First steps
A simple interface to the KeOps inner routines is provided by
the pykeops.numpy.LazyTensor
or pykeops.torch.LazyTensor
symbolic wrapper, to be used with NumPy arrays or PyTorch
tensors respectively.
To illustrate its main features on a simple example, let’s generate two point
clouds
import numpy as np
M, N = 1000, 2000
x = np.random.rand(M, 2)
y = np.random.rand(N, 2)
With NumPy, an efficient way of computing the index of the nearest y-neighbor
for all points numpy.argmin()
reduction on the M-by-N matrix of squared distances
computed using tensorized, broadcasted operators:
x_i = x[:, None, :] # (M, 1, 2) numpy array
y_j = y[None, :, :] # (1, N, 2) numpy array
D_ij = ((x_i - y_j) ** 2).sum(-1) # (M, N) array of squared distances |x_i-y_j|^2
s_i = np.argmin(D_ij, axis=1) # (M,) array of integer indices
print(s_i[:10])
[1489 184 540 1013 1911 89 661 309 926 846]
That’s good! Going further, we can speed-up these computations using the CUDA routines of the PyTorch library:
import torch
use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
x_i = tensor(x[:, None, :]) # (M, 1, 2) torch tensor
y_j = tensor(y[None, :, :]) # (1, N, 2) torch tensor
D_ij = ((x_i - y_j) ** 2).sum(-1) # (M, N) tensor of squared distances |x_i-y_j|^2
s_i = D_ij.argmin(dim=1) # (M,) tensor of integer indices
print(s_i[:10])
/home/code/keops/pykeops/pykeops/tutorials/a_LazyTensors/plot_lazytensors_a.py:63: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
x_i = tensor(x[:, None, :]) # (M, 1, 2) torch tensor
tensor([1489, 184, 540, 1013, 1911, 89, 661, 309, 926, 846],
device='cuda:0')
But can we scale to larger point clouds?
Unfortunately, tensorized codes will throw an exception
as soon as the M-by-N matrix
M, N = (100000, 200000) if use_cuda else (1000, 2000)
x = np.random.rand(M, 2)
y = np.random.rand(N, 2)
x_i = tensor(x[:, None, :]) # (M, 1, 2) torch tensor
y_j = tensor(y[None, :, :]) # (1, N, 2) torch tensor
try:
D_ij = ((x_i - y_j) ** 2).sum(-1) # (M, N) tensor of squared distances |x_i-y_j|^2
except RuntimeError as err:
print(err)
CUDA out of memory. Tried to allocate 149.01 GiB. GPU 0 has a total capacity of 23.68 GiB of which 23.35 GiB is free. Including non-PyTorch memory, this process has 314.00 MiB memory in use. Of the allocated memory 9.93 MiB is allocated by PyTorch, and 8.07 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
That’s unfortunate… And unexpected! After all, modern GPUs routinely handle the real-time rendering of scenes with millions of triangles moving around. So how do graphics programmers achieve such a level of performance?
The key to efficient numerical schemes is to remark that
even though the distance matrix
from pykeops.numpy import LazyTensor as LazyTensor_np
x_i = LazyTensor_np(
x[:, None, :]
) # (M, 1, 2) KeOps LazyTensor, wrapped around the numpy array x
y_j = LazyTensor_np(
y[None, :, :]
) # (1, N, 2) KeOps LazyTensor, wrapped around the numpy array y
D_ij = ((x_i - y_j) ** 2).sum(-1) # **Symbolic** (M, N) matrix of squared distances
print(D_ij)
KeOps LazyTensor
formula: Sum(Square((Var(0,2,0) - Var(1,2,1))))
shape: (100000, 200000)
With KeOps, implementing lazy numerical schemes really
is that simple!
Our LazyTensor
variables
are encoded as a list of data arrays plus an arbitrary
symbolic formula, written with a custom mathematical syntax
that is modified after each “pythonic” operation such as -
, **2
or .exp()
.
We can then perform a pykeops.torch.LazyTensor.argmin()
reduction with
an efficient Map-Reduce scheme, implemented
as a templated CUDA kernel around
our custom formula.
As evidenced by our benchmarks,
the KeOps routines have a linear memory footprint
and generally outperform tensorized GPU implementations by two orders of magnitude.
s_i = D_ij.argmin(dim=1).ravel() # genuine (M,) array of integer indices
print("s_i is now a {} of shape {}.".format(type(s_i), s_i.shape))
print(s_i[:10])
s_i is now a <class 'numpy.ndarray'> of shape (100000,).
[192686 18738 126043 98105 86023 155003 136028 111120 32777 166987]
Going further, we can combine LazyTensors
using a wide range of mathematical operations.
For instance, with data arrays stored directly on the GPU,
an exponential kernel dot product
in dimension D=10 can be performed with:
from pykeops.torch import LazyTensor
D = 10
x = torch.randn(M, D).type(tensor) # M target points in dimension D, stored on the GPU
y = torch.randn(N, D).type(tensor) # N source points in dimension D, stored on the GPU
b = torch.randn(N, 4).type(
tensor
) # N values of the 4D source signal, stored on the GPU
x.requires_grad = True # In the next section, we'll compute gradients wrt. x!
x_i = LazyTensor(x[:, None, :]) # (M, 1, D) LazyTensor
y_j = LazyTensor(y[None, :, :]) # (1, N, D) LazyTensor
D_ij = ((x_i - y_j) ** 2).sum(-1).sqrt() # Symbolic (M, N) matrix of distances
K_ij = (-D_ij).exp() # Symbolic (M, N) Laplacian (aka. exponential) kernel matrix
a_i = K_ij @ b # The matrix-vector product "@" can be used on "raw" PyTorch tensors!
print("a_i is now a {} of shape {}.".format(type(a_i), a_i.shape))
a_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 4]).
Note
KeOps LazyTensors have two symbolic or “virtual” axes at positions -3 and -2. Operations on the last “vector” dimension (-1) or on optional “batch” dimensions (-4 and beyond) are evaluated lazily. On the other hand, a reduction on one of the two symbolic axes (-2 or -3) triggers an explicit computation: we return a standard dense array with no symbolic axes.
Automatic differentiation
KeOps
fully support the torch.autograd
engine:
we can backprop through KeOps reductions as easily as through
vanilla PyTorch operations.
For instance, coming back to the kernel dot product above,
we can compute the gradient
with:
[g_i] = torch.autograd.grad((a_i**2).sum(), [x], create_graph=True)
print("g_i is now a {} of shape {}.".format(type(g_i), g_i.shape))
g_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 10]).
As usual with PyTorch, having set the create_graph=True
option
allows us to compute higher-order derivatives as needed:
[h_i] = torch.autograd.grad(g_i.exp().sum(), [x], create_graph=True)
print("h_i is now a {} of shape {}.".format(type(h_i), h_i.shape))
h_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 10]).
Warning
As of today, backpropagation is not supported through
the pykeops.torch.LazyTensor.min()
, pykeops.torch.LazyTensor.max()
or pykeops.torch.LazyTensor.Kmin()
reductions:
we’re working on it, but are not there just yet.
Until then, a simple workaround is to use
the indices computed by the
pykeops.torch.LazyTensor.argmin()
, pykeops.torch.LazyTensor.argmax()
or pykeops.torch.LazyTensor.argKmin()
reductions to define a fully differentiable PyTorch tensor as we now explain.
Coming back to our example about nearest neighbors in the unit cube:
x = torch.randn(M, 3).type(tensor)
y = torch.randn(N, 3).type(tensor)
x.requires_grad = True
x_i = LazyTensor(x[:, None, :]) # (M, 1, 3) LazyTensor
y_j = LazyTensor(y[None, :, :]) # (1, N, 3) LazyTensor
D_ij = ((x_i - y_j) ** 2).sum(-1) # Symbolic (M, N) matrix of squared distances
We could compute the (M,)
vector of squared distances to the nearest y-neighbor with:
to_nn = D_ij.min(dim=1).view(-1)
But instead, using:
s_i = D_ij.argmin(dim=1).view(-1) # (M,) integer Torch tensor
to_nn_alt = ((x - y[s_i, :]) ** 2).sum(-1)
outputs the same result, while also allowing us to compute arbitrary gradients:
print(
"Difference between the two vectors: {:.2e}".format((to_nn - to_nn_alt).abs().max())
)
[g_i] = torch.autograd.grad(to_nn_alt.sum(), [x])
print("g_i is now a {} of shape {}.".format(type(g_i), g_i.shape))
Difference between the two vectors: 1.19e-07
g_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 3]).
The only real downside here is that we had to write twice the “squared distance” formula that specifies our computation. We hope to fix this (minor) inconvenience sooner rather than later!
Batch processing
As should be expected, LazyTensors
also provide full support of batch processing,
with broadcasting over dummy (=1) batch dimensions:
A, B = 7, 3 # Batch dimensions
x_i = LazyTensor(torch.randn(A, B, M, 1, D))
l_i = LazyTensor(torch.randn(1, 1, M, 1, D))
y_j = LazyTensor(torch.randn(1, B, 1, N, D))
s = LazyTensor(torch.rand(A, 1, 1, 1, 1))
D_ij = ((l_i * x_i - y_j) ** 2).sum(-1) # Symbolic (A, B, M, N, 1) LazyTensor
K_ij = -1.6 * D_ij / (1 + s**2) # Some arbitrary (A, B, M, N, 1) Kernel matrix
a_i = K_ij.sum(dim=3)
print("a_i is now a {} of shape {}.".format(type(a_i), a_i.shape))
a_i is now a <class 'torch.Tensor'> of shape torch.Size([7, 3, 100000, 1]).
Everything works just fine, with two major caveats:
The structure of KeOps computations is still a little bit rigid:
LazyTensors
should only be used in situations where the large dimensions M and N over which the main reduction is performed are in positions -3 and -2 (respectively), with vector variables in position -1 and an arbitrary number of batch dimensions beforehand. We’re working towards a full support of tensor variables, but this will probably take some time to implement and test properly…KeOps
LazyTensors
never collapse their last “dimension”, even after a.sum(-1)
reduction whose keepdim argument is implicitely set to True.
print("Convenient, numpy-friendly shape: ", K_ij.shape)
print("Actual shape, used internally by KeOps: ", K_ij._shape)
Convenient, numpy-friendly shape: (7, 3, 100000, 200000)
Actual shape, used internally by KeOps: (7, 3, 100000, 200000, 1)
This is the reason why in the example above,
a_i is a 4D Tensor of shape (7, 3, 1000, 1)
and not
a 3D Tensor of shape (7, 3, 1000)
.
Supported formulas
The full range of mathematical operations supported by
LazyTensors
is described
in our API documentation.
Let’s just mention that the lines below define valid computations:
x_i = LazyTensor(torch.randn(A, B, M, 1, D))
l_i = LazyTensor(torch.randn(1, 1, M, 1, D))
y_j = LazyTensor(torch.randn(1, B, 1, N, D))
s = LazyTensor(torch.rand(A, 1, 1, 1, 1))
F_ij = (
(x_i**1.5 + y_j / l_i).cos() - (x_i | y_j) + (x_i[:, :, :, :, 2] * s.relu() * y_j)
)
print(F_ij)
a_j = F_ij.sum(dim=2)
print("a_j is now a {} of shape {}.".format(type(a_j), a_j.shape))
KeOps LazyTensor
formula: ((Cos((Powf(Var(0,10,0), Var(1,1,2)) + (Var(2,10,1) / Var(3,10,0)))) - (Var(0,10,0) | Var(2,10,1))) + ((Elem(Var(0,10,0),2) * ReLU(Var(4,1,2))) * Var(2,10,1)))
shape: (7, 3, 100000, 200000, 10)
a_j is now a <class 'torch.Tensor'> of shape torch.Size([7, 3, 200000, 10]).
Enjoy! And feel free to check the next tutorial for a discussion
of the varied reduction operations that can be applied to
KeOps LazyTensors
.
Total running time of the script: (0 minutes 11.261 seconds)