K-Nearest Neighbours search (WIP)

Let’s compare the performances of PyTorch, JAX, FAISS and KeOps fpr K-NN queries on random samples and standard datasets.

Note

In this demo, we use exact bruteforce computations (tensorized for PyTorch and online for KeOps), without leveraging any multiscale or low-rank (Nystroem/multipole) decomposition of the Kernel matrix. First support for these approximation schemes is scheduled for May-June 2021.

Setup

import numpy as np
import torch
from matplotlib import pyplot as plt
from functools import partial

from benchmark_utils import (
    flatten,
    random_normal,
    full_benchmark,
    timer,
    tensor,
    int_tensor,
    jax_tensor,
)
from dataset_utils import generate_samples

use_cuda = torch.cuda.is_available()

Benchmark specifications:

# Values of K that we'll loop upon:
Ks = [1, 2, 5, 10, 20, 50, 100]

Simple bruteforce implementations

Define a simple Gaussian RBF product, using a tensorized implementation. Note that expanding the squared norm \(\|x-y\|^2\) as a sum \(\|x\|^2 - 2 \langle x, y \rangle + \|y\|^2\) allows us to leverage the fast matrix-matrix product of the BLAS/cuBLAS libraries.

PyTorch bruteforce:

"""
def KNN_KeOps(K, metric="euclidean", **kwargs):
    def fit(x_train):
        # Setup the K-NN estimator:
        x_train = tensor(x_train)
        start = timer()

        # N.B.: The "training" time here should be negligible.
        elapsed = timer() - start

        def f(x_test):
            x_test = tensor(x_test)
            start = timer()

            # Actual K-NN query:

            elapsed = timer() - start

            indices = indices.cpu().numpy()
            return indices, elapsed

        return f, elapsed

    return fit
"""


def KNN_torch(K, metric="euclidean", **kwargs):
    def fit(x_train):
        # Setup the K-NN estimator:
        x_train = tensor(x_train)
        start = timer()
        # The "training" time here should be negligible:
        x_train_norm = (x_train ** 2).sum(-1)
        elapsed = timer() - start

        def f(x_test):
            x_test = tensor(x_test)
            start = timer()

            # Actual K-NN query:
            if metric == "euclidean":
                x_test_norm = (x_test ** 2).sum(-1)
                diss = (
                    x_test_norm.view(-1, 1)
                    + x_train_norm.view(1, -1)
                    - 2 * x_test @ x_train.t()
                )

            elif metric == "manhattan":
                diss = (x_test[:, None, :] - x_train[None, :, :]).abs().sum(dim=2)

            elif metric == "angular":
                diss = -x_test @ x_train.t()

            elif metric == "hyperbolic":
                x_test_norm = (x_test ** 2).sum(-1)
                diss = (
                    x_test_norm.view(-1, 1)
                    + x_train_norm.view(1, -1)
                    - 2 * x_test @ x_train.t()
                )
                diss /= x_test[:, 0].view(-1, 1) * x_train[:, 0].view(1, -1)

            out = diss.topk(K, dim=1, largest=False)

            elapsed = timer() - start
            indices = out.indices.cpu().numpy()
            return indices, elapsed

        return f, elapsed

    return fit

PyTorch bruteforce, with small batches to avoid memory overflows:

