The low-level interface of KeOps is the
Genred module, which allows users to define and reduce generic operations. Depending on your framework, you may import
Genred using either:
from pykeops.numpy import Genred # for NumPy users, or... from pykeops.torch import Genred # for PyTorch users.
In both cases,
Genred is a class with no methods: its instantiation simply returns a numerical function that can be called at will.
Genred(...)takes as input a bunch of strings that specify the desired computation. It returns a python function or PyTorch layer, callable on numpy arrays or torch tensors. The syntax is:
my_red = Genred(formula, aliases, reduction_op='Sum', axis=0, dtype='float32')
Call: The variable
my_rednow refers to a callable object wrapped around a set of custom Cuda routines. It may be used on any set of arrays (either NumPy arrays or Torch tensors) with the correct shapes, as described in the
result = my_red(arg_1, arg_2, ..., arg_p, backend='auto', device_id=-1, ranges=None)
Using the generic syntax, computing a Gaussian-RBF kernel product
can be done with:
import torch from pykeops.torch import Genred # Notice that the parameter gamma is a dim-1 vector, *not* a scalar: gamma = torch.tensor([.5]) # Generate the data as pytorch tensors. If you intend to compute gradients, don't forget the `requires_grad` flag! x = torch.randn(1000,3) y = torch.randn(2000,3) b = torch.randn(2000,2) gaussian_conv = Genred('Exp(-G * SqDist(X,Y)) * B', # F(g,x,y,b) = exp( -g*|x-y|^2 ) * b ['G = Pm(1)', # First arg is a parameter, of dim 1 'X = Vi(3)', # Second arg is indexed by "i", of dim 3 'Y = Vj(3)', # Third arg is indexed by "j", of dim 3 'B = Vj(2)'], # Fourth arg is indexed by "j", of dim 2 reduction_op='Sum', axis=1) # Summation over "j" # N.B.: a.shape == [1000, 2] a = gaussian_conv(gamma, x, y, b) # By explicitly specifying the backend, you can try to optimize your pipeline: a = gaussian_conv(gamma, x, y, b, backend='GPU') a = gaussian_conv(gamma, x, y, b, backend='CPU')
More examples can be found in the gallery.