Reductions

Following the same design principles, Reduction operators are implemented in the keops/core/reductions/*.h headers. Taking as input an arbitrary symbolic formula F, Reduction templates encode generic Map-Reduce schemes and should implement a few standard routines.

Summation

In the case of the simple Sum reduction (Sum_Reduction.h header), these can be described as:

  1. An InitializeReduction method, which fills up the running buffera” of our Map-Reduce algorithm – a vector of size F::DIM – with zeros before the start of the loop on the reduction index j.

  2. A ReducePair method, which takes as input a pointer to the running buffer a, a pointer to the result Fi,j=F(p1,,xi1,,yj1,) and implements the in-place reduction:

    a  a + Fi,j.
  3. A FinalizeOutput method, which post-processes the buffer a before saving its value in the output array. This is a useful step for argmin-like reductions; but in the case of the sum, no post-processing is needed.

The online Log-Sum-Exp trick

More interestingly, the Max_SumShiftExp_Reduction.h header implements an online version of the well-known Log-Sum-Exp trick: a factorization of the maximum in the computation of

logj=1Nexp(Fi,j) = mi + logj=1Nexp(Fi,jmi),  with  mi = maxj=1NFi,j

that ensures the computation of this important quantity – the linchpin of maximum likelihood estimators and entropic Optimal Transport solvers – without numeric overflows.

Merging the content of our C++ header and of the Python post-processing step implemented in pykeops/common/operations.py, assuming that Fi,j=F(p1,,xi1,,yj1,) is a scalar quantity, we may describe its behaviour as follows:

  1. The InitializeReduction method ensures that our running buffer a is a vector of size 2 that encodes the current value of the inner summation as an explicit (exponent, mantissa) or “(maximum, residual)” pair of float numbers: at any stage of the computation, the pair (m,r) encodes the positive number emr with the required precision. We initially set the value of a to (,0)e0.

  2. The ReducePair method takes as input a pointer to the result Fi,j of the computation, a pointer to the running buffer a=(m,r)emr and implements the in-place update:

    (m,r)  {( m , r+reFi,jm)if mFi,j(Fi,j, 1+remFi,j)otherwise.

    This is a numerically stable way of writing the sum reduction:

    emr  emr+eFi,j = { em (r+reFi,jm)if mFi,jeFi,j(1+remFi,j)otherwise.
  3. FinalizeOutput post-processes the buffer a=(m,r)emr by applying the final “log” operation, returning a value of m+log(r) for the full reduction.