Stochastic Computation Graphs

Sebastian Nowozin - Fri 24 July 2015 -

This post is about a recent arXiv submission entitled Gradient Estimation Using Stochastic Computation Graphs, and authored by John Schulman, Nicolas Heess, Theophane Weber, and Pieter Abbeel.

In a nutshell this paper generalizes the backpropagation algorithm to allow differentiation through expectations, that is, to compute unbiased estimates of

$$\frac{\partial}{\partial \theta} \mathbb{E}_{x \sim q(x|\theta)}[f(x,\theta)].$$

The paper also provides a nice calculus on directed graphs that allows quick derivation of unbiased gradient estimates. The basic technical results in the paper have been known and used in various communities before and the arXiv submission properly discusses these.

But dismissing the paper as non novel would miss the point in a similar way as missing the point when stating that backpropagation is ``just an application of the chain rule of differentiation''. Instead, the contribution of the current paper is in the practical utility of the graphical calculus and a rich catalogue of machine learning problems where the computation of unbiased gradients of expectations is useful.

In typical statistical point estimation tasks unbiasedness is often not quite as important compared to expected risk. However, here it is crucial. This is because the applications where stochastic computation graphs are useful involve optimization over \(\theta\) and stochastic approximation methods such as stochastic gradient methods can only be justified theoretically in the case of unbiased gradient estimates.

A Neat Derivative Trick

To get an idea of the flavour of derivatives involving expectations, let us look at a simpler case explained in Section 2.1 of the paper. The proof of that case also contains a neat trick worth knowing. The case is as above but inside the expectation we have only \(f(x)\) instead of \(f(x,\theta)\). The ``trick'' is in the identity (obvious in retrospect),

$$\frac{\partial}{\partial \theta} p(x|\theta) = p(x|\theta) \frac{\partial}{\partial \theta} \log p(x|\theta).$$

This allows to establish

\begin{eqnarray} \frac{\partial}{\partial \theta} \mathbb{E}_{x \sim p(x|\theta)}[f(x)] & = & \frac{\partial}{\partial \theta} \int p(x|\theta) f(x) \,\textrm{d}x\nonumber\\ & = & \int \frac{\partial}{\partial \theta} p(x|\theta) f(x) \,\textrm{d}x\nonumber\\ & = & \int p(x|\theta) f(x) \frac{\partial}{\partial \theta} \log p(x|\theta) \,\textrm{d}x\nonumber\\ & = & \mathbb{E}_{x \sim p(x|\theta)}[f(x) \frac{\partial}{\partial \theta} \log p(x|\theta)].\nonumber \end{eqnarray}

In this case the derivation was straightforward but for multiple expectations a derivation based on this elementary definition of the expectation is cumbersome and error-prone. Stochastic computation graphs allow a much quicker derivation of the derivative.

Stochastic Computation Graphs

Stochastic computation graphs are directed acyclic graphs that encode the dependency structure of computation to be performed. The graphical notation generalizes directed graphical models. Here is an example graph.

Stochastic computation graph of problem (1) in Schulman et al.

There are three (or four) types of nodes in a stochastic computation graph:

  1. Input nodes. These are the fixed parameters we would like to compute the derivative of. In the example graph, this is the \(\theta\) node and they are drawn without any container. While technically it is possible to have graphs without input nodes, in order to compute gradients the graph should include at least one input node.
  2. Deterministic nodes. These compute a deterministic function of their parents. In the above graph this is the case for the \(x\) and \(f\) nodes.
  3. Stochastic nodes. These nodes specify a random variable through a distribution conditional on their parents. In the above graph this is true for the \(y\) node, and the circle mirrors the notation used in directed graphical models.
  4. Cost nodes. These are a subset of the deterministic nodes in the graph whose range are the real numbers. In the above graph the node \(f\) is a cost node. I draw them shaded, this is not the case in the original paper.

The entire stochastic computation graph specifies a single objective function whose domain are the input nodes and whose scalar objective is the sum of all cost nodes. The sum of all cost nodes is taken as an expectation over all stochastic nodes in the graph.

Therefore the above graph has the objective function

$$F(\theta) = \mathbb{E}_{y \sim p(y|x(\theta))}[f(y)].$$

Derivative Calculus

The notation used in the paper is a bit heavy and (for my taste at least) a bit too custom, but here it is. Let \(\Theta\) be the set of input nodes, \(\mathcal{C}\) the set of cost nodes, and \(\mathcal{S}\) be the set of stochastic nodes. The notation \(u \prec v\) denotes that there exist a directed path from \(u\) to \(v\) in the graph. The notation \(u \prec^D v\) denotes that there exist a path whose nodes are all deterministic with the exception of the last node \(v\) which may be of any type. We write \(\hat{c}\) for a sample realization of a cost node \(c\). The final notation needed for the result is

