Variational Inference

This post follows the lecture from Pieter Abbeel's 2024 Deep Unsupervised Learning course.

Variational Inference as Importance Sampling

Let \(p_θ(X|Z)\) be a forward model from latent variables \(Z\) to observed variables \(X\). Let \(p_Z(Z)\) be a (initially fixed) prior distribution over the latent variables. Let \(\mathcal{D} = \{x^{(1)}, \ldots, x^{(n)}\}\) be a dataset of observed variables.

Our goal is to find a \(\theta\) that maximizes the likelihood of the data, \(p_θ(X|\mathcal{D})\).

\[ \begin{aligned} \theta^* &= \arg \max_\theta p_θ(X|\mathcal{D}) \\ &= \arg \max_\theta \sum_Z p_Z(Z=z) p_θ(X = \mathcal{D} | Z=z) &&\text{Summing over all possible values of $Z$}\\ &= \arg \max_\theta \prod_{i=1}^n \sum_{z \in Z} p_Z(Z=z) p_θ(x^{(i)} | Z=z) &&\text{Probabilities multiply over all observations}\\ &= \arg \max_\theta \sum_{i=1}^n \log \sum_{z \in Z} p_Z(Z=z) p_θ(x^{(i)} | Z=z) &&\text{Taking the log doesn't affect $\theta^*$}\\ \end{aligned} \]

To find \(\theta^*\), we need to compute the sum over all possible values of \(Z\). This is generally intractable, so we can approximate it by sampling \(z \sim p_Z(Z)\) and then maximizing the likelihood of the data given the sampled \(Z\). We can do this by sampling \(K\) values of \(Z\). Then our estimate of the optimal \(\theta\) is

\[ \theta^* \approx \arg \max_\theta \sum_{i=1}^n \log \frac{1}{K} \sum_{k=1}^K p_θ(x^{(i)} | z^{(i)}_k) \]

Now consider the case where, for any given pair of (observed, latent) variable pair \((x^{(i)}, z^{(i)}_k)\), the value of \(p_θ(z^{(i)}_k|x^{(i)})\) is close to zero, and the gradient of the log-likelihood is close to zero. Let's interpret this situation as a causal process. \(p_θ(x^{(i)}|z^{(i)}_k)\) is the forward model, which can be interpreted as "supposing the latent \(z\) really is \(z^{(i)}_k\), this is all the different (probability weighted) ways that x could come to be. Then \(p_θ(z^{(i)}_k|x^{(i)})\) is the causally backwards model, which can be interpreted as "supposing the observed \(x\) really is \(x^{(i)}\), this is the probability that \(z\), which is one of the possible causes of \(x\), actually was the cause of \(x\). If \(p_θ(z^{(i)}_k|x^{(i)})\) is close to zero, then we've essentially sampled a \(z\) that isn't a cause of the \(x^{(i)}\); it's totally unrelated, and our gradient is close to zero, so this sample doesn't help us learn \(p_θ(x^{(i)}|z)\).

How can we do a better job of sampling \(z\)s which are more likely to be the cause of the observed \(x\), and therefore more likely to help us learn \(p_θ(x|z)\)? We can sample from the \(p_θ(z|x)\) distribution, which is the distribution of \(z\)s that are most likely to be the cause of the observed \(x\). However, this is typically intractable, so we can learn a tractable approximation to this distribution: a variational distribution \(q_\phi(z|x) \approx p_θ(z|x)\). Because we're sampling from a different distribution than \(p_Z(Z)\), we just need to reweight the samples. This reweighting is called 'importance sampling'.

\[ \begin{aligned} \theta^* &\approx \arg \max_\theta \sum_{i=1}^n \log \frac{1}{K} \sum_{k=1}^K p_θ(x^{(i)} | z^{(i)}_k) &&\text{Sampling $z^{(i)}_k$ from $p_Z(Z)$}\\ &\approx \arg \max_\theta \sum_{i=1}^n \log \frac{1}{K} \sum_{k=1}^K \frac{ p_Z(z^{(i)}_k)}{q_\phi(z^{(i)}_k|x^{(i)})} p_θ(x^{(i)}|z^{(i)}_k) &&\text{Now sampling $z^{(i)}_k$ from $q_\phi(z|x^{(i)})$}\\ &= \arg \max_\theta \sum_{i=1}^n \log \underset {z^{(i)}_k \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \frac{p_Z(z^{(i)}_k)}{q_\phi(z^{(i)}_k|x^{(i)})} p_θ(x^{(i)}|z^{(i)}_k) \right] &&\text{Expressed as an expectation}\\ \end{aligned} \]

This gives us a tractable objective function, but we haven't specified how to obtain \(q_\phi(z|x)\). This is our estimate of the backwards model \(p_θ(z|x)\). We can use a neural network to parameterize \(q_\phi(z|x)\) and then optimize the parameters of the neural network to minimize the KL divergence between \(q_\phi(z|x)\) and \(p_θ(z|x)\).

\[ \begin{aligned} \phi^* &= \arg \min_\phi \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) &&\text{Minimizing the KL divergence}\\ \end{aligned} \]

We can alternate between optimizing \(\theta\) and \(\phi\). \(q_\phi(z|x)\) gives us a better estimate of the latent variables corresponding to an observed sample, and \(p_θ(x|z)\) gives us a better estimate of the forward model from the sampled latent variables to the observed variables. The hope is that, at the end of this process, we can learn these forward and backward models that are close to the true generative processes of the data.

Interpreting the terms of variational inference.

In a previous post I illustrated how KL Divergence can be interpreted as the average difference in surprise between a true distribution and an agent's estimate of that distribution. Surprise is defined as the negative log \(\mathbb{S}(P(x)) := -\log P(x)\). Our KL term measuring the goodness of our estimate of the backwards model \(p_θ(z|x)\) can be expanded as follows:

\[ \begin{aligned} \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) &:= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(z|x)) - \mathbb{S}(q_\phi(z|x)) \right] &&\text{KL divergence is relative surprise} \\ \end{aligned} \]

