This post follows the lecture from Pieter Abbeel's 2024 Deep Unsupervised Learning course.
Variational Inference as Importance Sampling
Let \(p_θ(X|Z)\) be a forward model from latent variables \(Z\) to observed variables \(X\). Let \(p_Z(Z)\) be a (initially fixed) prior distribution over the latent variables. Let \(\mathcal{D} = \{x^{(1)}, \ldots, x^{(n)}\}\) be a dataset of observed variables.
Our goal is to find a \(\theta\) that maximizes the likelihood of the data, \(p_θ(X|\mathcal{D})\).
To find \(\theta^*\), we need to compute the sum over all possible values of \(Z\). This is generally intractable, so we can approximate it by sampling \(z \sim p_Z(Z)\) and then maximizing the likelihood of the data given the sampled \(Z\). We can do this by sampling \(K\) values of \(Z\). Then our estimate of the optimal \(\theta\) is
Now consider the case where, for any given pair of (observed, latent) variable pair \((x^{(i)}, z^{(i)}_k)\), the value of \(p_θ(z^{(i)}_k|x^{(i)})\) is close to zero, and the gradient of the log-likelihood is close to zero. Let's interpret this situation as a causal process. \(p_θ(x^{(i)}|z^{(i)}_k)\) is the forward model, which can be interpreted as "supposing the latent \(z\) really is \(z^{(i)}_k\), this is all the different (probability weighted) ways that x could come to be. Then \(p_θ(z^{(i)}_k|x^{(i)})\) is the causally backwards model, which can be interpreted as "supposing the observed \(x\) really is \(x^{(i)}\), this is the probability that \(z\), which is one of the possible causes of \(x\), actually was the cause of \(x\). If \(p_θ(z^{(i)}_k|x^{(i)})\) is close to zero, then we've essentially sampled a \(z\) that isn't a cause of the \(x^{(i)}\); it's totally unrelated, and our gradient is close to zero, so this sample doesn't help us learn \(p_θ(x^{(i)}|z)\).
How can we do a better job of sampling \(z\)s which are more likely to be the cause of the observed \(x\), and therefore more likely to help us learn \(p_θ(x|z)\)? We can sample from the \(p_θ(z|x)\) distribution, which is the distribution of \(z\)s that are most likely to be the cause of the observed \(x\). However, this is typically intractable, so we can learn a tractable approximation to this distribution: a variational distribution \(q_\phi(z|x) \approx p_θ(z|x)\). Because we're sampling from a different distribution than \(p_Z(Z)\), we just need to reweight the samples. This reweighting is called 'importance sampling'.
This gives us a tractable objective function, but we haven't specified how to obtain \(q_\phi(z|x)\). This is our estimate of the backwards model \(p_θ(z|x)\). We can use a neural network to parameterize \(q_\phi(z|x)\) and then optimize the parameters of the neural network to minimize the KL divergence between \(q_\phi(z|x)\) and \(p_θ(z|x)\).
We can alternate between optimizing \(\theta\) and \(\phi\). \(q_\phi(z|x)\) gives us a better estimate of the latent variables corresponding to an observed sample, and \(p_θ(x|z)\) gives us a better estimate of the forward model from the sampled latent variables to the observed variables. The hope is that, at the end of this process, we can learn these forward and backward models that are close to the true generative processes of the data.
Interpreting the terms of variational inference.
In a previous post I illustrated how KL Divergence can be interpreted as the average difference in surprise between a true distribution and an agent's estimate of that distribution. Surprise is defined as the negative log \(\mathbb{S}(P(x)) := -\log P(x)\). Our KL term measuring the goodness of our estimate of the backwards model \(p_θ(z|x)\) can be expanded as follows:
The true backwards model \(p_θ(z|x)\) can be expanded by Bayes' rule:
Substituting this into our KL divergence term, we get:
We can interpret several of the terms in this equation.
- \(\mathbb{S}(p_θ(x))\) is the surprise of the model evidence. Model Evidence is the probability of observing the data given the model, and we are trying to find the model parameters that maximize this. Another way of thinking about it as we are trying to minimize the surprise of the evidence; given our model, future observations should be minimally surprising. The model evidence can be expanded as \(p_θ(x) = \sum_{z \in Z} p_Z(z) p_θ(x|z)\), so the surprise of the model evidence is \(\mathbb{S}(p_θ(x)) = -\log(\sum_{z \in Z} p_Z(z) p_θ(x|z))\).
- \(\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right]\) is the reconstruction loss. Here the entire backwards-forwards process can be seen. Given an observation \(x^{(i)}\), we sample some latent variable that is likely to be the cause of the observation. Then we use the forward model to generate the observation. We want to minimize the surprise of the reconstruction.
- \(D_{\text{KL}}(q_\phi(z|x) || p_Z(z))\) is the variational regularization term. This term forces the samples from \(q_\phi(z|x)\) to be as close as possible to the prior distribution \(p_Z(Z)\).
We can also rearrange these terms without the surprise \(S\) function at all. Starting from the the second to last equation above:
The last equation says that the log evidence is the sum of the KL divergence and the Evidence Lower Bound (ELBO). The reason it's called a lower bound on the evidence should be clear from the equation it is in; the KL divergence is always non-negative, so the ELBO is always less than or equal to the log evidence. It is only equal to the evidence when the KL divergence is zero. We can interpret this equation in terms of our alternating optimization process described above. We seek to maximize the log evidence. We fix the KL Divergence (by fixing \(q_\phi(z|x)\)), then find a \(\theta\) that maximizes the ELBO. This should maximize our log evidence. Then, we fix the \(p_θ\) and adjust \(q_\phi\) to minimize the KL Divergence.
Jensen's inequality
To summarize the objective in importance sampling, we have
Note how the ELBO term is an expectation of a log, whereas in our importance sampling above we had a log of an expectation. By Jensen's inequality, \(\mathbb{E}[\log x] \leq \log \mathbb{E}[x]\).
This is another way to see that the ELBO is a lower bound on the log evidence. We could use to the equation above to see that the KL divergence \(D_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) = 0\) if and only if the ELBO is equal to the log evidence. Another simple way to see this is to first suppose that \(q_\phi(z|x) = p_θ(z|x)\). The KL divergence would be zero, and using Bayes' rule we can see that the ELBO is equal to the log evidence.
Variational Free Energy
The variational free energy is a concept from statistical physics that carries into variational inference by analogy. In physics, the Helmholtz free energy \(F = U - TS\) represents the amount of energy available to do work, where \(U\) is internal energy, \(T\) is temperature, and \(S\) is entropy.
In the context of variational inference, we can define the variational free energy \(F\) as follows. \(\mathbb{S}\) is the surprise function, the negative log.
In this analogy the free energy is the difference of two terms:
-
Internal Energy (U): The term \(\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right]\) represents how poorly our model reconstructs the data (the reconstruction loss). As the reconstruction of \(x\) becomes more accurate, the internal energy decreases. In variational inference, we try to minimize the internal energy (i.e. reconstruction loss).
-
Entropy Term (TS): The negative KL divergence \(-D_{\text{KL}}(q_\phi(z|x) || p_Z(z))\) represents the entropy term. The KL divergence measures how much information is lost when using q to approximate p, so its negative can be interpreted as a form of entropy.
Minimizing the free energy is equivalent to maximizing the ELBO, which in our context means finding the best trade-off between:
- Accurately reconstructing the data (minimizing the reconstruction loss)
- Keeping the approximate posterior close to the prior (minimizing the KL divergence)
Another way to express the free energy starts with our ELBO equation
Here the imperative to minimize free energy can be seen as similar to our original objective in our importance sampling section. We seek to find a \(q_\phi(z|x)\) that minimizes the KL divergence between the true backwards model \(p_θ(z|x)\) and our estimate \(q_\phi(z|x)\). Simultaneously, we adjust \(p_θ\) to minimize the surprise of the model evidence \(p_θ(x)\).