Variational Inference Optimization

In my my last post, I described the problem of variational inference as maximizing the ELBO, which simultaneously minimizes the KL divergence between the approximate posterior and the prior and maximizes the expected log likelihood of the data. In this post, I'm going to outline some algorithms used for optimizing the ELBO.

Variational Inference (recap)

Generative process

Suppose we have some stochastic generative process \(p_\rho(x | \zeta)\), that maps some latent variables \(\zeta\) to some observed variables \(x\). We will start by assuming that each data sample (indexed by \(i\)) in our dataset \(x^{(i)} \in \mathcal{D}\) we have access to is independent of every other sample, and that they are all generated in the following way: first, the generative process selects a latent variable \(\zeta^{(i)} \sim p_\rho(\zeta)\), and then the data is generated from the forward process \(p_\rho(x^{(i)} | \zeta^{(i)})\). The joint distribution of the latent and observed variables in this process is given by the chain rule of probability: \(p_\rho(x, \zeta) = p_\rho(x | \zeta)p_\rho(\zeta)\). Note that everything up to this point is a generative process that we are going to attempt to model. It is not the model itself. If we could observe the latent variables, our dataset would be \(\mathcal{D}^{+} \doteq \{(x^{(i)}, \zeta^{(i)}) \sim p_\rho(x, \zeta)| i = 1 .. N\}\). However, since \(\zeta\) is latent, our real dataset is \(\mathcal{D} \doteq \{x^{(i)} : (x^{(i)}, \zeta^{(i)}) \sim p_\rho(x, \zeta), i = 1 .. N\}\).

Generative process with time dependency in the latent space

There is no time dependency in the generative process as it is, but we could imagine that the latent variables \(\zeta\) at some time \(t\) are actually generated from the state of latent variables at earlier times. If we added the Markov property to \(\zeta\), this would be a Hidden Markov Model (HMM), but let's not eliminate the possibility of long range temporal dependencies just yet. We can imagine an agent at time \(t\) who has access to past observations and wants to make accurate predictions of future observations. The agent can use the observed variables in the past to infer the state of the latent variables in the past that most likely generated the observations. Then the agent can use these inferred latent states to infer the state of latent variables in the future, and finally from these predict what observations will be generated in the future. If we divide time up to everything before the present time \(t\) and everything after \(t\), we can symbolize the latent and observed variables as \(x_{t^-}\),\(x_{t^+}\), and \(\zeta_{t^-}\), and \(\zeta_{t^+}\), respectively. Then we could write the steps of this inference process as:

\[ x_{t^-} \rightarrow \zeta_{t^-} \rightarrow \zeta_{t^+} \rightarrow x_{t^+} \]

Generative model

From the above inference process, we can see a few challenges. First, each of these three arrows represent stochastic inferences which the agent has to learn. Second and most importantly, the agent has no access to the latent variables \(\zeta\). To solve the second problem first, we are going to create latent variables \(z\) that may not correspond to the 'true' latent variables \(\zeta\) at all. Since \(\zeta\) is never observed, we can never peek behind the curtain and see how close our \(z\) is to the true \(\zeta\). The only objective we really have is to maximize the correctness of our agent's predictions about the future. Presumably the closer our \(z\) is to the the hypothesized \(\zeta\), the better our agent would be at predicting the future, but we could imagine that for some particular important \(x\)s, heuristics that create variables in \(z\) that have no correspondance to \(\zeta\), but enable better predictions, are ultimately more useful to the agent. The fact that we are optimizing for future prediction is important, because that will help us when we consider choosing the right loss function for optimization.

We have replaced \(\zeta\) with \(z\) in the conversion from process to model. Let's also replace \(p_\rho(x | \zeta)\) with \(p_\theta(x | z)\) in the generative process, and \(p_\rho(\zeta)\) with \(p_z(z^{t^+} | z^{t^-})\). Then the inference process would become:

\[ x_{t^-} \overset{q_\phi}{\longrightarrow} z_{t^-} \overset{p_z}{\longrightarrow} z_{t^+} \overset{p_\theta}{\longrightarrow} x_{t^+} \]

Note that \(q_\phi\) is a stochastic map from \(x_{t^-}\) to \(z_{t^-}\), which says: "given all the past observations I've made, what is the probability distribution over past latent variables". This is the inverse of the generative process, which in english says: "given all the past latent variables, what is the probability distribution over past observations". And our models \(p_z\) and \(p_\theta\) already have this information, so by Bayes' rule, we can express our desire that these be inverses of each other (shown below).