The true backwards model \(p_θ(z|x)\) can be expanded by Bayes' rule:

\[ \begin{aligned} p_θ(z|x) &= \frac{p_θ(x|z) p_Z(z)}{p_θ(x)} &&\text{Bayes' rule}\\ \end{aligned} \]

Substituting this into our KL divergence term, we get:

\[ \begin{aligned} \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S} \left(\frac{p_θ(x|z) p_Z(z)}{p_θ(x)}\right) - \mathbb{S}(q_\phi(z|x)) \right] &&\text{Substituting Bayes' rule} \\ &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) + \mathbb{S}(p_Z(z)) - \mathbb{S}(p_θ(x)) - \mathbb{S}(q_\phi(z|x)) \right] &&\text{Surprise is a log}\\ &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) + \mathbb{S}(p_Z(z)) - \mathbb{S}(q_\phi(z|x)) \right] - \mathbb{S}(p_θ(x)) && \text{$p_θ(x)$ is a constant w.r.t $z$} \\ \mathbb{S}(p_θ(x)) &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) + \mathbb{S}(p_Z(z)) - \mathbb{S}(q_\phi(z|x)) \right] - \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) && \text{Rearranging} \\ \underbrace{\mathbb{S}(p_θ(x))}_{\text{surprise of model evidence}} &= \underbrace{\underbrace{\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right]}_{\text{Reconstruction Loss}} + \underbrace{D_{\text{KL}}(q_\phi(z|x) || p_Z(z))}_{\text{Variational Regularization}}}_{\text{Variational Free Energy = -ELBO}} - \underbrace{\text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x))}_{\text{KL divergence}} \end{aligned} \]

We can interpret several of the terms in this equation.

We can also rearrange these terms without the surprise \(S\) function at all. Starting from the the second to last equation above:

\[ \begin{aligned} \mathbb{S}(p_θ(x)) &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) + \mathbb{S}(p_Z(z)) - \mathbb{S}(q_\phi(z|x)) \right] - \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) \\ -\log p_θ(x) &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ -\log p_θ(x|z) -\log p_Z(z) + \log q_\phi(z|x) \right] - \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) \\ \log p_θ(x) &= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \log p_θ(x|z) + \log p_Z(z) - \log q_\phi(z|x) \right] + \text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) \\ \underbrace{\log p_θ(x)}_{\text{log evidence}} &= \underbrace{\text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x))}_{\text{KL divergence}} + \underbrace{\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \log \frac{ p_Z(z)}{ q_\phi(z|x)} p_θ(x|z)\right]}_{\text{Evidence Lower Bound (ELBO)}} \end{aligned} \]