def KNN_torch_batch_loop(K, metric="euclidean", **kwargs):
    def fit(x_train):
        # Setup the K-NN estimator:
        x_train = tensor(x_train)
        Ntrain, D = x_train.shape
        start = timer()
        # The "training" time here should be negligible:
        x_train_norm = (x_train ** 2).sum(-1)
        elapsed = timer() - start

        def f(x_test):
            x_test = tensor(x_test)

            # Estimate the largest reasonable batch size:
            Ntest = x_test.shape[0]
            #  torch.cuda.get_device_properties(deviceId).total_memory
            av_mem = int(5e8)
            Ntest_loop = min(max(1, av_mem // (4 * D * Ntrain)), Ntest)
            Nloop = (Ntest - 1) // Ntest_loop + 1
            # print(f"{Ntest} queries, split in {Nloop} batches of {Ntest_loop} queries each.")
            out = int_tensor(Ntest, K)

            start = timer()
            # Actual K-NN query:
            for k in range(Nloop):
                x_test_k = x_test[Ntest_loop * k : Ntest_loop * (k + 1), :]
                if metric == "euclidean":
                    x_test_norm = (x_test_k ** 2).sum(-1)
                    diss = (
                        x_test_norm.view(-1, 1)
                        + x_train_norm.view(1, -1)
                        - 2 * x_test_k @ x_train.t()
                    )

                elif metric == "manhattan":
                    diss = (x_test_k[:, None, :] - x_train[None, :, :]).abs().sum(dim=2)

                elif metric == "angular":
                    diss = -x_test_k @ x_train.t()

                elif metric == "hyperbolic":
                    x_test_norm = (x_test_k ** 2).sum(-1)
                    diss = (
                        x_test_norm.view(-1, 1)
                        + x_train_norm.view(1, -1)
                        - 2 * x_test_k @ x_train.t()
                    )
                    diss /= x_test_k[:, 0].view(-1, 1) * x_train[:, 0].view(1, -1)

                out[Ntest_loop * k : Ntest_loop * (k + 1), :] = diss.topk(
                    K, dim=1, largest=False
                ).indices
                del diss
            # torch.cuda.empty_cache()

            elapsed = timer() - start
            indices = out.cpu().numpy()
            return indices, elapsed

        return f, elapsed

    return fit

Distance matrices with JAX:

from functools import partial
import jax
import jax.numpy as jnp


@partial(jax.jit, static_argnums=(2, 3))
def knn_jax_fun(x_train, x_test, K, metric):
    if metric == "euclidean":
        diss = (
            (x_test ** 2).sum(-1)[:, None]
            + (x_train ** 2).sum(-1)[None, :]
            - 2 * x_test @ x_train.T
        )
    elif metric == "manhattan":
        diss = jax.lax.abs(x_test[:, None, :] - x_train[None, :, :]).sum(-1)
    elif metric == "angular":
        diss = -x_test @ x_train.T
    elif metric == "hyperbolic":
        diss = (
            (x_test ** 2).sum(-1)[:, None]
            + (x_train ** 2).sum(-1)[None, :]
            - 2 * x_test @ x_train.T
        )
        diss = diss / (x_test[:, 0][:, None] * x_train[:, 0][None, :])

    indices = jax.lax.top_k(-diss, K)[1]
    return indices

JAX bruteforce:

def KNN_JAX(K, metric="euclidean", **kwargs):
    def fit(x_train):

        # Setup the K-NN estimator:
        start = timer(use_torch=False)
        x_train = jax_tensor(x_train)
        elapsed = timer(use_torch=False) - start

        def f(x_test):
            x_test = jax_tensor(x_test)

            # Actual K-NN query:
            start = timer(use_torch=False)
            indices = knn_jax_fun(x_train, x_test, K, metric)
            indices = np.array(indices)
            elapsed = timer(use_torch=False) - start
            return indices, elapsed

        return f, elapsed

    return fit

JAX bruteforce, with small batches to avoid memory overflows:

def KNN_JAX_batch_loop(K, metric="euclidean", **kwargs):
    def fit(x_train):

        # Setup the K-NN estimator:
        start = timer(use_torch=False)
        x_train = jax_tensor(x_train)
        elapsed = timer(use_torch=False) - start

        def f(x_test):
            x_test = jax_tensor(x_test)

            # Estimate the largest reasonable batch size
            #  torch.cuda.get_device_properties(deviceId).total_memory
            av_mem = int(5e8)
            Ntrain, D = x_train.shape
            Ntest = x_test.shape[0]
            Ntest_loop = min(max(1, av_mem // (4 * D * Ntrain)), Ntest)
            Nloop = (Ntest - 1) // Ntest_loop + 1
            # print(f"{Ntest} queries, split in {Nloop} batches of {Ntest_loop} queries each.")
            indices = np.zeros((Ntest, K), dtype=int)

            start = timer(use_torch=False)
            # Actual K-NN query:
            for k in range(Nloop):
                x_test_k = x_test[Ntest_loop * k : Ntest_loop * (k + 1), :]
                indices[Ntest_loop * k : Ntest_loop * (k + 1), :] = knn_jax_fun(
                    x_train, x_test_k, K, metric
                )
            elapsed = timer(use_torch=False) - start
            return indices, elapsed

        return f, elapsed

    return fit

KeOps bruteforce implementation:

from pykeops.torch import LazyTensor, Vi, Vj


def KNN_KeOps(K, metric="euclidean", **kwargs):
    def fit(x_train):
        # Setup the K-NN estimator:
        x_train = tensor(x_train)
        start = timer()

        # Encoding as KeOps LazyTensors:
        D = x_train.shape[1]
        X_i = Vi(0, D)
        X_j = Vj(1, D)

        # Symbolic distance matrix:
        if metric == "euclidean":
            D_ij = ((X_i - X_j) ** 2).sum(-1)
        elif metric == "manhattan":
            D_ij = (X_i - X_j).abs().sum(-1)
        elif metric == "angular":
            D_ij = -(X_i | X_j)
        elif metric == "hyperbolic":
            D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])

        # K-NN query operator:
        KNN_fun = D_ij.argKmin(K, dim=1)

        # N.B.: The "training" time here should be negligible.
        elapsed = timer() - start

        def f(x_test):
            x_test = tensor(x_test)
            start = timer()

            # Actual K-NN query:
            indices = KNN_fun(x_test, x_train)

            elapsed = timer() - start

            indices = indices.cpu().numpy()
            return indices, elapsed

        return f, elapsed

    return fit

SciKit-Learn tree-based and bruteforce methods

from sklearn.neighbors import NearestNeighbors


def KNN_sklearn(K, metric="euclidean", algorithm=None, **kwargs):

    if metric in ["euclidean", "angular"]:
        p = 2
    elif metric == "manhattan":
        p = 1
    else:
        raise NotImplementedError("This distance is not supported.")

    KNN_meth = NearestNeighbors(n_neighbors=K, algorithm=algorithm, p=p, n_jobs=-1)

    def fit(x_train):
        # Setup the K-NN estimator:
        start = timer()
        KNN_fun = KNN_meth.fit(x_train).kneighbors
        elapsed = timer() - start

        def f(x_test):
            start = timer()
            distances, indices = KNN_fun(x_test)
            elapsed = timer() - start

            return indices, elapsed

        return f, elapsed

    return fit


KNN_sklearn_auto = partial(KNN_sklearn, algorithm="auto")
KNN_sklearn_ball_tree = partial(KNN_sklearn, algorithm="ball_tree")
KNN_sklearn_kd_tree = partial(KNN_sklearn, algorithm="kd_tree")
KNN_sklearn_brute = partial(KNN_sklearn, algorithm="brute")

NumPy vs. PyTorch vs. KeOps (Gpu)

def run_KNN_benchmark(name):

    # Load the dataset and some info:
    dataset = generate_samples(name)(1)
    N_train, dimension = dataset["train"].shape
    N_test, _ = dataset["test"].shape
    metric = dataset["metric"]

    # Routines to benchmark:
    routines = [
        (KNN_sklearn_auto, "sklearn, auto (CPU)", {}),
        (KNN_sklearn_ball_tree, "sklearn, Ball-tree (CPU)", {}),
        (KNN_sklearn_kd_tree, "sklearn, KD-tree (CPU)", {}),
        (KNN_sklearn_brute, "sklearn, bruteforce (CPU)", {}),
        (KNN_torch, "PyTorch (GPU)", {}),
        (KNN_torch_batch_loop, "PyTorch (small batches, GPU)", {}),
        (KNN_KeOps, "KeOps (GPU)", {}),
        (KNN_JAX, "JAX (GPU)", {}),
        (KNN_JAX_batch_loop, "JAX (small batches, GPU)", {}),
    ]

    # Actual run:
    full_benchmark(
        f"K-NN search on {name}: {N_test:,} queries on a dataset of {N_train:,} points\nin dimension {dimension:,} with a {metric} metric.",
        routines,
        generate_samples(name),
        min_time=1e-4,
        max_time=10,
        problem_sizes=Ks,
        xlabel="Number of neighbours K",
    )

On random samples:

Small dataset in \(\mathbb{R}^3\):

run_KNN_benchmark("R^D a")
K-NN search on R^D a: 10,000 queries on a dataset of 10,000 points in dimension 3 with a euclidean metric.

Out:

Benchmarking : K-NN search on R^D a: 10,000 queries on a dataset of 10,000 points
in dimension 3 with a euclidean metric. ===============================
sklearn, auto (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.006913s
    test  =   1x100x0.132252s
  1x 10 loops of size         2:
    train =   1x 10x0.006147s
    test  =   1x 10x0.143104s
  1x  1 loops of size         5:
    train =   1x  1x0.009165s
    test  =   1x  1x0.129589s
  1x  1 loops of size        10:
    train =   1x  1x0.009278s
    test  =   1x  1x0.242093s
  1x  1 loops of size        20:
    train =   1x  1x0.008777s
    test  =   1x  1x0.241809s
  1x  1 loops of size        50:
    train =   1x  1x0.005599s
    test  =   1x  1x0.352888s
  1x  1 loops of size       100:
    train =   1x  1x0.006884s
    test  =   1x  1x0.555457s
sklearn, Ball-tree (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.005426s
    test  =   1x100x1.174859s
  1x 10 loops of size         2:
    train =   1x 10x0.005629s
    test  =   1x 10x1.514498s
  1x  1 loops of size         5:
    train =   1x  1x0.012689s
    test  =   1x  1x1.361847s
  1x  1 loops of size        10:
    train =   1x  1x0.006375s
    test  =   1x  1x1.532763s
  1x  1 loops of size        20:
    train =   1x  1x0.005746s
    test  =   1x  1x1.654460s
  1x  1 loops of size        50:
    train =   1x  1x0.012803s
    test  =   1x  1x2.155772s
  1x  1 loops of size       100:
    train =   1x  1x0.007468s
    test  =   1x  1x1.945825s
sklearn, KD-tree (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.006627s
    test  =   1x100x0.134348s
  1x 10 loops of size         2:
    train =   1x 10x0.006702s
    test  =   1x 10x0.133760s
  1x  1 loops of size         5:
    train =   1x  1x0.006728s
    test  =   1x  1x0.131101s
  1x  1 loops of size        10:
    train =   1x  1x0.009019s
    test  =   1x  1x0.231460s
  1x  1 loops of size        20:
    train =   1x  1x0.012306s
    test  =   1x  1x0.361526s
  1x  1 loops of size        50:
    train =   1x  1x0.011660s
    test  =   1x  1x0.329645s
  1x  1 loops of size       100:
    train =   1x  1x0.010894s
    test  =   1x  1x0.551188s
sklearn, bruteforce (CPU) -------------
**
Too slow !
PyTorch (GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000079s
    test  =   1x100x0.008414s
  1x 10 loops of size         2:
    train =   1x 10x0.000068s
    test  =   1x 10x0.008607s
  1x 10 loops of size         5:
    train =   1x 10x0.000133s
    test  =   1x 10x0.009002s
  1x 10 loops of size        10:
    train =   1x 10x0.000065s
    test  =   1x 10x0.009213s
  1x 10 loops of size        20:
    train =   1x 10x0.000111s
    test  =   1x 10x0.009527s
  1x 10 loops of size        50:
    train =   1x 10x0.000075s
    test  =   1x 10x0.010202s
  1x 10 loops of size       100:
    train =   1x 10x0.000141s
    test  =   1x 10x0.010330s
PyTorch (small batches, GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000057s
    test  =   1x100x0.008201s
  1x 10 loops of size         2:
    train =   1x 10x0.000055s
    test  =   1x 10x0.008747s
  1x 10 loops of size         5:
    train =   1x 10x0.000054s
    test  =   1x 10x0.009433s
  1x 10 loops of size        10:
    train =   1x 10x0.000056s
    test  =   1x 10x0.009302s
  1x 10 loops of size        20:
    train =   1x 10x0.000065s
    test  =   1x 10x0.009586s
  1x 10 loops of size        50:
    train =   1x 10x0.000056s
    test  =   1x 10x0.010070s
  1x 10 loops of size       100:
    train =   1x 10x0.000065s
    test  =   1x 10x0.010461s
KeOps (GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000047s
    test  =   1x100x0.000534s
  1x100 loops of size         2:
    train =   1x100x0.000046s
    test  =   1x100x0.000842s
  1x100 loops of size         5:
    train =   1x100x0.000047s
    test  =   1x100x0.000974s
  1x100 loops of size        10:
    train =   1x100x0.000047s
    test  =   1x100x0.001060s
  1x100 loops of size        20:
    train =   1x100x0.000046s
    test  =   1x100x0.001588s
  1x100 loops of size        50:
    train =   1x100x0.000049s
    test  =   1x100x0.003767s
  1x 10 loops of size       100:
    train =   1x 10x0.000060s
    test  =   1x 10x0.115337s
JAX (GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000073s
    test  =   1x100x0.044641s
  1x 10 loops of size         2:
    train =   1x 10x0.000101s
    test  =   1x 10x0.044295s
  1x  1 loops of size         5:
    train =   1x  1x0.000352s
    test  =   1x  1x0.044471s
  1x  1 loops of size        10:
    train =   1x  1x0.000814s
    test  =   1x  1x0.045054s
  1x  1 loops of size        20:
    train =   1x  1x0.000378s
    test  =   1x  1x0.058827s
  1x  1 loops of size        50:
    train =   1x  1x0.000372s
    test  =   1x  1x0.060803s
  1x  1 loops of size       100:
    train =   1x  1x0.000428s
    test  =   1x  1x0.049157s
JAX (small batches, GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000090s
    test  =   1x100x0.049987s
  1x 10 loops of size         2:
    train =   1x 10x0.000097s
    test  =   1x 10x0.054213s
  1x  1 loops of size         5:
    train =   1x  1x0.000390s
    test  =   1x  1x0.064223s
  1x  1 loops of size        10:
    train =   1x  1x0.000425s
    test  =   1x  1x0.064441s
  1x  1 loops of size        20:
    train =   1x  1x0.000376s
    test  =   1x  1x0.064874s
  1x  1 loops of size        50:
    train =   1x  1x0.000362s
    test  =   1x  1x0.066706s
  1x  1 loops of size       100:
    train =   1x  1x0.000444s
    test  =   1x  1x0.071569s

Large dataset in \(\mathbb{R}^3\):

run_KNN_benchmark("R^D b")

plt.show()
K-NN search on R^D b: 10,000 queries on a dataset of 1,000,000 points in dimension 3 with a euclidean metric.

Out:

Benchmarking : K-NN search on R^D b: 10,000 queries on a dataset of 1,000,000 points
in dimension 3 with a euclidean metric. ===============================
sklearn, auto (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x1.598607s
    test  =   1x100x0.143958s
  1x 10 loops of size         2:
    train =   1x 10x1.612554s
    test  =   1x 10x0.133257s
  1x  1 loops of size         5:
    train =   1x  1x1.532201s
    test  =   1x  1x0.236985s
  1x  1 loops of size        10:
    train =   1x  1x1.611835s
    test  =   1x  1x0.241700s
  1x  1 loops of size        20:
    train =   1x  1x1.656041s
    test  =   1x  1x0.241299s
  1x  1 loops of size        50:
    train =   1x  1x1.730085s
    test  =   1x  1x0.340788s
  1x  1 loops of size       100:
    train =   1x  1x1.726777s
    test  =   1x  1x0.540386s
sklearn, Ball-tree (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x1.369651s
    test  =   1x100x2.344115s
  1x 10 loops of size         2:
    train =   1x 10x1.353685s
    test  =   1x 10x3.061345s
  1x  1 loops of size         5:
    train =   1x  1x1.513042s
    test  =   1x  1x3.353290s
  1x  1 loops of size        10:
    train =   1x  1x1.183532s
    test  =   1x  1x5.071523s
  1x  1 loops of size        20:
    train =   1x  1x1.261427s
    test  =   1x  1x4.351152s
  1x  1 loops of size        50:
    train =   1x  1x1.236903s
    test  =   1x  1x5.147325s
  1x  1 loops of size       100:
    train =   1x  1x1.364799s
    test  =   1x  1x6.874878s
sklearn, KD-tree (CPU) -------------
  1x100 loops of size         1:
    train =   1x100x1.462515s
    test  =   1x100x0.131533s
  1x 10 loops of size         2:
    train =   1x 10x1.514099s
    test  =   1x 10x0.147783s
  1x  1 loops of size         5:
    train =   1x  1x1.524585s
    test  =   1x  1x0.125420s
  1x  1 loops of size        10:
    train =   1x  1x1.615530s
    test  =   1x  1x0.243877s
  1x  1 loops of size        20:
    train =   1x  1x1.406488s
    test  =   1x  1x0.241338s
  1x  1 loops of size        50:
    train =   1x  1x1.469080s
    test  =   1x  1x0.414169s
  1x  1 loops of size       100:
    train =   1x  1x1.486899s
    test  =   1x  1x0.346360s
sklearn, bruteforce (CPU) -------------
**
Too slow !
PyTorch (GPU) -------------
**
Memory overflow !
PyTorch (small batches, GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000179s
    test  =   1x100x1.845937s
  1x 10 loops of size         2:
    train =   1x 10x0.000149s
    test  =   1x 10x1.960062s
  1x  1 loops of size         5:
    train =   1x  1x0.000181s
    test  =   1x  1x2.083780s
  1x  1 loops of size        10:
    train =   1x  1x0.000167s
    test  =   1x  1x2.148774s
  1x  1 loops of size        20:
    train =   1x  1x0.000218s
    test  =   1x  1x2.215750s
  1x  1 loops of size        50:
    train =   1x  1x0.000200s
    test  =   1x  1x2.312260s
  1x  1 loops of size       100:
    train =   1x  1x0.000215s
    test  =   1x  1x2.409131s
KeOps (GPU) -------------
  1x100 loops of size         1:
    train =   1x100x0.000075s
    test  =   1x100x0.027534s
  1x 10 loops of size         2:
    train =   1x 10x0.000093s
    test  =   1x 10x0.052842s
  1x  1 loops of size         5:
    train =   1x  1x0.000117s
    test  =   1x  1x0.059585s
  1x  1 loops of size        10:
    train =   1x  1x0.000116s
    test  =   1x  1x0.055962s
  1x  1 loops of size        20:
    train =   1x  1x0.000131s
    test  =   1x  1x0.059440s
  1x  1 loops of size        50:
    train =   1x  1x0.000114s
    test  =   1x  1x0.068099s
  1x  1 loops of size       100:
    train =   1x  1x0.000116s
    test  =   1x  1x0.457150s
JAX (GPU) -------------
**
Memory overflow !
JAX (small batches, GPU) -------------
**
Too slow !

Total running time of the script: ( 23 minutes 33.842 seconds)

Gallery generated by Sphinx-Gallery