Streaming Log-sum-exp Computation

Sebastian Nowozin - Sun 08 May 2016 -

A common numerical operation in statistical computing is to compute

$$\log \sum_{i=1}^n \exp x_i,$$

where \(x_i \in \mathbb{R}\), and \(n\) is potentially very large.

We can implement the above computation by exponentiating each number, then summing them, then taking a logarithm as follows (written in Julia).

logsumexp_naive(X) = log(sum(exp(X)))

When the above function returns a finite number then it is numerically accurate. However, the above computation is not robust if one of the elements is very large (say, larger than 710 for double precision IEEE floating point). Then \(\exp(x_i)\) returns a floating point infinity and the entire computation returns a floating point infinity as well.

Standard Batch Solution

The standard solution to this problem is to use the mathematical identity

$$\log \sum_{i=1}^n \exp x_i = \alpha + \log \sum_{i=1}^n \exp (x_i - \alpha),$$

which holds for any \(\alpha \in \mathbb{R}\). By selecting \(\alpha = \max_{i=1,\dots,n} x_i\) no argument to the \(\exp\)-function will be larger than zero and the above naive computation can be applied on the transformed numbers. The code is as follows.

function logsumexp_batch(X)
    alpha = maximum(X)  # Find maximum value in X
    log(sum(exp(X-alpha))) + alpha
end

Code such as the above is used in almost all packages for performing statistical computation and is described as the standard solution, see e.g. here and here.

However, there are the following problems:

  1. It requires two scans over the data array, one to find the maximum, one to compute the summation. For modern systems and large input arrays the above computation is memory-bandwidth limited so two memory scans mean twice the runtime.
  2. It requires knowledge of the number of elements in the sum prior to computation.

Streaming log-sum-exp Computation

The solution is to also compute the maximum element in a streaming manner and to correct a running estimate whenever a new maximum is found. I have not seen this solution elsewhere, but I hope you may find it useful.

First, here is the code.

function logsumexp_stream(X)
    alpha = -Inf
    r = 0.0
    for x = X
        if x <= alpha
            r += exp(x - alpha)
        else
            r *= exp(alpha - x)
            r += 1.0
            alpha = x
        end
    end
    log(r) + alpha
end

As you can see by glancing over the code, only one linear access over the input is required and we do not need to know the number of elements.

To understand how the code works, assume we maintain two quantities. The first is the largest value seen after \(i\) elements,

$$\alpha_i := \max_{j = 1,\dots,i} x_i.$$

The second is the accumulated sum so far with the current maximum subtracted,

$$r_i := \sum_{j=1}^i \exp(x_j - \alpha_i).$$

Now when we visit a new element \(x_{i+1}\) there are two cases that can happen. If \(x_{i+1} \leq \alpha_i\) then \(\alpha_{i+1} = \alpha_i\) and we simply update

$$r_{i+1} = r_i + \exp(x_{i+1} - \alpha_{i+1}).$$

However, if we see a new largest element, we can write \(r_i\) as

$$r_i := \sum_{j=1}^i \exp(x_j - \alpha_i) = \exp(-\alpha_i) \sum_{j=1}^i \exp(x_j).$$

We correct this estimate in order to use the new maximum \(x_{i+1}\) and cancelling the old maximum \(\alpha_i\),

$$r'_{i+1} = \exp(\alpha_i - x_{i+1}) \, r_i.$$

The factor is always smaller than one. Then we proceed to accumulate as normal to obtain

$$r_{i+1} = r'_{i+1} + \exp(x_{i+1} - \alpha_{i+1}) = r'_{i+1} + 1.$$

The above code is as numerically robust as the commonly used batch version and for large arrays can be twice as fast.

Example

Running

n = 10_000_000
X = 500.0*randn(n)

logsumexp_naive(X), logsumexp_batch(X), logsumexp_stream(X)

gives the following output

(Inf,2686.7659554831052,2686.7659554831052)