A brief primer on gradient-based optimization

To find the maximum a posterior (MAP) estimates for our probabilistic models, we’ll use gradient-based optimization–the core set of techniques underlying effectively all of deep learning. We’ve already seen a related idea in the notes on frequentist inference. There, we proved that the maximum likelihood estimator \(\hat\pi\) for the parameter of a Bernoulli distribution is the sample mean. We did this in a few steps.

Review: Analytically computing an MLE

First, we noted that, assuming that \(X_i \sim \text{Bernoulli}(\pi)\):

\[\begin{align*}\hat\pi(\mathbf{x}) &= \arg_\pi\max p_{X_1, X_2, \ldots}(\mathbf{x}; \pi)\\ &= \arg_\pi\max \prod_{i=1}^N p_{X_i}(x_i; \pi)\\ &= \arg_\pi\max \prod_{i=1}^N \text{Bern}(x_i; \pi)\\ &= \arg_\pi\max \prod_{i=1}^N \pi^{x_i}(1-\pi)^{1-x_i}\\\end{align*}\]

And because logarithms are monotone increasing:

\[\begin{align*}\hat\pi(\mathbf{x}) &= \arg_\pi\max \mathcal{L}_\mathbf{x}(\pi)\\ &= \arg_\pi\max\log\mathcal{L}_\mathbf{x}(\pi) \\ &= \arg_\pi\max \log\prod_{i=1}^N \pi^{x_i}(1-\pi)^{1-x_i}\\ &= \arg_\pi\max \sum_{i=1}^N \log\left( \pi^{x_i}(1-\pi)^{1-x_i}\right)\\ &= \arg_\pi\max \sum_{i=1}^N x_i \log \pi + (1-x_i)\log(1-\pi)\\\end{align*}\]

Converting this expression to a sum allows us to more easily take the derivative:

\[\begin{align*}\frac{\mathrm{d}}{\mathrm{d}\pi}\log\mathcal{L}_\mathbf{x}(\pi) &= \frac{\mathrm{d}}{\mathrm{d}\pi}\sum_{i=1}^N x_i \log \pi + (1-x_i)\log(1-\pi)\\ &= \sum_{i=1}^N \frac{\mathrm{d}}{\mathrm{d}\pi} x_i \log \pi + (1-x_i)\log(1-\pi)\\ &= \sum_{i=1}^N \frac{x_i -\pi}{p(1-\pi)}\\ &= \sum_{i=1}^N \frac{x_i}{\pi(1-\pi)} - \frac{1}{1-\pi}\\ &= \left[\frac{1}{\pi(1-\pi)}\sum_{i=1}^N x_i\right] - \frac{N}{1-\pi}\end{align*}\]

And once we have the derivative, we can use it to compute the \(\arg_\pi\max\) by setting it to zero.

\[\begin{align*}\left[\frac{1}{\pi(1-\pi)}\sum_{i=1}^N x_i\right] - \frac{N}{1-\pi} &= 0 \\ \frac{1}{\pi(1-\pi)}\sum_{i=1}^N x_i &= \frac{N}{1-\pi} \\ \sum_{i=1}^N x_i &= N\pi \\ \frac{\sum_{i=1}^N x_i}{N} &= \pi \\ \end{align*}\]

So \(\hat\pi(\mathbf{x}) = \frac{\sum_{i=1}^N x_i}{N}\).

Gradient-based optimization

What we’re going to be doing when we using gradient-based optimization uses basically the same idea. The main difference is that, for some model with log-posterior \(\log p(\boldsymbol\theta\mid\mathbf{x})\), we need to attempt to find the \(\arg_\boldsymbol\theta\max \log p(\boldsymbol\theta\mid\mathbf{x})\) of all the parameters \(\boldsymbol\theta\) simultaneously. One way to do this is to compute the gradient \(\nabla \log p(\boldsymbol\theta\mid\mathbf{x})\), which is effectively the multivariate generalization of the deriviative. We can then find at what setting the parameters \(\boldsymbol\theta\) it is \(\mathbf{0}\) (the zero vector).

There are a couple obstacles to finding this optimal \(\boldsymbol\theta\).

Obstacle 1: Computing the optimal parameters