$$\textrm{DEPS}_v = \{ w \in \Theta \cup \mathcal{S} | w \prec^D v\}.$$

The key result of the paper, Theorem 1, is now stated as follows:

$$\frac{\partial}{\partial \theta} \mathbb{E}\left[\sum_{c \in \mathcal{C}} c\right] = \mathbb{E}\Bigg[\underbrace{\sum_{w \in \mathcal{S}, \theta \prec^D w} \left( \frac{\partial}{\partial \theta} \log p(w|\textrm{DEPS}_w) \right) \sum_{c \in \mathcal{C}, w \prec c} \hat{c}}_{\textrm{(A)}} + \underbrace{\sum_{c \in \mathcal{C}, \theta \prec^D c} \frac{\partial}{\partial \theta} c(\textrm{DEPS}_c)}_{\textrm{(B)}}\Bigg].$$

The two parts, (A) and (B) can be interpreted as follows. If we only have deterministic computation so that \(\mathcal{S} = \emptyset\), as in an ordinary feedforward neural network for example, the part (B) is just the ordinary derivative and we have to apply the chain rule to that expression. The part (A) originates from each stochastic node and the consequences that originate from the stochastic nodes is absorbed in the sample realizations \(\hat{c}\).

It takes a bit of practice to apply Theorem 1 quickly to a given graph, and I found it easier to instead manually, on a piece of paper, executing Algorithm 1 of the paper, which generalizes backpropagation and builds the derivative node by node by traversing the graph backwards.


To understand the basic technique I illustrate the stochastic computation graph technique on the concrete graph above, which is problem (1) in the paper (Section 2.3), but I make the example concrete.

Stochastic computation graph of problem (1) in Schulman et al.

$$x(\theta) = (\theta-1)^2,$$
$$y(x) \sim \mathcal{N}(x,1),$$
$$f(y) = \left(y-\frac{5}{2}\right)^2.$$

Before we apply Theorem 1 to the graph, here is how the problem actually looks like. First, the objective \(F(\theta) = \mathbb{E}_{y \sim p(y|x(\theta))}[f(y)]\). This objective is just an ordinary one-dimensional deterministic function.

True objective to be minimized

The true gradient of the objective is also just an ordinary function. You can see three zero-crossings at approximately -0.6, 1, and 2.6, corresponding to two local minima and a saddle-point of the objective function.

True gradient of objective

For this simple example we can find a closed form expression for \(F(\theta)\), but in general stochastic computation graphs we are not able to evaluate \(F(\theta)\) and instead only sample values \(\hat{F}_1, \hat{F}_2, \dots\) which are unbiased estimates of the true \(F(\theta)\). By taking averages of a few samples, say of a 100 samples, we can improve the accuracy of our estimates. In order to minimize \(F(\theta)\) over \(\theta\) our goal is to sample unbiased gradients as well. The unbiased sample gradients look as follows, for \(1\) sample (shown in green) and for averages of a \(100\) samples (shown in red), evaluated at a 100 points equispaced along the \(\theta\) axis shown.

Sample gradient of objective

To derive the unbiased gradient estimate we apply Theorem 1. From the summation (A) we will only have one term because our graph contains only one stochastic node, namely \(y\). We will not have any term from (B) as there is no deterministic path from \(\theta\) to \(f\). Therefore we have

$$\frac{\partial}{\partial \theta} \mathbb{E}_{y \sim p(y|x(\theta))}[f(y)] = \mathbb{E}_{y \sim p(y|x(\theta))}\left[\frac{\partial}{\partial \theta} \log p(y|x(\theta)) \hat{f}\right].$$

For the logarithm we need to differentiate the log-likelihood of the Normal distribution and compute

\begin{eqnarray} \frac{\partial x}{\partial \theta} \frac{\partial}{\partial x} \log p(y|x(\theta)) & = & \frac{\partial x}{\partial \theta} \frac{\partial}{\partial x} \left[ - \frac{(y-x(\theta))^2}{2} - \frac{1}{2} \log 2\pi \right]\nonumber\\ & = & \frac{\partial x}{\partial \theta} (y-x(\theta))\nonumber\\ & = & 2(\theta - 1)(y - x(\theta)).\nonumber \end{eqnarray}

So the overall unbiased gradient estimator is

$$\mathbb{E}\left[\frac{\partial}{\partial \theta} \log p(y|x(\theta)) \hat{f}\right] = \mathbb{E}[2(\theta-1)(\hat{y}-\hat{x}) \hat{f}].$$

