Note
Click here to download the full example code
Fancy reductions, solving linear systems¶
As discussed in the previous notebook,
KeOps LazyTensors
support a wide range of mathematical formulas.
Let us now discuss the different operators that may be used
to reduce our large M-by-N symbolic tensors into
vanilla NumPy arrays or PyTorch tensors.
Note
In this tutorial, we stick to the PyTorch interface; but note that apart from a few lines on backpropagation, everything here can be seamlessly translated to vanilla NumPy+KeOps code.
LogSumExp, KMin and advanced reductions¶
First, let’s build some large LazyTensors
S_ij
and V_ij
which respectively handle scalar and vector
formulas:
import torch
from pykeops.torch import LazyTensor
use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
M, N = (100000, 200000) if use_cuda else (1000, 2000)
D = 3
x = torch.randn(M, D).type(tensor)
y = torch.randn(N, D).type(tensor)
x_i = LazyTensor(x[:, None, :]) # (M, 1, D) LazyTensor
y_j = LazyTensor(y[None, :, :]) # (1, N, D) LazyTensor
V_ij = (x_i - y_j) # (M, N, D) symbolic tensor of differences
S_ij = (V_ij ** 2).sum(-1) # (M, N, 1) = (M, N) symbolic matrix of squared distances
print(S_ij)
print(V_ij)
Out:
KeOps LazyTensor
formula: Sum(Square((Var(0,3,0) - Var(1,3,1))))
shape: (100000, 200000)
KeOps LazyTensor
formula: (Var(0,3,0) - Var(1,3,1))
shape: (100000, 200000, 3)
As we’ve seen earlier, the pykeops.torch.LazyTensor.sum()
reduction can be used
on both S_ij
and V_ij
to produce genuine PyTorch 2D tensors:
print("Sum reduction of S_ij wrt. the 'N' dimension:", S_ij.sum(dim=1).shape)
Out:
Compiling libKeOpstorch5bc1766cab in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch5bc1766cab:
formula: Sum_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),0)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Sum reduction of S_ij wrt. the 'N' dimension: torch.Size([100000, 1])
Note that LazyTensors
support reductions over both indexing M
and N
dimensions,
which may be specified using the PyTorch-friendly dim
or the standard NumPy axis
optional arguments:
print("Sum reduction of V_ij wrt. the 'M' dimension:", V_ij.sum(axis=0).shape)
Out:
Compiling libKeOpstorch104ae9182c in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch104ae9182c:
formula: Sum_Reduction((Var(0,3,0) - Var(1,3,1)),1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Sum reduction of V_ij wrt. the 'M' dimension: torch.Size([200000, 3])
Just like PyTorch tensors,
pykeops.torch.LazyTensor
also support a stabilized log-sum-exp reduction,
computed efficiently with a running maximum in the CUDA loop. For example, the
following line computes \(\log(\sum_ie^{S_{ij}})\)
print("LogSumExp reduction of S_ij wrt. the 'M' dimension:", S_ij.logsumexp(dim=0).shape)
Out:
Compiling libKeOpstorch65f45d9d1e in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch65f45d9d1e:
formula: Max_SumShiftExp_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
LogSumExp reduction of S_ij wrt. the 'M' dimension: torch.Size([200000, 1])
This reduction supports a weight parameter, which can be scalar or vector-valued. For example, the following line computes \(\log(\sum_je^{S_{ij}}V_{ij})\)
print("LogSumExp reduction of S_ij, with 'weight' V_ij, wrt. the 'N' dimension:",
S_ij.logsumexp(dim=1, weight=V_ij).shape)
Out:
Compiling libKeOpstorch98b6c9f403 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch98b6c9f403:
formula: Max_SumShiftExpWeight_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),0,(Var(0,3,0) - Var(1,3,1)))
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
LogSumExp reduction of S_ij, with 'weight' V_ij, wrt. the 'N' dimension: torch.Size([100000, 3])
Going further, the pykeops.torch.LazyTensor.min()
, pykeops.torch.LazyTensor.max()
, pykeops.torch.LazyTensor.argmin()
or pykeops.torch.LazyTensor.argmax()
reductions work as expected, following the (sensible) NumPy convention:
print("Min reduction of S_ij wrt. the 'M' dimension:", S_ij.min(dim=0).shape)
print("ArgMin reduction of S_ij wrt. the 'N' dimension:", S_ij.argmin(dim=1).shape)
print("Max reduction of V_ij wrt. the 'M' dimension:", V_ij.max(dim=0).shape)
print("ArgMax reduction of V_ij wrt. the 'N' dimension:", V_ij.argmax(dim=1).shape)
Out:
Compiling libKeOpstorch3b4051c5ea in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch3b4051c5ea:
formula: Min_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Min reduction of S_ij wrt. the 'M' dimension: torch.Size([200000, 1])
ArgMin reduction of S_ij wrt. the 'N' dimension: torch.Size([100000, 1])
Compiling libKeOpstorch68e2861a52 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch68e2861a52:
formula: Max_Reduction((Var(0,3,0) - Var(1,3,1)),1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Max reduction of V_ij wrt. the 'M' dimension: torch.Size([200000, 3])
Compiling libKeOpstorchcb4f208f8f in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorchcb4f208f8f:
formula: ArgMax_Reduction((Var(0,3,0) - Var(1,3,1)),0)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
ArgMax reduction of V_ij wrt. the 'N' dimension: torch.Size([100000, 3])
To compute both quantities in a single pass, feel free to use
the pykeops.torch.LazyTensor.min_argmin()
and pykeops.torch.LazyTensor.max_argmax()
reductions:
m_i, s_i = S_ij.min_argmin(dim=0)
print("Min-ArgMin reduction on S_ij wrt. the 'M' dimension:", m_i.shape, s_i.shape)
m_i, s_i = V_ij.max_argmax(dim=1)
print("Max-ArgMax reduction on V_ij wrt. the 'N' dimension:", m_i.shape, s_i.shape)
Out:
Compiling libKeOpstorchedd0520797 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorchedd0520797:
formula: Min_ArgMin_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Min-ArgMin reduction on S_ij wrt. the 'M' dimension: torch.Size([200000, 1]) torch.Size([200000, 1])
Compiling libKeOpstorchefd2ce2506 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorchefd2ce2506:
formula: Max_ArgMax_Reduction((Var(0,3,0) - Var(1,3,1)),0)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
Max-ArgMax reduction on V_ij wrt. the 'N' dimension: torch.Size([100000, 3]) torch.Size([100000, 3])
More interestingly, KeOps also provides support for
the pykeops.torch.LazyTensor.Kmin()
, pykeops.torch.LazyTensor.argKmin()
and pykeops.torch.LazyTensor.Kmin_argKmin()
reductions that may be used to implement an efficient
K-nearest neighbor algorithm :
K = 5
print("KMin reduction of S_ij wrt. the 'M' dimension:", S_ij.Kmin(K=K, dim=0).shape)
print("ArgKMin reduction of S_ij wrt. the 'N' dimension:", S_ij.argKmin(K=K, dim=1).shape)
Out:
Compiling libKeOpstorch7244a8c89b in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch7244a8c89b:
formula: KMin_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),5,1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
KMin reduction of S_ij wrt. the 'M' dimension: torch.Size([200000, 5])
Compiling libKeOpstorchfbc6a86081 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorchfbc6a86081:
formula: ArgKMin_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),5,0)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
ArgKMin reduction of S_ij wrt. the 'N' dimension: torch.Size([100000, 5])
It even works on vector formulas!
K = 7
print("KMin reduction of V_ij wrt. the 'M' dimension:", V_ij.Kmin(K=K, dim=0).shape)
print("ArgKMin reduction of V_ij wrt. the 'N' dimension:", V_ij.argKmin(K=K, dim=1).shape)
Out:
Compiling libKeOpstorch8bbe3aa0fa in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch8bbe3aa0fa:
formula: KMin_Reduction((Var(0,3,0) - Var(1,3,1)),7,1)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
KMin reduction of V_ij wrt. the 'M' dimension: torch.Size([200000, 7, 3])
Compiling libKeOpstorch4dcb3b2df8 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch4dcb3b2df8:
formula: ArgKMin_Reduction((Var(0,3,0) - Var(1,3,1)),7,0)
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
ArgKMin reduction of V_ij wrt. the 'N' dimension: torch.Size([100000, 7, 3])
Finally, the pykeops.torch.LazyTensor.sumsoftmaxweight()
reduction
may be used to computed weighted SoftMax combinations
with scalar coefficients \(s_{i,j}\) and arbitrary vector weights \(v_{i,j}\):
a_i = S_ij.sumsoftmaxweight(V_ij, dim=1)
print("SumSoftMaxWeight reduction of S_ij, with weights V_ij, wrt. the 'N' dimension:",
a_i.shape)
Out:
Compiling libKeOpstorch5e8a75195f in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch5e8a75195f:
formula: Max_SumShiftExpWeight_Reduction(Sum(Square((Var(0,3,0) - Var(1,3,1)))),0,Concat(IntCst(1),(Var(0,3,0) - Var(1,3,1))))
aliases: Var(0,3,0); Var(1,3,1);
dtype : float32
... Done.
SumSoftMaxWeight reduction of S_ij, with weights V_ij, wrt. the 'N' dimension: torch.Size([100000, 3])
Solving linear systems¶
Inverting large M-by-M linear systems is a fundamental problem in applied mathematics. To help you solve problems of the form
KeOps pykeops.torch.LazyTensor
support
a simple LazyTensor.solve(b, alpha=1e-10)
operation that can be used as follows:
x = torch.randn(M, D, requires_grad=True).type(tensor) # Random point cloud
x_i = LazyTensor(x[:, None, :]) # (M, 1, D) LazyTensor
x_j = LazyTensor(x[None, :, :]) # (1, M, D) LazyTensor
K_xx = (- ((x_i - x_j) ** 2).sum(-1)).exp() # Symbolic (M, M) Gaussian kernel matrix
alpha = .1 # "Ridge" regularization parameter
b_i = torch.randn(M, 4).type(tensor) # Target signal, supported by the x_i's
a_i = K_xx.solve(b_i, alpha=alpha) # Source signal, supported by the x_i's
print("a_i is now a {} of shape {}.".format(type(a_i), a_i.shape))
Out:
Compiling libKeOpstorch7e7ac5a1fe in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch7e7ac5a1fe:
formula: Sum_Reduction((Exp(Minus(Sum(Square((Var(1,3,0) - Var(2,3,1)))))) * Var(0,4,1)),0)
aliases: Var(0,4,1); Var(1,3,0); Var(2,3,1);
dtype : float32
... Done.
a_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 4]).
As expected, we can now check that:
c_i = alpha * a_i + K_xx @ a_i # Reconstructed target signal
print("Mean squared reconstruction error: {:.2e}".format(((c_i - b_i) ** 2).mean()))
Out:
Compiling libKeOpstorch2349177526 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch2349177526:
formula: Sum_Reduction((Exp(Minus(Sum(Square((Var(0,3,0) - Var(1,3,1)))))) * Var(2,4,1)),0)
aliases: Var(0,3,0); Var(1,3,1); Var(2,4,1);
dtype : float32
... Done.
Mean squared reconstruction error: 7.29e-06
Please note that just like (nearly) all the other LazyTensor
methods,
pykeops.torch.LazyTensor.solve()
fully supports the torch.autograd
module:
[g_i] = torch.autograd.grad((a_i ** 2).sum(), [x])
print("g_i is now a {} of shape {}.".format(type(g_i), g_i.shape))
Out:
Compiling libKeOpstorch5bcd4459b2 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch5bcd4459b2:
formula: Grad_WithSavedForward(Sum_Reduction((Exp(Minus(Sum(Square((Var(1,3,0) - Var(2,3,1)))))) * Var(0,4,1)),0), Var(1,3,0), Var(3,4,0), Var(4,4,0))
aliases: Var(0,4,1); Var(1,3,0); Var(2,3,1); Var(3,4,0); Var(4,4,0);
dtype : float32
... Done.
Compiling libKeOpstorch741fb857a0 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch741fb857a0:
formula: Grad_WithSavedForward(Sum_Reduction((Exp(Minus(Sum(Square((Var(1,3,0) - Var(2,3,1)))))) * Var(0,4,1)),0), Var(2,3,1), Var(3,4,0), Var(4,4,0))
aliases: Var(0,4,1); Var(1,3,0); Var(2,3,1); Var(3,4,0); Var(4,4,0);
dtype : float32
... Done.
g_i is now a <class 'torch.Tensor'> of shape torch.Size([100000, 3]).
Warning
As of today, the pykeops.torch.LazyTensor.solve()
operator only implements
a conjugate gradient descent
under the assumption that K_xx is a symmetric, positive-definite matrix.
To solve generic systems, you could either
interface KeOps with the routines of the SciPy package
or implement your own solver, mimicking our
reference implementation.
Total running time of the script: ( 7 minutes 5.267 seconds)