Backpropagation
Last but not least, KeOps fully supports automatic differentiation.
Most of the magic required is implemented by the F::DiffT
attributes of KeOps formulas and reductions, as discussed in
previous pages.
Backprop through a Sum reduction
Then, to implement the PyTorch backward
of the KeOps Genred
operator, we simply have to remember that if
\((g_i) \in \mathbb{R}^{\mathrm{M}\times \mathrm{E}}\) is a “gradient to backpropagate” with
respect to the output \((a_i) \in \mathbb{R}^{\mathrm{M}\times \mathrm{E}}\)
of a Genred
call with a Sum reduction, we can write that for all variations
\((\delta p,\delta x_i, \delta y_j)\) of the parameters, \(i\)-
and \(j\)-variables, at order 1:
\[\begin{split}\begin{aligned}
\Big\langle&
\sum_{j=1}^\mathrm{N} F(p+\delta p, x_i+\delta x_i, y_j + \delta y_j)
- F(p, x_i, y_j)~,~
g_i
\Big\rangle_{\mathbb{R}^{\mathrm{M}\times E}}\\
~=~&
\sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N}
\Big(
\left\langle \partial_p F(\dots) \cdot g_i, \delta p \right\rangle
\,+\,
\left\langle \partial_{x_i} F(\dots) \cdot g_i, \delta x_i \right\rangle
\,+\,
\langle \partial_{y_j} F(\dots) \cdot g_i, \delta y_j \rangle
\Big).\end{aligned}\end{split}\]
Consequently, performing the appropriate permutations of sums:
\[\begin{split}\begin{aligned}
\partial_{x_i}
\Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\phantom{\sum_{i=1}^\mathrm{M} }
\sum_{j=1}^\mathrm{N} \Big(
\partial_{x_i}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big), \\
\partial_{y_j}
\Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\phantom{\sum_{i=1}^\mathrm{M} }
\sum_{i=1}^\mathrm{M} \Big(
\partial_{y_j}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big), \\
\partial_{p}
\Big[ \sum_{j=1}^\mathrm{N} F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N} \Big(
\partial_{p}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i \Big).\end{aligned}\end{split}\]
Backprop through a Log-Sum-Exp reduction
Similarly, when \((a_i)\) is given through a Log-Sum-Exp
reduction:
\[\begin{aligned}
a_i~=~ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j),\end{aligned}\]
straightforward computations show that:
\[\begin{split}\begin{aligned}
\partial_{x_i}
\Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\phantom{\sum_{i=1}^\mathrm{M} }
\sum_{j=1}^\mathrm{N}
e^{F(p,x_i,y_j) - a_i}\cdot
\Big(
\partial_{x_i}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big), \\
\partial_{y_j}
\Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\phantom{\sum_{i=1}^\mathrm{M} }
\sum_{i=1}^\mathrm{M}
e^{F(p,x_i,y_j) - a_i}\cdot
\Big(
\partial_{y_j}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big), \\
\partial_{p}
\Big[ \log \sum_{j=1}^\mathrm{N} \exp F(p,x_i,y_j)\Big] \cdot (g_i)
~&=~
\sum_{i=1}^\mathrm{M} \sum_{j=1}^\mathrm{N}
e^{F(p,x_i,y_j) - a_i}\cdot
\Big(
\partial_{p}
\Big[ F(p,x_i,y_j)\Big] \cdot g_i\Big).\end{aligned}\end{split}\]
In other words, a backward pass through a Genred
call that involves
a Sum or a Log-Sum-Exp reduction can
always be written as a symbolic Map-Reduce computation.
Bootstrapping derivatives of arbitrary order
Applying these commutation rules between the differential operator
\(\partial_\texttt{V}\) and the Sum or Log-Sum-Exp reductions, the
pykeops/torch/generic/generic_red.py
module provides full
compatibility between KeOps LazyTensors
and the
torch.autograd
package. Thanks to recursive calls to the Genred
operator and to our
symbolic math engine, everything works just fine – even high-order
derivatives.