K-NN classification - PyTorch API

The argKmin(K) reduction supported by KeOps pykeops.torch.LazyTensor allows us to perform bruteforce k-nearest neighbors search with four lines of code. It can thus be used to implement a large-scale K-NN classifier, without memory overflows.


Standard imports:

import time
from matplotlib import pyplot as plt

import numpy as np
import torch

from pykeops.torch import LazyTensor

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

Dataset, in 2D:

N, D = 10000 if use_cuda else 1000, 2  # Number of samples, dimension
x = torch.rand(N, D).type(dtype)       # Random samples on the unit square

# Random-ish class labels:
def fth(x):
    return 3*x*(x-.5)*(x-1) + x
cl = x[:,1] + .1 * torch.randn(N).type(dtype) < fth( x[:,0] )

Reference sampling grid, on the unit square:

M = 1000 if use_cuda else 100
tmp = torch.linspace(0, 1, M).type(dtype)
g2, g1 = torch.meshgrid(tmp, tmp)
g = torch.cat( (g1.contiguous().view(-1,1), g2.contiguous().view(-1,1)), dim=1 )