The last equation says that the log evidence is the sum of the KL divergence and the Evidence Lower Bound (ELBO). The reason it's called a lower bound on the evidence should be clear from the equation it is in; the KL divergence is always non-negative, so the ELBO is always less than or equal to the log evidence. It is only equal to the evidence when the KL divergence is zero. We can interpret this equation in terms of our alternating optimization process described above. We seek to maximize the log evidence. We fix the KL Divergence (by fixing \(q_\phi(z|x)\)), then find a \(\theta\) that maximizes the ELBO. This should maximize our log evidence. Then, we fix the \(p_θ\) and adjust \(q_\phi\) to minimize the KL Divergence.

Jensen's inequality

To summarize the objective in importance sampling, we have

\[ \begin{aligned} p_θ(X | \mathcal{D}) &= \prod_i \sum_{z \in Z} p_Z(z) p_θ(x^{(i)}|z) \\ &= \prod_i \underset {z \sim p_Z(z)} {\mathbb{E}} \left[ p_θ(x^{(i)}|z) \right] \\ &= \prod_i \underset {z \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \frac{p_Z(z)}{q_\phi(z|x^{(i)})} p_θ(x^{(i)}|z) \right] \\ \underbrace{\log p_θ(X | \mathcal{D})}_{\text{log evidence}} &= \log \prod_i \underset {z \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \frac{p_Z(z)}{q_\phi(z|x^{(i)})} p_θ(x^{(i)}|z) \right] \\ &= \sum_i \log \underset {z \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \frac{p_Z(z)}{q_\phi(z|x^{(i)})} p_θ(x^{(i)}|z) \right] \\ \end{aligned} \]

Note how the ELBO term is an expectation of a log, whereas in our importance sampling above we had a log of an expectation. By Jensen's inequality, \(\mathbb{E}[\log x] \leq \log \mathbb{E}[x]\).

\[ \underbrace{\underset {z \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \log \frac{ p_Z(z)}{ q_\phi(z|x^{(i)})} p_θ(x^{(i)}|z)\right]}_{\text{Evidence Lower Bound (ELBO)}} \leq \underbrace{\log \underset {z^{(i)}_k \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \frac{p_Z(z^{(i)}_k)}{q_\phi(z^{(i)}_k|x^{(i)})} p_θ(x^{(i)}|z^{(i)}_k) \right]}_{\text{Objective in importance sampling}} = \log p_θ(X = x^{(i)}) \]

This is another way to see that the ELBO is a lower bound on the log evidence. We could use to the equation above to see that the KL divergence \(D_{\text{KL}}(q_\phi(z|x) || p_θ(z|x)) = 0\) if and only if the ELBO is equal to the log evidence. Another simple way to see this is to first suppose that \(q_\phi(z|x) = p_θ(z|x)\). The KL divergence would be zero, and using Bayes' rule we can see that the ELBO is equal to the log evidence.