First, we generally can’t analytically compute \(\boldsymbol\theta\) such that \(\nabla p(\boldsymbol\theta \mid \mathbf{x}) = \mathbf{0}\), as we did for computing our maximum likelihood estimator \(\hat\pi\).1 To deal with this issue, we often turn to iterative methods such as gradient ascent/descent. I’ll describe things in terms of gradient ascent, which uses \(\nabla \log p(\boldsymbol\theta \mid \mathbf{x})\) to maximize \(\log p(\boldsymbol\theta \mid \mathbf{x})\), but most descriptions you will find tend to use gradient descent, which uses \(-\nabla \log p(\boldsymbol\theta \mid \mathbf{x})\) to minimize \(-\log p(\boldsymbol\theta \mid \mathbf{x})\). The thing we are minimizing in gradient descent is often called the loss, and I will use that terminology for \(-\log p(\boldsymbol\theta \mid \mathbf{x})\) in my code.

In vanilla gradient ascent, we start with some initial parameters \(\bar{\boldsymbol\theta}_0\) and iteratively modify those parameters by taking steps of size \(\eta\) in the direction of the gradient at \(\bar{\boldsymbol\theta}_i\):

\[\bar{\boldsymbol\theta}_i \equiv \bar{\boldsymbol\theta}_{i-1} + \eta \cdot \nabla \log p(\bar{\boldsymbol\theta}_{i-1} \mid \mathbf{x})\]

The parameter \(\eta\) is often called the learning rate.

If we follow the direction of the gradient, it will eventually lead us toward a \(\bar{\boldsymbol\theta}_i\) where \(\nabla \log p(\bar{\boldsymbol\theta}_i \mid \mathbf{x})\) is close to zero.2

Obstacle 2: Non-convexity in the posterior

The second obstacle is that, in many cases, there is no unique \(\boldsymbol\theta\) such that \(\nabla p(\boldsymbol\theta \mid \mathbf{x}) = \mathbf{0}\) because often \(\log p(\boldsymbol\theta \mid \mathbf{x})\) is not convex. In such cases of non-convexity, we say there are multiple local maxima of \(\log p(\boldsymbol\theta \mid \mathbf{x})\)–or equivalently, that there are multiple local minima of \(-\log p(\boldsymbol\theta \mid \mathbf{x})\).

What we usually want to find is a global maximum for \(\log p(\boldsymbol\theta \mid \mathbf{x})\), where the global maxima are a subset of the local maxima. The tricky thing is that we usually can’t know for sure whether something is a global maximum. This problem has no general solution, but there are approaches that seem to work well for finding empirically good solutions in the presence of non-convexity–even if we can’t be sure they are global maxima. Specfically, a very common method is to use some form of (mini-batch) stochastic gradient ascent/descent.

The basic idea behind stochastic gradient ascent is to make gradient updates against randomly or pseudo-randomly selected subsets of the data–rather than the whole dataset at once. This approach is often implemented by (pseudo-)randomly shuffling the data, partitioning it into “minibatches” of size \(m\), then cycling through the minibatches and updating the parameters by following the gradient for that minibatch. One cycle through the minibatches is often termed an epoch.

Under certain assumptions, this approach will provably get you close to at least a local maximum of \(\log p(\boldsymbol\theta \mid \mathbf{x})\) for the entire dataset. But it has the added benefit that, because the shape of the log-posterior is different for each minibatch (potentially by quite a lot), if you were stuck at a bad local maximum, you can get away from it–hopefully, toward a better local maximum.

What we’ll see when we go to reimplement the model developed by White and Rawlins (2016) in the next section is that minibatch gradient descent tends to work empirically better–in the sense of finding a parameterization where \(\log p(\boldsymbol\theta \mid \mathbf{x})\) is larger–than standard “batch” gradient descent.

References

White, Aaron Steven, and Kyle Rawlins. 2016. “A Computational Model of S-Selection.” Edited by Mary Moroney, Carol-Rose Little, Jacob Collard, and Dan Burgdorf. Semantics and Linguistic Theory 26 (October): 641–63. https://doi.org/10.3765/salt.v26i0.3819.

Footnotes

  1. In fact, we already encountered such a case when fitting the two parameters of a negative binomial distribution. The maximum likelihood estimator of the negative binomial distribution’s parameters cannot be computed analytically.↩︎

  2. Unfortunately, the point may not be a maximum, but rather a saddle point. Methods using a generalization of the second derivative, such as the Hessian, can help deal with this (if it can be computed): maxima are zero points of both the gradient and the Hessian, but saddle points are zero points only of the gradient.↩︎