K-NN on the MNIST dataset - PyTorch API

The argKmin(K) reduction supported by KeOps pykeops.torch.LazyTensor allows us to perform bruteforce k-nearest neighbors search with four lines of code. It can thus be used to implement a large-scale K-NN classifier, without memory overflows on the full MNIST dataset.


Standard imports:

import time

import torch
from matplotlib import pyplot as plt

from pykeops.torch import LazyTensor

use_cuda = torch.cuda.is_available()
tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

Load the MNIST dataset: 70,000 images of shape (28,28).

    from sklearn.datasets import fetch_openml
except ImportError:
    raise ImportError("This tutorial requires Scikit Learn version >= 0.20.")

mnist = fetch_openml('mnist_784', cache=False)

x = tensor( mnist.data.astype('float32') )
y = tensor( mnist.target.astype('int64') )

Split it into a train and test set:

D = x.shape[1]
Ntrain, Ntest = (60000, 10000) if use_cuda else (1000, 100)
x_train, y_train = x[:Ntrain,:].contiguous(), y[:Ntrain].contiguous()
x_test,  y_test  = x[Ntrain:Ntrain+Ntest,:].contiguous(), y[Ntrain:Ntrain+Ntest].contiguous()