K-NN classification - NumPy API

The pykeops.numpy.generic_argkmin() routine allows us to perform bruteforce k-nearest neighbors search with four lines of code. It can thus be used to implement a large-scale K-NN classifier, without memory overflows.


Standard imports:

import time
import numpy as np
from pykeops.numpy import generic_argkmin
from pykeops.numpy.utils import IsGpuAvailable
from matplotlib import pyplot as plt

dtype = "float32"
use_cuda = IsGpuAvailable()

Dataset, in 2D:

N, D = 10000 if use_cuda else 1000, 2  # Number of samples, dimension
x = np.random.rand(N, D).astype(dtype)  # Random samples on the unit square

# Random-ish class labels:
def fth(x):
    return 3*x*(x-.5)*(x-1)+x
cl = x[:,1] + .1 * np.random.randn(N).astype(dtype) < fth( x[:,0] )

Reference sampling grid, on the unit square:

M = 1000 if use_cuda else 100
tmp = np.linspace(0, 1, M).astype(dtype)
g1, g2 = np.meshgrid(tmp,tmp)
g = np.hstack( (g1.reshape(-1,1), g2.reshape(-1,1)) )