# K-Nearest Neighbours search (WIP)¶

Let’s compare the performances of PyTorch, JAX, FAISS and KeOps fpr K-NN queries on random samples and standard datasets.

Note

In this demo, we use exact bruteforce computations (tensorized for PyTorch and online for KeOps), without leveraging any multiscale or low-rank (Nystroem/multipole) decomposition of the Kernel matrix. First support for these approximation schemes is scheduled for May-June 2021.

## Setup¶

import numpy as np
import torch
from matplotlib import pyplot as plt
from functools import partial

from benchmark_utils import (
flatten,
random_normal,
full_benchmark,
timer,
tensor,
int_tensor,
jax_tensor,
)
from dataset_utils import generate_samples

use_cuda = torch.cuda.is_available()


Benchmark specifications:

# Values of K that we'll loop upon:
Ks = [1, 2, 5, 10, 20, 50, 100]


## Simple bruteforce implementations¶

Define a simple Gaussian RBF product, using a tensorized implementation. Note that expanding the squared norm $$\|x-y\|^2$$ as a sum $$\|x\|^2 - 2 \langle x, y \rangle + \|y\|^2$$ allows us to leverage the fast matrix-matrix product of the BLAS/cuBLAS libraries.

PyTorch bruteforce:

"""
def KNN_KeOps(K, metric="euclidean", **kwargs):
def fit(x_train):
# Setup the K-NN estimator:
x_train = tensor(x_train)
start = timer()

# N.B.: The "training" time here should be negligible.
elapsed = timer() - start

def f(x_test):
x_test = tensor(x_test)
start = timer()

# Actual K-NN query:

elapsed = timer() - start

indices = indices.cpu().numpy()
return indices, elapsed

return f, elapsed

return fit
"""

def KNN_torch(K, metric="euclidean", **kwargs):
def fit(x_train):
# Setup the K-NN estimator:
x_train = tensor(x_train)
start = timer()
# The "training" time here should be negligible:
x_train_norm = (x_train ** 2).sum(-1)
elapsed = timer() - start

def f(x_test):
x_test = tensor(x_test)
start = timer()

# Actual K-NN query:
if metric == "euclidean":
x_test_norm = (x_test ** 2).sum(-1)
diss = (
x_test_norm.view(-1, 1)
+ x_train_norm.view(1, -1)
- 2 * x_test @ x_train.t()
)

elif metric == "manhattan":
diss = (x_test[:, None, :] - x_train[None, :, :]).abs().sum(dim=2)

elif metric == "angular":
diss = -x_test @ x_train.t()

elif metric == "hyperbolic":
x_test_norm = (x_test ** 2).sum(-1)
diss = (
x_test_norm.view(-1, 1)
+ x_train_norm.view(1, -1)
- 2 * x_test @ x_train.t()
)
diss /= x_test[:, 0].view(-1, 1) * x_train[:, 0].view(1, -1)

out = diss.topk(K, dim=1, largest=False)

elapsed = timer() - start
indices = out.indices.cpu().numpy()
return indices, elapsed

return f, elapsed

return fit


PyTorch bruteforce, with small batches to avoid memory overflows:

def KNN_torch_batch_loop(K, metric="euclidean", **kwargs):
def fit(x_train):
# Setup the K-NN estimator:
x_train = tensor(x_train)
Ntrain, D = x_train.shape
start = timer()
# The "training" time here should be negligible:
x_train_norm = (x_train ** 2).sum(-1)
elapsed = timer() - start

def f(x_test):
x_test = tensor(x_test)