\[ \begin{aligned} \text{ELBO} &:= \underset {z \sim q_\phi(z|x^{(i)})} {\mathbb{E}} \left[ \log \frac{p_Z(z)p_θ(x^{(i)}|z)}{q_\phi(z|x^{(i)})} \right] \\ &= \underset {z \sim p_θ(z|x^{(i)})} {\mathbb{E}} \left[ \log \frac{p_Z(z)p_θ(x^{(i)}|z)}{p_θ(z|x^{(i)})} \right] && \text{Assuming $q_\phi(z|x^{(i)}) = p_θ(z|x^{(i)})$}\\ &= \underset {z \sim p_θ(z|x^{(i)})} {\mathbb{E}} \left[ \log p_θ(x^{(i)}) \right] && \text{Using Bayes' rule: } p_θ(x^{(i)}) = \frac{p_θ(x^{(i)}|z)p_Z(z)}{p_θ(z|x^{(i)})}\\ &= \underbrace{\log p_θ(x^{(i)})}_{\text{log evidence}} && \text{Because $p_θ(z|x^{(i)})$ is independent of $z$} \end{aligned} \]

Variational Free Energy

The variational free energy is a concept from statistical physics that carries into variational inference by analogy. In physics, the Helmholtz free energy \(F = U - TS\) represents the amount of energy available to do work, where \(U\) is internal energy, \(T\) is temperature, and \(S\) is entropy.

In the context of variational inference, we can define the variational free energy \(F\) as follows. \(\mathbb{S}\) is the surprise function, the negative log.

\[ \begin{aligned} \text{Internal Energy U} &:= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right] \\ \text{Entropy Term TS} &:= \underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(q_\phi(z|x)) - \mathbb{S}(p_Z(z)) \right] \\ &= -D_{\text{KL}}(q_\phi(z|x) || p_Z(z)) \\ F &= \underbrace{\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right]}_{\text{Internal Energy U}} - \underbrace{(-D_{\text{KL}}(q_\phi(z|x) || p_Z(z)))}_{\text{Entropy Term TS}} \\ &= -\text{ELBO} \end{aligned} \]

In this analogy the free energy is the difference of two terms:

  1. Internal Energy (U): The term \(\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \mathbb{S}(p_θ(x|z)) \right]\) represents how poorly our model reconstructs the data (the reconstruction loss). As the reconstruction of \(x\) becomes more accurate, the internal energy decreases. In variational inference, we try to minimize the internal energy (i.e. reconstruction loss).

  2. Entropy Term (TS): The negative KL divergence \(-D_{\text{KL}}(q_\phi(z|x) || p_Z(z))\) represents the entropy term. The KL divergence measures how much information is lost when using q to approximate p, so its negative can be interpreted as a form of entropy.

Minimizing the free energy is equivalent to maximizing the ELBO, which in our context means finding the best trade-off between:

Another way to express the free energy starts with our ELBO equation

\[ \begin{aligned} \underbrace{\log p_θ(x)}_{\text{log evidence}} &= \underbrace{\text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x))}_{\text{KL divergence}} + \underbrace{\underset {z \sim q_\phi(z|x)} {\mathbb{E}} \left[ \log \frac{ p_Z(z)}{ q_\phi(z|x)} p_θ(x|z)\right]}_{\text{Evidence Lower Bound (ELBO)}} \\ \underbrace{\log p_θ(x)}_{\text{log evidence}} &= \underbrace{\text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x))}_{\text{KL divergence}} - F \\ \underbrace{F}_{\text{free energy}} &= \underbrace{\text{D}_{\text{KL}}(q_\phi(z|x) || p_θ(z|x))}_{\text{KL divergence}} + \underbrace{\mathbb{S}(p_θ(x))}_{\text{surprise of model evidence}} \\ \end{aligned} \]

Here the imperative to minimize free energy can be seen as similar to our original objective in our importance sampling section. We seek to find a \(q_\phi(z|x)\) that minimizes the KL divergence between the true backwards model \(p_θ(z|x)\) and our estimate \(q_\phi(z|x)\). Simultaneously, we adjust \(p_θ\) to minimize the surprise of the model evidence \(p_θ(x)\).