Suppose that all the time dependencies in the latent space are temporally bound, so that we could go back in time so far to some \(t=0\) where \(p_z(z^{t^+} | z^{0^+}, z^{0^-}) = p_z(z^{t^+} | z^{0^+})\) In other words, the latents we are concerned about are conditionally independent the values of the latents at all times before \(t=0\), conditioned on their values after \(t=0\). Then we can introduce a stationary global prior \(p_\omicron(z^{0^-})\) Our joint generative model can then be written \(p(x, z) = p_\omicron(z^{0^-}) p_z(z^{0^+} | z^{0^-}) p_\theta(x | z)\). In words, the joint distribution is factored into the product of a stationary prior, a latent transition model, and a conditional likelihood.

Back to stationary latents

So far we've been assuming that there is some temporal model in the latent space, which I find helpful for thinking about the ultimate motivation (to me) of variational inference: making accurate predictions about future observations. However, if we make a common assumption that there is no temporal component, and that the latent variables are IID, then we can replace \(p_\omicron(z^{0^-}) p_z(z^{0^+} | z^{0^-})\) with a single stationary \(p_z(z)\). For the rest of this post, we will assume that the latents are IID.

Variational Distribution

If we take the log likelihood, we can perform an algebraic manipulation to get:

\[ \begin{align} \ln p(x) &= \ln \frac{p(z, x)}{p(z|x)} \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p(z, x)}{p(z|x)} \right] &&\text{expectation is constant} \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{q_\phi(z|x)}{q_\phi(z|x)} \frac{p_z(z) p_\theta(x | z) }{p(z|x)} \right] \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{q_\phi(z|x)}{p(z|x)} + \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right] \\ &= D_{KL}(q_\phi(z|x) || p(z|x)) + \underbrace{\underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right]}_{\text{ELBO}} \end{align} \]

Here \(q_\phi(z|x)\) is our approximate posterior, which we can think of as a stochastic map from past observed variables to all past latent variables, that approximates the posterior of the generative model \(q_\phi(z|x) \approx p(z|x)\).

In order to maximize the log likelihood of the forward model, we use the variational distribution \(q_\phi(z|x)\) to approximate the posterior \(p(z|x)\), and then optimize the ELBO, which jointly minimizes the KL divergence between the approximate posterior and the prior, and maximizes the expected log likelihood of the data.

Optimization

Gradient ascent on the ELBO

If we can compute the gradient of the ELBO wrt parameters \(\phi\) and \(\theta\), we can use gradient ascent to optimize the ELBO. If we symbolize the ELBO as \(L(\phi, \theta)\), we can write the gradient ascent update as:

\[ \phi^*, \theta^* = \underset{\phi, \theta}{\text{argmax }} L(\phi, \theta) \]

To compute the gradient wrt \(\theta\), we can expand the ELBO, and push the gradient through the expectation:

\[ \begin{align} \nabla_{\theta} L(\phi, \theta) &= \nabla_{\theta} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right] \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \nabla_{\theta} \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right] \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \nabla_{\theta} \ln p_\theta(x | z) \right] \\ \end{align} \]

We can estimate this gradient by sampling \(z^{(i)}_k \sim q_\phi(z|x^{(i)})\) from our approximate posterior, where \(i\) is the index of the data sample \(x^{(i)}\), and \(k\) is the index of the sample from the approximate posterior (between 1 and \(K\)). Then we just compute \(\nabla_{\theta} \ln p_\theta(x^{(i)} | z^{(i)}_k)\) for each sample, and take the average.

However, computing the gradient of the ELBO wrt \(\phi\) is a bit more complicated. This gradient is given by:

\[ \nabla_{\phi} L(\phi, \theta) = \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right] \]

The paper by (Schulman et al. 2015) was helpful for me to understand exactly when gradients can be pushed through expectations. Because the parameters we are computing the gradient wrt are both in the distribution the expectation is over and inside the expectation, we have to resort to some mathematical tricks. There are two main strategies (also described in Yuge (Jimmy) Shi's blog post and in this Pyro tutorial): the pathwise derivative estimator and the score function estimator.

Pathwise derivative estimator This is also called the reparameterization trick. If the distribution can be reparameterized into a deterministic function and a fixed noise variable that take two independent computational paths, then we can modify the expectation to be over the fixed noise variable, and we can push the gradient through.

Score Function Estimator This is also called REINFORCE, or the likelihood ratio estimator. The gradient of the log of a function is called the score function in statistics. Applying the chain rule of differential calculus, we can write:

\[ \nabla_{\phi} \ln f(\phi) = \frac{\nabla_{\phi} f(\phi) }{f(\phi)} \]

To see how this works when taking the gradient of the ELBO:

\[ \begin{align} \nabla_{\phi} L(\phi, \theta) &= \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \right] \\ &= \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln p_z(z) p_\theta(x | z) - \ln q_\phi(z|x) \right] \\ &= \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \ln p_z(z) p_\theta(x | z) - \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \ln q_\phi(z|x) \\ \end{align} \]