# Estimate the largest reasonable batch size:
Ntest = x_test.shape
#  torch.cuda.get_device_properties(deviceId).total_memory
av_mem = int(5e8)
Ntest_loop = min(max(1, av_mem // (4 * D * Ntrain)), Ntest)
Nloop = (Ntest - 1) // Ntest_loop + 1
# print(f"{Ntest} queries, split in {Nloop} batches of {Ntest_loop} queries each.")
out = int_tensor(Ntest, K)

start = timer()
# Actual K-NN query:
for k in range(Nloop):
x_test_k = x_test[Ntest_loop * k : Ntest_loop * (k + 1), :]
if metric == "euclidean":
x_test_norm = (x_test_k ** 2).sum(-1)
diss = (
x_test_norm.view(-1, 1)
+ x_train_norm.view(1, -1)
- 2 * x_test_k @ x_train.t()
)

elif metric == "manhattan":
diss = (x_test_k[:, None, :] - x_train[None, :, :]).abs().sum(dim=2)

elif metric == "angular":
diss = -x_test_k @ x_train.t()

elif metric == "hyperbolic":
x_test_norm = (x_test_k ** 2).sum(-1)
diss = (
x_test_norm.view(-1, 1)
+ x_train_norm.view(1, -1)
- 2 * x_test_k @ x_train.t()
)
diss /= x_test_k[:, 0].view(-1, 1) * x_train[:, 0].view(1, -1)

out[Ntest_loop * k : Ntest_loop * (k + 1), :] = diss.topk(
K, dim=1, largest=False
).indices
del diss
# torch.cuda.empty_cache()

elapsed = timer() - start
indices = out.cpu().numpy()
return indices, elapsed

return f, elapsed

return fit


Distance matrices with JAX:

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnums=(2, 3))
def knn_jax_fun(x_train, x_test, K, metric):
if metric == "euclidean":
diss = (
(x_test ** 2).sum(-1)[:, None]
+ (x_train ** 2).sum(-1)[None, :]
- 2 * x_test @ x_train.T
)
elif metric == "manhattan":
diss = jax.lax.abs(x_test[:, None, :] - x_train[None, :, :]).sum(-1)
elif metric == "angular":
diss = -x_test @ x_train.T
elif metric == "hyperbolic":
diss = (
(x_test ** 2).sum(-1)[:, None]
+ (x_train ** 2).sum(-1)[None, :]
- 2 * x_test @ x_train.T
)
diss = diss / (x_test[:, 0][:, None] * x_train[:, 0][None, :])

indices = jax.lax.top_k(-diss, K)
return indices


JAX bruteforce:

def KNN_JAX(K, metric="euclidean", **kwargs):
def fit(x_train):

# Setup the K-NN estimator:
start = timer(use_torch=False)
x_train = jax_tensor(x_train)
elapsed = timer(use_torch=False) - start

def f(x_test):
x_test = jax_tensor(x_test)

# Actual K-NN query:
start = timer(use_torch=False)
indices = knn_jax_fun(x_train, x_test, K, metric)
indices = np.array(indices)
elapsed = timer(use_torch=False) - start
return indices, elapsed

return f, elapsed

return fit


JAX bruteforce, with small batches to avoid memory overflows:

def KNN_JAX_batch_loop(K, metric="euclidean", **kwargs):
def fit(x_train):

# Setup the K-NN estimator:
start = timer(use_torch=False)
x_train = jax_tensor(x_train)
elapsed = timer(use_torch=False) - start

def f(x_test):
x_test = jax_tensor(x_test)

# Estimate the largest reasonable batch size
#  torch.cuda.get_device_properties(deviceId).total_memory
av_mem = int(5e8)
Ntrain, D = x_train.shape
Ntest = x_test.shape
Ntest_loop = min(max(1, av_mem // (4 * D * Ntrain)), Ntest)
Nloop = (Ntest - 1) // Ntest_loop + 1
# print(f"{Ntest} queries, split in {Nloop} batches of {Ntest_loop} queries each.")
indices = np.zeros((Ntest, K), dtype=int)

start = timer(use_torch=False)
# Actual K-NN query:
for k in range(Nloop):
x_test_k = x_test[Ntest_loop * k : Ntest_loop * (k + 1), :]
indices[Ntest_loop * k : Ntest_loop * (k + 1), :] = knn_jax_fun(
x_train, x_test_k, K, metric
)
elapsed = timer(use_torch=False) - start
return indices, elapsed

return f, elapsed

return fit


KeOps bruteforce implementation:

from pykeops.torch import LazyTensor, Vi, Vj

def KNN_KeOps(K, metric="euclidean", **kwargs):
def fit(x_train):
# Setup the K-NN estimator:
x_train = tensor(x_train)
start = timer()

# Encoding as KeOps LazyTensors:
D = x_train.shape
X_i = Vi(0, D)
X_j = Vj(1, D)

# Symbolic distance matrix:
if metric == "euclidean":
D_ij = ((X_i - X_j) ** 2).sum(-1)
elif metric == "manhattan":
D_ij = (X_i - X_j).abs().sum(-1)
elif metric == "angular":
D_ij = -(X_i | X_j)
elif metric == "hyperbolic":
D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i * X_j)

# K-NN query operator:
KNN_fun = D_ij.argKmin(K, dim=1)

# N.B.: The "training" time here should be negligible.
elapsed = timer() - start

def f(x_test):
x_test = tensor(x_test)
start = timer()

# Actual K-NN query:
indices = KNN_fun(x_test, x_train)

elapsed = timer() - start

indices = indices.cpu().numpy()
return indices, elapsed

return f, elapsed

return fit


## SciKit-Learn tree-based and bruteforce methods¶

from sklearn.neighbors import NearestNeighbors

def KNN_sklearn(K, metric="euclidean", algorithm=None, **kwargs):

if metric in ["euclidean", "angular"]:
p = 2
elif metric == "manhattan":
p = 1
else:
raise NotImplementedError("This distance is not supported.")

KNN_meth = NearestNeighbors(n_neighbors=K, algorithm=algorithm, p=p, n_jobs=-1)

def fit(x_train):
# Setup the K-NN estimator:
start = timer()
KNN_fun = KNN_meth.fit(x_train).kneighbors
elapsed = timer() - start

def f(x_test):
start = timer()
distances, indices = KNN_fun(x_test)
elapsed = timer() - start

return indices, elapsed

return f, elapsed

return fit

KNN_sklearn_auto = partial(KNN_sklearn, algorithm="auto")
KNN_sklearn_ball_tree = partial(KNN_sklearn, algorithm="ball_tree")
KNN_sklearn_kd_tree = partial(KNN_sklearn, algorithm="kd_tree")
KNN_sklearn_brute = partial(KNN_sklearn, algorithm="brute")


## NumPy vs. PyTorch vs. KeOps (Gpu)¶

def run_KNN_benchmark(name):

# Load the dataset and some info:
dataset = generate_samples(name)(1)
N_train, dimension = dataset["train"].shape
N_test, _ = dataset["test"].shape
metric = dataset["metric"]

# Routines to benchmark:
routines = [
(KNN_sklearn_auto, "sklearn, auto (CPU)", {}),
(KNN_sklearn_ball_tree, "sklearn, Ball-tree (CPU)", {}),
(KNN_sklearn_kd_tree, "sklearn, KD-tree (CPU)", {}),
(KNN_sklearn_brute, "sklearn, bruteforce (CPU)", {}),
(KNN_torch, "PyTorch (GPU)", {}),
(KNN_torch_batch_loop, "PyTorch (small batches, GPU)", {}),
(KNN_KeOps, "KeOps (GPU)", {}),
(KNN_JAX, "JAX (GPU)", {}),
(KNN_JAX_batch_loop, "JAX (small batches, GPU)", {}),
]

# Actual run:
full_benchmark(
f"K-NN search on {name}: {N_test:,} queries on a dataset of {N_train:,} points\nin dimension {dimension:,} with a {metric} metric.",
routines,
generate_samples(name),
min_time=1e-4,
max_time=10,
problem_sizes=Ks,
xlabel="Number of neighbours K",
)


## On random samples:¶

Small dataset in $$\mathbb{R}^3$$:

run_KNN_benchmark("R^D a") Out:

Benchmarking : K-NN search on R^D a: 10,000 queries on a dataset of 10,000 points
in dimension 3 with a euclidean metric. ===============================
sklearn, auto (CPU) -------------
1x100 loops of size         1:
train =   1x100x0.006913s
test  =   1x100x0.132252s
1x 10 loops of size         2:
train =   1x 10x0.006147s
test  =   1x 10x0.143104s
1x  1 loops of size         5:
train =   1x  1x0.009165s
test  =   1x  1x0.129589s
1x  1 loops of size        10:
train =   1x  1x0.009278s
test  =   1x  1x0.242093s
1x  1 loops of size        20:
train =   1x  1x0.008777s
test  =   1x  1x0.241809s
1x  1 loops of size        50:
train =   1x  1x0.005599s
test  =   1x  1x0.352888s
1x  1 loops of size       100:
train =   1x  1x0.006884s
test  =   1x  1x0.555457s
sklearn, Ball-tree (CPU) -------------
1x100 loops of size         1:
train =   1x100x0.005426s
test  =   1x100x1.174859s
1x 10 loops of size         2:
train =   1x 10x0.005629s
test  =   1x 10x1.514498s
1x  1 loops of size         5:
train =   1x  1x0.012689s
test  =   1x  1x1.361847s
1x  1 loops of size        10:
train =   1x  1x0.006375s
test  =   1x  1x1.532763s
1x  1 loops of size        20:
train =   1x  1x0.005746s
test  =   1x  1x1.654460s
1x  1 loops of size        50:
train =   1x  1x0.012803s
test  =   1x  1x2.155772s
1x  1 loops of size       100:
train =   1x  1x0.007468s
test  =   1x  1x1.945825s
sklearn, KD-tree (CPU) -------------
1x100 loops of size         1:
train =   1x100x0.006627s
test  =   1x100x0.134348s
1x 10 loops of size         2:
train =   1x 10x0.006702s
test  =   1x 10x0.133760s
1x  1 loops of size         5:
train =   1x  1x0.006728s
test  =   1x  1x0.131101s
1x  1 loops of size        10:
train =   1x  1x0.009019s
test  =   1x  1x0.231460s
1x  1 loops of size        20:
train =   1x  1x0.012306s
test  =   1x  1x0.361526s
1x  1 loops of size        50:
train =   1x  1x0.011660s
test  =   1x  1x0.329645s
1x  1 loops of size       100:
train =   1x  1x0.010894s
test  =   1x  1x0.551188s
sklearn, bruteforce (CPU) -------------
**
Too slow !
PyTorch (GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000079s
test  =   1x100x0.008414s
1x 10 loops of size         2:
train =   1x 10x0.000068s
test  =   1x 10x0.008607s
1x 10 loops of size         5:
train =   1x 10x0.000133s
test  =   1x 10x0.009002s
1x 10 loops of size        10:
train =   1x 10x0.000065s
test  =   1x 10x0.009213s
1x 10 loops of size        20:
train =   1x 10x0.000111s
test  =   1x 10x0.009527s
1x 10 loops of size        50:
train =   1x 10x0.000075s
test  =   1x 10x0.010202s
1x 10 loops of size       100:
train =   1x 10x0.000141s
test  =   1x 10x0.010330s
PyTorch (small batches, GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000057s
test  =   1x100x0.008201s
1x 10 loops of size         2:
train =   1x 10x0.000055s
test  =   1x 10x0.008747s
1x 10 loops of size         5:
train =   1x 10x0.000054s
test  =   1x 10x0.009433s
1x 10 loops of size        10:
train =   1x 10x0.000056s
test  =   1x 10x0.009302s
1x 10 loops of size        20:
train =   1x 10x0.000065s
test  =   1x 10x0.009586s
1x 10 loops of size        50:
train =   1x 10x0.000056s
test  =   1x 10x0.010070s
1x 10 loops of size       100:
train =   1x 10x0.000065s
test  =   1x 10x0.010461s
KeOps (GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000047s
test  =   1x100x0.000534s
1x100 loops of size         2:
train =   1x100x0.000046s
test  =   1x100x0.000842s
1x100 loops of size         5:
train =   1x100x0.000047s
test  =   1x100x0.000974s
1x100 loops of size        10:
train =   1x100x0.000047s
test  =   1x100x0.001060s
1x100 loops of size        20:
train =   1x100x0.000046s
test  =   1x100x0.001588s
1x100 loops of size        50:
train =   1x100x0.000049s
test  =   1x100x0.003767s
1x 10 loops of size       100:
train =   1x 10x0.000060s
test  =   1x 10x0.115337s
JAX (GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000073s
test  =   1x100x0.044641s
1x 10 loops of size         2:
train =   1x 10x0.000101s
test  =   1x 10x0.044295s
1x  1 loops of size         5:
train =   1x  1x0.000352s
test  =   1x  1x0.044471s
1x  1 loops of size        10:
train =   1x  1x0.000814s
test  =   1x  1x0.045054s
1x  1 loops of size        20:
train =   1x  1x0.000378s
test  =   1x  1x0.058827s
1x  1 loops of size        50:
train =   1x  1x0.000372s
test  =   1x  1x0.060803s
1x  1 loops of size       100:
train =   1x  1x0.000428s
test  =   1x  1x0.049157s
JAX (small batches, GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000090s
test  =   1x100x0.049987s
1x 10 loops of size         2:
train =   1x 10x0.000097s
test  =   1x 10x0.054213s
1x  1 loops of size         5:
train =   1x  1x0.000390s
test  =   1x  1x0.064223s
1x  1 loops of size        10:
train =   1x  1x0.000425s
test  =   1x  1x0.064441s
1x  1 loops of size        20:
train =   1x  1x0.000376s
test  =   1x  1x0.064874s
1x  1 loops of size        50:
train =   1x  1x0.000362s
test  =   1x  1x0.066706s
1x  1 loops of size       100:
train =   1x  1x0.000444s
test  =   1x  1x0.071569s


Large dataset in $$\mathbb{R}^3$$:

run_KNN_benchmark("R^D b")

plt.show() Out:

Benchmarking : K-NN search on R^D b: 10,000 queries on a dataset of 1,000,000 points
in dimension 3 with a euclidean metric. ===============================
sklearn, auto (CPU) -------------
1x100 loops of size         1:
train =   1x100x1.598607s
test  =   1x100x0.143958s
1x 10 loops of size         2:
train =   1x 10x1.612554s
test  =   1x 10x0.133257s
1x  1 loops of size         5:
train =   1x  1x1.532201s
test  =   1x  1x0.236985s
1x  1 loops of size        10:
train =   1x  1x1.611835s
test  =   1x  1x0.241700s
1x  1 loops of size        20:
train =   1x  1x1.656041s
test  =   1x  1x0.241299s
1x  1 loops of size        50:
train =   1x  1x1.730085s
test  =   1x  1x0.340788s
1x  1 loops of size       100:
train =   1x  1x1.726777s
test  =   1x  1x0.540386s
sklearn, Ball-tree (CPU) -------------
1x100 loops of size         1:
train =   1x100x1.369651s
test  =   1x100x2.344115s
1x 10 loops of size         2:
train =   1x 10x1.353685s
test  =   1x 10x3.061345s
1x  1 loops of size         5:
train =   1x  1x1.513042s
test  =   1x  1x3.353290s
1x  1 loops of size        10:
train =   1x  1x1.183532s
test  =   1x  1x5.071523s
1x  1 loops of size        20:
train =   1x  1x1.261427s
test  =   1x  1x4.351152s
1x  1 loops of size        50:
train =   1x  1x1.236903s
test  =   1x  1x5.147325s
1x  1 loops of size       100:
train =   1x  1x1.364799s
test  =   1x  1x6.874878s
sklearn, KD-tree (CPU) -------------
1x100 loops of size         1:
train =   1x100x1.462515s
test  =   1x100x0.131533s
1x 10 loops of size         2:
train =   1x 10x1.514099s
test  =   1x 10x0.147783s
1x  1 loops of size         5:
train =   1x  1x1.524585s
test  =   1x  1x0.125420s
1x  1 loops of size        10:
train =   1x  1x1.615530s
test  =   1x  1x0.243877s
1x  1 loops of size        20:
train =   1x  1x1.406488s
test  =   1x  1x0.241338s
1x  1 loops of size        50:
train =   1x  1x1.469080s
test  =   1x  1x0.414169s
1x  1 loops of size       100:
train =   1x  1x1.486899s
test  =   1x  1x0.346360s
sklearn, bruteforce (CPU) -------------
**
Too slow !
PyTorch (GPU) -------------
**
Memory overflow !
PyTorch (small batches, GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000179s
test  =   1x100x1.845937s
1x 10 loops of size         2:
train =   1x 10x0.000149s
test  =   1x 10x1.960062s
1x  1 loops of size         5:
train =   1x  1x0.000181s
test  =   1x  1x2.083780s
1x  1 loops of size        10:
train =   1x  1x0.000167s
test  =   1x  1x2.148774s
1x  1 loops of size        20:
train =   1x  1x0.000218s
test  =   1x  1x2.215750s
1x  1 loops of size        50:
train =   1x  1x0.000200s
test  =   1x  1x2.312260s
1x  1 loops of size       100:
train =   1x  1x0.000215s
test  =   1x  1x2.409131s
KeOps (GPU) -------------
1x100 loops of size         1:
train =   1x100x0.000075s
test  =   1x100x0.027534s
1x 10 loops of size         2:
train =   1x 10x0.000093s
test  =   1x 10x0.052842s
1x  1 loops of size         5:
train =   1x  1x0.000117s
test  =   1x  1x0.059585s
1x  1 loops of size        10:
train =   1x  1x0.000116s
test  =   1x  1x0.055962s
1x  1 loops of size        20:
train =   1x  1x0.000131s
test  =   1x  1x0.059440s
1x  1 loops of size        50:
train =   1x  1x0.000114s
test  =   1x  1x0.068099s
1x  1 loops of size       100:
train =   1x  1x0.000116s
test  =   1x  1x0.457150s
JAX (GPU) -------------
**
Memory overflow !
JAX (small batches, GPU) -------------
**
Too slow !


Total running time of the script: ( 23 minutes 33.842 seconds)

Gallery generated by Sphinx-Gallery