Backpropagation

Last but not least, KeOps fully supports automatic differentiation. Most of the magic required is implemented by the F::DiffT attributes of KeOps formulas and reductions, as discussed in previous pages.

Backprop through a Sum reduction

Then, to implement the PyTorch backward of the KeOps Genred operator, we simply have to remember that if \((g_i) \in \mathbb{R}^{\mathrm{M}\times \mathrm{E}}\) is a “gradient to backpropagate” with respect to the output \((a_i) \in \mathbb{R}^{\mathrm{M}\times \mathrm{E}}\) of a Genred call with a Sum reduction, we can write that for all variations \((\delta p,\delta x_i, \delta y_j)\) of the parameters, \(i\)- and \(j\)-variables, at order 1:

\[\begin{split}\begin{aligned} \Big\langle& \sum_{j=1}^\mathrm{N} F(p+\delta p, x_i+\delta x_i, y_j + \delta y_j) - F(p, x_i, y_j)~,~ g_i \Big\rangle_{\mathbb{R}^{\mathrm{M}\times E}}\\ ~=~& \sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N} \Big( \left\langle \partial_p F(\dots) \cdot g_i, \delta p \right\rangle \,+\, \left\langle \partial_{x_i} F(\dots) \cdot g_i, \delta x_i \right\rangle \,+\, \langle \partial_{y_j} F(\dots) \cdot g_i, \delta y_j \rangle \Big).\end{aligned}\end{split}\]

Consequently, performing the appropriate permutations of sums:

\[\begin{split}\begin{aligned} \partial_{x_i} \Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \phantom{\sum_{i=1}^\mathrm{M} } \sum_{j=1}^\mathrm{N} \Big( \partial_{x_i} \Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big), \\ \partial_{y_j} \Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \phantom{\sum_{i=1}^\mathrm{M} } \sum_{i=1}^\mathrm{M} \Big( \partial_{y_j} \Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big), \\ \partial_{p} \Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N} \Big( \partial_{p} \Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big).\end{aligned}\end{split}\]

Backprop through a Log-Sum-Exp reduction

Similarly, when \((a_i)\) is given through a Log-Sum-Exp reduction:

\[\begin{aligned} a_i~=~ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j),\end{aligned}\]

straightforward computations show that:

\[\begin{split}\begin{aligned} \partial_{x_i} \Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \phantom{\sum_{i=1}^\mathrm{M} } \sum_{j=1}^\mathrm{N} e^{F(p,x_i,y_j) - a_i}\cdot \Big( \partial_{x_i} \Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big), \\ \partial_{y_j} \Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \phantom{\sum_{i=1}^\mathrm{M} } \sum_{i=1}^\mathrm{M} e^{F(p,x_i,y_j) - a_i}\cdot \Big( \partial_{y_j} \Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big), \\ \partial_{p} \Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i) ~&=~ \sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N} e^{F(p,x_i,y_j) - a_i}\cdot \Big( \partial_{p} \Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big).\end{aligned}\end{split}\]

In other words, a backward pass through a Genred call that involves a Sum or a Log-Sum-Exp reduction can always be written as a symbolic Map-Reduce computation.

Bootstrapping derivatives of arbitrary order

Applying these commutation rules between the differential operator \(\partial_\texttt{V}\) and the Sum or Log-Sum-Exp reductions, the pykeops/torch/generic/generic_red.py module provides full compatibility between KeOps LazyTensors and the torch.autograd package. Thanks to recursive calls to the Genred operator and to our symbolic math engine, everything works just fine – even high-order derivatives.