Reductions
Following the same design principles, \(\operatorname{Reduction}\)
operators are implemented in the
keops/core/reductions/*.h headers.
Taking as input an arbitrary symbolic formula F
,
Reduction
templates encode generic Map-Reduce schemes
and should implement a few standard routines.
Summation
In the case of the simple Sum reduction (Sum_Reduction.h header), these can be described as:
An
InitializeReduction
method, which fills up the running buffer “\(a\)” of our Map-Reduce algorithm – a vector of sizeF::DIM
– with zeros before the start of the loop on the reduction index \(j\).A
ReducePair
method, which takes as input a pointer to the running buffer \(a\), a pointer to the result \(F_{i,j} = F(p^1,\dots,x^1_i,\dots,y^1_j,\dots)\) and implements the in-place reduction:\[\begin{aligned} a~\gets~a~+~F_{i,j}. \end{aligned}\]A
FinalizeOutput
method, which post-processes the buffer \(a\) before saving its value in the output array. This is a useful step for argmin-like reductions; but in the case of the sum, no post-processing is needed.
The online Log-Sum-Exp trick
More interestingly, the Max_SumShiftExp_Reduction.h header implements an online version of the well-known Log-Sum-Exp trick: a factorization of the maximum in the computation of
that ensures the computation of this important quantity – the linchpin of maximum likelihood estimators and entropic Optimal Transport solvers – without numeric overflows.
Merging the content of our C++ header and of the Python post-processing step implemented in pykeops/common/operations.py, assuming that \(F_{i,j} = F(p^1,\dots,x^1_i,\dots,y^1_j,\dots)\) is a scalar quantity, we may describe its behaviour as follows:
The
InitializeReduction
method ensures that our running buffer \(a\) is a vector of size 2 that encodes the current value of the inner summation as an explicit (exponent, mantissa) or “(maximum, residual)” pair of float numbers: at any stage of the computation, the pair \((m,r)\) encodes the positive number \(e^{m}\cdot r\) with the required precision. We initially set the value of \(a\) to \((-\infty, 0)\simeq e^{-\infty}\cdot 0\).The
ReducePair
method takes as input a pointer to the result \(F_{i,j}\) of the computation, a pointer to the running buffer \(a = (m, r) \simeq e^m\cdot r\) and implements the in-place update:\[\begin{split}\begin{aligned} (m,r) ~\gets~ \begin{cases} \big( ~m~, ~\,r + \phantom{r\cdot{}} e^{F_{i,j} - m} \big) & \text{if}~ m \geqslant F_{i,j}\\ \big( F_{i,j},~ 1 + r \cdot e^{m - F_{i,j}} \big) & \text{otherwise.} \end{cases} \end{aligned}\end{split}\]This is a numerically stable way of writing the sum reduction:
\[\begin{split}\begin{aligned} e^m \cdot r ~\gets~ e^m\cdot r \, +\, e^{F_{i,j}} ~=~ \begin{cases} ~e^m~\cdot(\,r+ \phantom{r\cdot{}} e^{F_{i,j}-m}) & \text{if}~ m \geqslant F_{i,j}\\ e^{F_{i,j}}\cdot (1 + r\cdot e^{m-F_{i,j}}) & \text{otherwise.} \end{cases} \end{aligned}\end{split}\]FinalizeOutput
post-processes the buffer \(a = (m,r) \simeq e^{m}\cdot r\) by applying the final “\(\log\)” operation, returning a value of \(m+\log(r)\) for the full reduction.