And the last expression in the expectation is the estimate for a single sample realization.

Variational Bayesian Neural Networks

One important application of being able to compute gradients of expectation objectives is the approximate variational Bayesian posterior inference of neural network parameters.

The original pioneering work of applying variational Bayes (aka mean field inference) to neural network learning is this 1993 paper of Hinton and van Kamp. Recently this has made a revival in particular through the appearance of stochastic variational inference methods around 2011, including a paper of Alex Graves. Many works followed up on this lead, for example Kingma and Welling, Rezende et al., ICML 2014, Blundell et al., ICML 2015, and Mnih and Gregor. They use different estimators of the gradient with varying quality and the SCG paper provides a nice overview of the bigger picture.

In any case, here is a visualization of prototypical variational Bayes learning for feedforward neural networks. A normal feedforward neural network training objective yields the following computation graph, without any stochastic nodes.

Feedforward neural network training objective computation graph

Here we have a fixed weight vector \(w\) with a regularizer \(R(w)\). We have \(n\) training instances and each input \(x_i\) produces a network output, \(P_i(x_i,w)\), for example a distribution over class labels. Together with a known ground truth label \(y_i\) this yields a loss \(\ell_i(P_i,y_i)\), for example the cross-entropy loss. If we use a likelihood based loss and a regularizer derived from a prior, i.e. \(R(w)=-\log P(w)\) the training objective becomes just regularized maximum likelihood estimation.

$$F(w) = -\log P(w) - \sum_{i=1}^n \log P(y_i|x_i;w).$$

The variational Bayes training objective yields the following slightly extended stochastic computation graph.

Variational Bayes neural network training objective stochastic computation graph

Here \(w\) is still a network parameter, but it is now a stochastic vector, \(w \sim Q(w|\theta)\) and \(\theta\) becomes the parameter we would like to learn. The additional cost node \(H\) arises from the entropy of the approximating posterior distribution \(Q\). (An interesting detail: in principle we would not need an arrow \(w \to H\) because we can compute \(H(Q)\). However, if we allow this arrow, then we can use a Monte Carlo approximation of the entropy for approximating families which do not have an analytic entropy expression.) The training objective becomes:

$$F(\theta) = \mathbb{E}_{w \sim Q(w|\theta)}\left[-\log P(w) + \log Q(w|\theta) - \sum_{i=1}^n \log P(y_i|x_i;w)\right].$$

The stochastic computation graph rules can now be used to derive the unbiased gradient estimate.

$$\frac{\partial}{\partial \theta} F(\theta) = \mathbb{E}_{w \sim Q(w|\theta)}\left[ \frac{\partial}{\partial \theta} \log Q(w|\theta) \left( -\log P(w) + \log Q(w|\theta) - \sum_{i=1}^n \log P(y_i|x_i;w) \right)\right].$$

This is now quite practical: the expectation can be approximated using simple Monte Carlo samples of \(w\) values using the current approximating posterior \(Q(w|\theta)\). Because the gradient is unbiased we can improve the approximation by running standard stochastic gradient methods.

Additional Applications

The paper contains a large number of machine learning applications, but there are many others. Here is one I find useful.

Experimental design stochastic computation graph

Experimental design. In Bayesian experimental design we make a choice that influences our future measurements and we would like to make these choices in such a way that we will maximize the future expected utility or minimize expected loss. For this we use a model of how our choices relate to the information we will capture and to how valuable these information will be. Because this is just decision theory and the idea is general, let me be more concrete. Let us assume the objective function

$$\mathbb{E}_{z \sim p(z)}[\mathbb{E}_{x \sim p(x|z,\theta)}[\ell(\tilde{z}(x,\theta), z)]].$$

Here \(\theta\) is our design parameter, \(z\) is the true state we are interested in with a prior \(p(z)\). The measurement process produces \(x \sim p(x|z,\theta)\). We have an estimator \(\tilde{z}(x,\theta)\) and a loss function which compares the estimated value against the true state. The full objective function is then the expected loss of our estimator \(\tilde{z}\) as a function of the design parameters \(\theta\). The above expression looks a bit convoluted but this structure appears frequently when the type of information that is collected can be controlled. One example application of this: \(z\) could represent user behaviour and \(\theta\) some subset of questions we could ask that user to learn more about his behaviour. We then assume a model \(p(x|z,\theta)\) of how the user would provide answers \(x\) given questions \(\theta\) and behaviour \(z\). This allows us to build an estimator \(\tilde{z}(x,\theta)\). The design objective then tries to find the most informative set of questions to ask.

Acknowledgements. I thank Michael Schober for discussions about the paper and Nicolas Heess for feedback on this article.