The first term doesn't have \(\phi\) inside the expectation, so we can push the gradient through using the following identity:

\[ \begin{align} \nabla_{\phi} \underset{z \sim q_\phi(z|x)}{\mathbb{E}} f(z) &= \nabla_\phi \sum_z q_\phi(z|x) f(z) \\ &= \sum_z \nabla_\phi q_\phi(z|x) f(z) && \text{linearity of the gradient} \\ &= \sum_z q_\phi(z|x) \frac{\nabla_\phi q_\phi(z|x) }{q_\phi(z|x)} f(z) && \text{multiply by 1} \\ &= \sum_z q_\phi(z|x) \nabla_\phi \ln q_\phi(z|x) f(z) && \text{use the score function identity} \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \nabla_\phi \ln q_\phi(z|x) f(z) && \text{definition of expectation} \\ \end{align} \]

We can use this to push the gradient through so long as \(f(z)\) is not a function of \(\phi\). This will handle the first term. I'm not going to write out the derivation, but the second term can be algebraically manipulated to yield the following identity:

\[ \nabla_\phi \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \ln q_\phi(z|x) = \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \ln q_\phi(z|x) \nabla_\phi \ln q_\phi(z|x) \]

Putting both these terms together, we get:

\[ \nabla_\phi L(\phi, \theta) = \underset{z \sim q_\phi(z|x)}{\mathbb{E}} \left[ \ln \frac{p_z(z) p_\theta(x|z)}{q_\phi(z|x)} \nabla_\phi \ln q_\phi(z|x) \right] \]

This can be estimated by sampling \(z^{(i)}_k \sim q_\phi(z|x^{(i)})\) from our approximate posterior, where \(i\) is the index of the data sample \(x^{(i)}\), and \(k\) is the index of the sample from the approximate posterior (between 1 and \(K\)). Then we just compute \(\ln \frac{p_z(z) p_\theta(x | z) }{q_\phi(z|x)} \nabla_\phi \ln q_\phi(z|x)\) for each sample, and take the average.

Pathwise derivative estimator vs score function estimator The score function estimator is more general than the pathwise derivative estimator, because the score function estimator can handle distributions that cannot be reparameterized. However, the variance of the gradient estimate is higher for the score function estimator, meaning that your stochastic gradient descent might be a little too stochastic.

Extensions

The name of the game for ELBO gradient estimators is minimizing variance. The score function estimator and the pathwise derivative estimator both give the correct gradient in expectation, but more can be done to minimize the variance. There are two main strategies for variance reduction: baseline subtraction and Rao-Blackwellization.

Importance weighted Autoencoder (IWAE)

Everything above optimized the ELBO. Per Cremer, Morris and Duvenaud (2017), the IWAE gradient estimator "optimizes the standard variational lower bound, but using a more complex distribution." To get the IWAE gradient estimator, we can expand the likelihood (not the log likelihood as we did to get the ELBO):

\[ \begin{align} p(x) &= \sum_z p(x, z) \\ &= \sum_z p_z(z) p_\theta(x | z) \\ &= \underset{z \sim p_z(z)}{\mathbb{E}} p_\theta(x | z) \\ &= \underset{z \sim q_\phi(z|x)}{\mathbb{E}}\left[ \frac{p_z(z)}{q_\phi(z|x)} p_\theta(x | z) \right] \end{align} \]

We can then maximize this expression with respect to the parameters \(\phi\) and \(\theta\), by sampling k samples of \(z_k \sim q_\phi(z|x)\) and averaging all inside the log. Taking the log outside the sum (as detailed in Jimmy Shi's blog post) is a tigher bound on the log likelihood than the standard ELBO.

\[ \begin{align} \nabla_{\phi, \theta} \ln p(x) \approx \nabla_{\phi, \theta} \ln \left( \frac{1}{K} \sum_k \frac{p_z(z_k) p_\theta(x | z_k)}{q_\phi(z_k|x)} \right) \\ \end{align} \]

Conclusion

This post outlined the problem of variational inference as maximizing the ELBO. Many other algorithms can be seen as modifications of this goal. If we could identify the parameters of the variational and generative models that maximize the ELBO, we would be further along in our goal of building an agent that take historical observations and learn the latent structure of the generative process, so that from new observations it can infer the latent state of the world and predict the temporal evolution of that latent state and the future observations they imply.