Clusters
pykeops.torch.cluster
- Block-sparse reductions, allowing you to go beyond the baseline quadratic complexity of kernel operations with the optional ranges argument:
Summary
|
Computes the (weighted) centroids of classes specified by a vector of labels. |
|
Computes the |
|
Computes the cluster indices and centroids of a (weighted) point cloud with labels. |
|
Turns a boolean matrix into a KeOps-friendly ranges argument. |
|
Simplistic clustering algorithm which distributes points into cubic bins. |
|
Sorts a list of points and labels to make sure that the clusters are contiguous in memory. |
|
Swaps the "\(i\)" and "\(j\)" axes of a reduction's optional ranges parameter. |
Syntax
- pykeops.torch.cluster.cluster_centroids(x, lab, Nlab=None, weights=None, weights_c=None)[source]
Computes the (weighted) centroids of classes specified by a vector of labels.
If points \(x_i \in\mathbb{R}^D\) are assigned to \(C\) different classes by the vector of integer labels \(\ell_i \in [0,C)\), this function returns a collection of \(C\) centroids
\[c_k = \frac{\sum_{i, \ell_i = k} w_i\cdot x_i}{\sum_{i, \ell_i=k} w_i},\]where the weights \(w_i\) are set to 1 by default.
- Parameters:
x ((M,D) Tensor) – List of points \(x_i \in \mathbb{R}^D\).
lab ((M,) IntTensor) – Vector of class labels \(\ell_i\in\mathbb{N}\).
- Keyword Arguments:
Nlab ((C,) IntTensor) – Number of points per class. Recomputed if None.
weights ((N,) Tensor) – Positive weights \(w_i\) of each point.
weights_c ((C,) Tensor) – Total weight of each class. Recomputed if None.
- Returns:
List of centroids \(c_k \in \mathbb{R}^D\).
- Return type:
(C,D) Tensor
Example
>>> x = torch.Tensor([ [0.], [1.], [4.], [5.], [6.] ]) >>> lab = torch.IntTensor([ 0, 0, 1, 1, 1 ]) >>> weights = torch.Tensor([ .5, .5, 2., 1., 1. ]) >>> centroids = cluster_centroids(x, lab, weights=weights) >>> print(centroids) tensor([[0.5000], [4.7500]])
- pykeops.torch.cluster.cluster_ranges(lab, Nlab=None)[source]
Computes the
[start,end)
indices that specify clusters in a sorted point cloud.If lab denotes a vector of labels \(\ell_i\in[0,C)\),
sort_clusters()
allows us to sort our point clouds and make sure that points that share the same label are stored next to each other in memory.cluster_ranges()
is simply there to give you the slice indices that correspond to each of those \(C\) classes.- Parameters:
x ((M,D) Tensor) – List of points \(x_i \in \mathbb{R}^D\).
lab ((M,) IntTensor) – Vector of class labels \(\ell_i\in\mathbb{N}\).
- Keyword Arguments:
Nlab ((C,) IntTensor, optional) – If you have computed it already, you may specify the number of points per class through this integer vector of length \(C\).
- Returns:
Stacked array of \([\text{start}_k, \text{end}_k )\) indices in \([0,M]\), for \(k\in[0,C)\).
- Return type:
(C,2) IntTensor
Example
>>> x = torch.Tensor( [ [0.], [5.], [.4], [.3], [2.] ]) >>> lab = torch.IntTensor([ 0, 2, 0, 0, 1 ]) >>> x_sorted, lab_sorted = sort_clusters(x, lab) >>> print(x_sorted) tensor([[0.0000], [0.4000], [0.3000], [2.0000], [5.0000]]) >>> print(lab_sorted) tensor([0, 0, 0, 1, 2], dtype=torch.int32) >>> ranges_i = cluster_ranges(lab) >>> print( ranges_i ) tensor([[0, 3], [3, 4], [4, 5]], dtype=torch.int32) --> cluster 0 = x_sorted[0:3, :] --> cluster 1 = x_sorted[3:4, :] --> cluster 2 = x_sorted[4:5, :]
- pykeops.torch.cluster.cluster_ranges_centroids(x, lab, weights=None, min_weight=1e-09)[source]
Computes the cluster indices and centroids of a (weighted) point cloud with labels.
If x and lab encode a cloud of points \(x_i\in\mathbb{R}^D\) with labels \(\ell_i\in[0,C)\), for \(i\in[0,M)\), this routine returns:
Ranges \([\text{start}_k,\text{end}_k)\) compatible with
sort_clusters()
for \(k\in[0,C)\).Centroids \(c_k\) for each cluster \(k\), computed as barycenters using the weights \(w_i \in \mathbb{R}_{>0}\):
\[c_k = \frac{\sum_{i, \ell_i=k} w_i\cdot \ell_i}{\sum_{i, \ell_i=k} w_i}\]Total weights \(\sum_{i, \ell_i=k} w_i\), for \(k\in[0,C)\).
The weights \(w_i\) can be given through a vector weights of size \(M\), and are set by default to 1 for all points in the cloud.
- Parameters:
x ((M,D) Tensor) – List of points \(x_i \in \mathbb{R}^D\).
lab ((M,) IntTensor) – Vector of class labels \(\ell_i\in\mathbb{N}\).
- Keyword Arguments:
weights ((M,) Tensor) – Positive weights \(w_i\) that can be used to compute our barycenters.
min_weight (float) – For the sake of numerical stability, weights are clamped to be larger or equal to this value.
- Returns:
ranges - Stacked array of \([\text{start}_k,\text{end}_k)\) indices in \([0,M]\), for \(k\in[0,C)\), compatible with the
sort_clusters()
routine.centroids - List of centroids \(c_k \in \mathbb{R}^D\).
weights_c - Total weight of each cluster.
- Return type:
(C,2) IntTensor, (C,D) Tensor, (C,) Tensor
Example
>>> x = torch.Tensor([ [0.], [.5], [1.], [2.], [3.] ]) >>> lab = torch.IntTensor([ 0, 0, 1, 1, 1 ]) >>> ranges, centroids, weights_c = cluster_ranges_centroids(x, lab) >>> print(ranges) tensor([[0, 2], [2, 5]], dtype=torch.int32) --> cluster 0 = x[0:2, :] --> cluster 1 = x[2:5, :] >>> print(centroids) tensor([[0.2500], [2.0000]]) >>> print(weights_c) tensor([2., 3.])
>>> weights = torch.Tensor([ 1., .5, 1., 1., 10. ]) >>> ranges, centroids, weights_c = cluster_ranges_centroids(x, lab, weights=weights) >>> print(ranges) tensor([[0, 2], [2, 5]], dtype=torch.int32) --> cluster 0 = x[0:2, :] --> cluster 1 = x[2:5, :] >>> print(centroids) tensor([[0.1667], [2.7500]]) >>> print(weights_c) tensor([1.5000, 12.0000])
- pykeops.torch.cluster.from_matrix(ranges_i, ranges_j, keep)[source]
Turns a boolean matrix into a KeOps-friendly ranges argument.
This routine is a helper for the block-sparse reduction mode of KeOps, allowing you to turn clustering information (ranges_i, ranges_j) and a cluster-to-cluster boolean mask (keep) into integer tensors of indices that can be used to schedule the KeOps routines.
Suppose that you’re working with variables \(x_i\) (\(i \in [0,10^6)\)), \(y_j\) (\(j \in [0,10^7)\)), and that you want to compute a KeOps reduction over indices \(i\) or \(j\): Instead of performing the full kernel dot product (\(10^6 \cdot 10^7 = 10^{13}\) operations!), you may want to restrict yourself to interactions between points \(x_i\) and \(y_j\) that are “close” to each other.
With KeOps, the simplest way of doing so is to:
Compute cluster labels for the \(x_i\)’s and \(y_j\)’s, using e.g. the
grid_cluster()
method.Compute the ranges (ranges_i, ranges_j) and centroids associated to each cluster, using e.g. the
cluster_ranges_centroids()
method.Sort the tensors
x_i
andy_j
withsort_clusters()
to make sure that the clusters are stored contiguously in memory (this step is critical for performance on GPUs).
- At this point:
the \(k\)-th cluster of \(x_i\)’s is given by
x_i[ ranges_i[k,0]:ranges_i[k,1], : ]
, for \(k \in [0,M)\),the \(\ell\)-th cluster of \(y_j\)’s is given by
y_j[ ranges_j[l,0]:ranges_j[l,1], : ]
, for \(\ell \in [0,N)\).
Compute the \((M,N)\) matrix dist of pairwise distances between cluster centroids.
Apply a threshold on dist to generate a boolean matrix
keep = dist < threshold
.Define a KeOps reduction
my_genred = Genred(..., axis = 0 or 1)
, as usual.Compute the block-sparse reduction through
result = my_genred(x_i, y_j, ranges = from_matrix(ranges_i,ranges_j,keep) )
from_matrix()
is thus the routine that turns a high-level description of your block-sparse computation (cluster ranges + boolean matrix) into a set of integer tensors (the ranges optional argument), used by KeOps to schedule computations on the GPU.- Parameters:
ranges_i ((M,2) IntTensor) – List of \([\text{start}_k,\text{end}_k)\) indices. For \(k \in [0,M)\), the \(k\)-th cluster of “\(i\)” variables is given by
x_i[ ranges_i[k,0]:ranges_i[k,1], : ]
, etc.ranges_j ((N,2) IntTensor) – List of \([\text{start}_\ell,\text{end}_\ell)\) indices. For \(\ell \in [0,N)\), the \(\ell\)-th cluster of “\(j\)” variables is given by
y_j[ ranges_j[l,0]:ranges_j[l,1], : ]
, etc.keep ((M,N) BoolTensor) – If the output
ranges
offrom_matrix()
is used in a KeOps reduction, we will only compute and reduce the terms associated to pairs of “points” \(x_i\), \(y_j\) in clusters \(k\) and \(\ell\) ifkeep[k,l] == 1
.
- Returns:
A 6-uple of LongTensors that can be used as an optional ranges argument of
torch.Genred
. See the documentation oftorch.Genred
for reference.
Example
>>> r_i = torch.IntTensor( [ [2,5], [7,12] ] ) # 2 clusters: X[0] = x_i[2:5], X[1] = x_i[7:12] >>> r_j = torch.IntTensor( [ [1,4], [4,9], [20,30] ] ) # 3 clusters: Y[0] = y_j[1:4], Y[1] = y_j[4:9], Y[2] = y_j[20:30] >>> x,y = torch.Tensor([1., 0.]), torch.Tensor([1.5, .5, 2.5]) # dummy "centroids" >>> dist = (x[:,None] - y[None,:])**2 >>> keep = (dist <= 1) # (2,3) matrix >>> print(keep) tensor([[1, 1, 0], [0, 1, 0]], dtype=torch.uint8) --> X[0] interacts with Y[0] and Y[1], X[1] interacts with Y[1] >>> (ranges_i,slices_i,redranges_j, ranges_j,slices_j,redranges_i) = from_matrix(r_i,r_j,keep) --> (ranges_i,slices_i,redranges_j) will be used for reductions with respect to "j" (axis=1) --> (ranges_j,slices_j,redranges_i) will be used for reductions with respect to "i" (axis=0)
Information relevant if axis = 1:
>>> print(ranges_i) # = r_i tensor([[ 2, 5], [ 7, 12]], dtype=torch.int32) --> Two "target" clusters in a reduction wrt. j >>> print(slices_i) tensor([2, 3], dtype=torch.int32) --> X[0] is associated to redranges_j[0:2] --> X[1] is associated to redranges_j[2:3] >>> print(redranges_j) tensor([[1, 4], [4, 9], [4, 9]], dtype=torch.int32) --> For X[0], i in [2,3,4], we'll reduce over j in [1,2,3] and [4,5,6,7,8] --> For X[1], i in [7,8,9,10,11], we'll reduce over j in [4,5,6,7,8]
Information relevant if axis = 0:
>>> print(ranges_j) tensor([[ 1, 4], [ 4, 9], [20, 30]], dtype=torch.int32) --> Three "target" clusters in a reduction wrt. i >>> print(slices_j) tensor([1, 3, 3], dtype=torch.int32) --> Y[0] is associated to redranges_i[0:1] --> Y[1] is associated to redranges_i[1:3] --> Y[2] is associated to redranges_i[3:3] = no one... >>> print(redranges_i) tensor([[ 2, 5], [ 2, 5], [ 7, 12]], dtype=torch.int32) --> For Y[0], j in [1,2,3], we'll reduce over i in [2,3,4] --> For Y[1], j in [4,5,6,7,8], we'll reduce over i in [2,3,4] and [7,8,9,10,11] --> For Y[2], j in [20,21,...,29], there is no reduction to be done
- pykeops.torch.cluster.grid_cluster(x, size)[source]
Simplistic clustering algorithm which distributes points into cubic bins.
- Parameters:
x ((M,D) Tensor) – List of points \(x_i \in \mathbb{R}^D\).
size (float or (D,) Tensor) – Dimensions of the cubic cells (“voxels”).
- Returns:
Vector of integer labels. Two points
x[i]
andx[j]
are in the same cluster if and only iflabels[i] == labels[j]
. Labels are sorted in a compact range \([0,C)\), where \(C\) is the number of non-empty cubic cells.- Return type:
(M,) IntTensor
Example
>>> x = torch.Tensor([ [0.], [.1], [.9], [.05], [.5] ]) # points in the unit interval >>> labels = grid_cluster(x, .2) # bins of size .2 >>> print( labels ) tensor([0, 0, 2, 0, 1], dtype=torch.int32)
- pykeops.torch.cluster.sort_clusters(x, lab)[source]
Sorts a list of points and labels to make sure that the clusters are contiguous in memory.
On the GPU, contiguous memory accesses are key to high performances. By making sure that points in the same cluster are stored next to each other in memory, this pre-processing routine allows KeOps to compute block-sparse reductions with maximum efficiency.
Warning
For unknown reasons,
torch.bincount
is much more efficient on unsorted arrays of labels… so make sure not to callbincount
on the output of this routine!- Parameters:
x ((M,D) Tensor or tuple/list of (M,..) Tensors) – List of points \(x_i \in \mathbb{R}^D\).
lab ((M,) IntTensor) – Vector of class labels \(\ell_i\in\mathbb{N}\).
- Returns:
Sorted point cloud(s) and vector of labels.
- Return type:
(M,D) Tensor or tuple/list of (M,..) Tensors, (M,) IntTensor
Example
>>> x = torch.Tensor( [ [0.], [5.], [.4], [.3], [2.] ]) >>> lab = torch.IntTensor([ 0, 2, 0, 0, 1 ]) >>> x_sorted, lab_sorted = sort_clusters(x, lab) >>> print(x_sorted) tensor([[0.0000], [0.4000], [0.3000], [2.0000], [5.0000]]) >>> print(lab_sorted) tensor([0, 0, 0, 1, 2], dtype=torch.int32)