These notes are based on CMU Prof. Bhiksha Raj’s lecture video on generative models and variational autoencoders and his lecture slides, specifically the section about expectation maximization and variational inference. Prof. Raj gives a great intuitive explanation for the principles behind EM (expectation maximization), so I highly recommend watching the lecture video. But the overall topic is complex and there are many moving pieces, so here I will try to condense down to the key principles behind how this works. I hope these notes will be useful for both myself and others in the future.

Image space and latent space

Let \(x\) be a vector in a space of images, and \(z\) be a vector in a latent space of image classifications.

For example, if each image has \(n^2=n \cdot n\) pixels and each pixel has 3 numbers representing RGB values, then \(x\) could be a vector of \(3n^2\) elements, where each element is a number.

The conditional distribution \(P(x \mid z)\) is a probability distribution in \(x\), with nonzero probability over all \(x\) that satisfy the constraint \(z\). The peak of this distribution is for \(x\) that strongly satisfies \(z\); the tail of the distribution is for \(x\) that weakly satisfies \(z\).

For example, suppose \(z_0\) represents the category “rose”. Then a \(x_1\) which is an image of a rose would strongly satisfy \(z_0\), while a \(x_2\) that is a begonia would weakly satisfy \(z_0\), and a \(x_3\) that is a cat would not at all satisfy \(z_0\). This means that \(P(x_1 \mid z_0) > P(x_2 \mid z_0) > P(x_3 \mid z_0)\).

\(P(x,z)\) is the joint probability density of \(x\) and \(z\). If we sample \((x,z) \sim P(x,z)\), then each \((x,z)\) sample would be an image together with its latent classification. The image and its classification would be randomly chosen, with the most probable image/classification pair most frequently sampled.

If we instead want to generate images with a specific classification (e.g., “rose”), we need to fix \(z\) and randomly sample \(x \sim P(x \mid z)\).

Confusion: do we need to know \(P(z)\)?

We know that the conditional probability is given by \(P(x \mid z) = P(x,z)/P(z)\). So it appears that in order to sample from the conditional probability, we not only need to know the joint density \(P(x,z)\), but also need to know \(P(z)\). But if we have already decided to generate rose images and have already chosen \(z=z_0\), what does \(P(z)\) mean? And what if \(P(z)\) is difficult to evaluate (as it usually is)?

A quick way to think about this is to realize that \(P(x \mid z)\) is a probability density of \(x\). And therefore it must normalize to \(1\) when we integrate over all values of \(x\). And so the \(P(z)\) in the denominator is only there to ensure the normalization. But if we use a sampling method that does not require normalization, (e.g., MCMC), then we don’t care about this denominator \(P(z)\). And then a model for either \(P(x,z)\) or \(P(x \mid z)\) is sufficient. (But for a deeper dive into the subtleties here, see this post.)

So we start with a parametrized model \(P(x,z;\theta)\), designed so that any observed data distribution can be modeled with the appropriate settings for \(\theta\). To reiterate: this is our hypothesis for the probability density, fully specified once the parameters vector \(\theta\) is specified.

Also to clarify, the ground truth distribution is \(P(x,z)\) (you can see that it does not have the \(\theta\) parameters), but unfortunately we do not know the ground truth; and the model to be learned is \(P(x,z;\theta),\) which is our estimate/hypothesis for the ground truth. We apply maximum likelihood estimation (MLE) to learn the optimal \(\theta\), by maximizing the log-likelihoods evaluated over the observed data.

Not having all the data is ok

If we have observations of vectors \(x\) and \(z\), then we can optimize \(\theta\) to maximize the sum of the log-likelihoods, where the log-likelihood is the log of the parametrized model \(P(x,z;\theta)\) evaluated at the observed values for \(x\) and \(z\).

But since we only observe \(x\) and don’t have observations of \(z\), we need to use the log-likelihood of \(P(x;\theta)\) and feed the observed \(x\) values into that. We get \(P(x;\theta)\) by integrating over \(z\):

\[\begin{align} P(x;\theta) &= \int dz\, P(x,z;\theta) \end{align}\]

Learning the model

So we need to perform the following optimization:

\[\begin{align} \theta_{opt} & = \mathrm{argmax}_{\theta} \sum_i \log P(x_i;\theta) \\ & = \mathrm{argmax}_{\theta} \sum_i \log \int dz\, P(x_i,z;\theta), \end{align}\]

where the sum over \(i\) is the sum over the observed data \(x_i\).

The log of an integral may be difficult to maximize because taking a derivative would move a copy of the integral into the denominator, and this integral depends on the specific \(x_i\). So we would need to evaluate this integral for each \(x_i\), and the problem is that the integral could be over a high dimensional \(z\)-space. (BTW, it appears that for some problems this integral is actually tractable, so there is some subtlety here.) But in general, this may be difficult to maximize, so we go with the following inequality:

\[\begin{align} \sum_i \log \int dz\, P(x_i,z;\theta) &\geq \sum_i \int dz\, Q(z;x_i,\phi) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} \tag{1}\label{eq:1} \\ &:= \sum_i J(x_i;\theta,\phi), \end{align}\]

where \(Q(z;x_i,\phi)\) is an arbitrary probability density of \(z\) (i.e., \(Q(z;x_i,\phi)\) is normalized) and we used Jensen’s Inequality. We would like to pick a \(Q()\) that maximizes \(J\), and that is the same as picking a \(\phi\) that maximizes \(J\). (Note that we use a slightly different notation compared to the Raj lecture: \(z\) here instead of \(h\); \(x_i\) instead of \(o\).) \(J\) is called the ELBO, or evidence lower bound, because it is a lower bound on the evidence which is the sum of the log-likelihoods.

We now have two knobs to tune: \(J\)’s parameters include \(\theta\), which are parameters (e.g., means and variances of various Gaussians) that specify the joint probability density model \(P(x,z;\theta)\); and \(\phi\), which are parameters that specify the \(Q(z;x,\phi)\) probability density of \(z\).

There are a lot of symbols flying around here, but if we zoom out, we see that the overall task is to adjust \(\theta\) and \(\phi\) to make the ELBO \(J\), or more precisely the summed ELBO \(\sum_i J(x_i;\theta,\phi)\), as large as possible. If we do that successfully, then \(P(x,z;\theta)\) would be an accurate maximum likelihood model of the ground truth \(P(x,z)\).

Who said “Greed is Good”?

When we have multiple parameters to adjust in an optimization problem, a common strategy is called “greedy”; here, that means we fix the \(\theta\) parameters and optimize the \(\phi\) parameters, then fix the \(\phi\) parameters and optimize the \(\theta\) parameters, and rinse and repeat until we are close enough. We can see why this works by realizing that it is always possible to find a function \(Q_{opt}(z,x)\) that converts Eq. \(\eqref{eq:1}\) from an inequality to an equality (which means that \(J\) is maximized).

This of course means that if we could adjust \(\phi\) to make \(Q(z;x,\phi)\) be very close to such a \(Q_{opt}(z,x)\), then that is the \(\phi\) that maximizes \(J\). Call that value \(J_1\). And then if we fix \(\phi\) and maximize \(J\) further over \(\theta\), the resultant \(J\) will be greater than or equal to \(J_1\). So that guarantees that such a \(\phi\)-\(\theta\) iteration step will monotonically increase \(J\).

Why is it possible to always find such a \(Q_{opt}(z,x)\)? Because we can set \(Q_{opt}(z,x) = P(z \mid x;\theta)\). Remember that if we fix the value of \(\theta\), then \(P(x,z;\theta)\) is just a joint probability density of the two vectors \(x\) and \(z\), and that uniquely determines the conditional probability \(P(z \mid x;\theta)\). If we plug that value of \(Q_{opt}(z,x)\) into the RHS of Eq. \(\eqref{eq:1}\), we see that the log term becomes \(\log P(x_i;\theta)\), and since that is independent of \(z\) it can be factored out of the integral, which means the integral is \(1\) because it’s just the normalization of a probability density, which means that the RHS is actually equal to the LHS.

Just fill in the blanks

And we get a very nice intuitive interpretation if we play with the RHS of Eq. \(\eqref{eq:1}\) and consider the step where we fix \(\phi\) (meaning we fix \(Q(z;x_i,\phi) \approx Q_{opt}(z,x)\)) and optimize over \(\theta\):

\[\begin{align} \mathrm{argmax}_{\theta} \sum_i \int dz\, Q(z;x_i,\phi) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} &\approx \mathrm{argmax}_{\theta} \sum_i \int dz\, Q_{opt}(z,x) \log \frac{P(x_i,z;\theta)}{Q_{opt}(z,x)} \\ &= \mathrm{argmax}_{\theta} \sum_i \int dz\, Q_{opt}(z,x) \log P(x_i,z;\theta) \\ &= \mathrm{argmax}_{\theta} \sum_i \int dz\, P(z \mid x;\theta) \log P(x_i,z;\theta), \end{align}\]

where we dropped \(Q_{opt}()\) from the denominator of the log term because it is supposed to be fixed during this iteration step. And now we see that we are really evaluating the log-likelihood of the full joint probability density \(P(x,z;\theta)\) of both \(x\) and \(z\), and “filling in the missing values of \(z\)” by weighting each possible \(z\) value with the conditional probability of \(z\) given \(x\).

Summary of using EM to learn a generative model

Whew, that was still pretty complicated! Here is a final summary:

  1. We focus on an example of a generative model: a joint probability density over \(x\), the space of images, and \(z\), the latent space of classifications or categories
  2. If we had the conditional probability \(P(x\mid z)\), then we could pick a \(z\) and generate samples of \(x\) that fit that category
  3. Unfortunately, we do not know this conditional probability. Instead, we create a parametrized model \(P(x,z;\theta)\), and use a training set of images \(x_i\) to learn the parameters \(\theta\) and the relationship between \(x\) and the latent \(z\). We do this via maximum likelihood estimation, MLE, by adjusting \(\theta\) to maximize the evidence (sum of log-likelihoods).
  4. Unfortunately, MLE usually requires that we have observations for both \(x\) and \(z\), but we only have \(x\). So we have the problem of missing data. So we need to marginalize out \(z\) by integrating over it. Unfortunately, that results in an ugly log of an integral, which is difficult to handle.
  5. So instead we optimize the ELBO, which is a lower bound on the evidence, and we chose to use the ELBO because it leads to nice math cleverness that makes things work out. And we have two sets of parameters to optimize over, \(\theta\) and \(\phi\). The name “variational inference” comes about because \(\phi\) is a set of knobs to adjust the function \(Q()\), and this makes one think of the calculus of variations (from physics, Euler-Lagrange equations, etc., even though here we do not use the calculus of variations at all).
  6. And it turns out there is a clever “greedy” way to do the optimization, where we optimize alternately over the parameters \(\phi\) and \(\theta\).
  7. And it turns out that this is mathematically very similar to guessing the most likely values of \(z\) given the observed \(x\), and using that to fill in the missing data (see step 4 above). We do this iteratively. We have some estimate of the likely values of \(z\) given \(x\), then we use that to update our parametrized model for the joint density of \(x\) and \(z\), then we use that to update our estimate of the likely values of \(z\) given \(x\), and so on.

EM is all about filling in the missing \(z\)’s in the latent space by “making it up” according to the current hypothesis for \(z\), running MLE, then rinse and repeat.


Share or comment on Mastodon