Variational Autoencoders

March 5, 2023

[This is an article, which is something I’ve optimized for readability and transmission of ideas, opposed to notes that are the seed of an idea and something I’ve written quickly.]

I’ve been wanting to write a primer on the variational autoencoders for some time. There have been so many papers and blog posts written on this that I am at this point very late to the party. Nevertheless, I will put this out for the exercise of it. Perhaps I will have a slightly different way of thinking about it that someone finds illuminating.

The variational autoencoder (VAE) was introduced in the paper Auto-Encoding Variational Bayes by Kingma & Welling (2014). This topic, along with the more general topic of variational inference, were pain points for me when I was teaching myself machine learning, and I can imagine they make pain points for many others. There is now an abundance of literature which explains and re-explains variational inference and the VAE. Whether they were successful is for the reader to decide. I’ll add one more to the pile just for fun.

Some VAE lit:

Hidden Variable Models

$$
\newcommand{\0}{\mathrm{false}}
\newcommand{\1}{\mathrm{true}}
\newcommand{\mb}{\mathbb}
\newcommand{\mc}{\mathcal}
\newcommand{\mf}{\mathfrak}
\newcommand{\and}{\wedge}
\newcommand{\or}{\vee}
\newcommand{\es}{\emptyset}
\newcommand{\a}{\alpha}
\newcommand{\t}{\theta}
\newcommand{\T}{\Theta}
\newcommand{\o}{\omega}
\newcommand{\O}{\Omega}
\newcommand{\x}{\xi}
\newcommand{\z}{\zeta}
\newcommand{\fa}{\forall}
\newcommand{\ex}{\exists}
\newcommand{\X}{\mc{X}}
\newcommand{\Y}{\mc{Y}}
\newcommand{\Z}{\mc{Z}}
\newcommand{\P}{\Phi}
\newcommand{\y}{\psi}
\newcommand{\p}{\phi}
\newcommand{\l}{\lambda}
\newcommand{\s}{\sigma}
\newcommand{\pr}{\times}
\newcommand{\B}{\mb{B}}
\newcommand{\N}{\mb{N}}
\newcommand{\R}{\mb{R}}
\newcommand{\E}{\mb{E}}
\newcommand{\e}{\varepsilon}
\newcommand{\set}[1]{\left\{#1\right\}}
\newcommand{\par}[1]{\left(#1\right)}
\newcommand{\tup}{\par}
\newcommand{\brak}[1]{\left[#1\right]}
\newcommand{\vtup}[1]{\left\langle#1\right\rangle}
\newcommand{\abs}[1]{\left\lvert#1\right\rvert}
\newcommand{\inv}[1]{{#1}^{-1}}
\newcommand{\ceil}[1]{\left\lceil#1\right\rceil}
\newcommand{\df}{\overset{\mathrm{def}}{=}}
\newcommand{\t}{\theta}
\newcommand{\kl}[2]{D_{\text{KL}}\left(#1\ \| \ #2\right)}
\DeclareMathOperator*{\argmin}{argmin}
\DeclareMathOperator*{\argmax}{argmax}
\newcommand{\d}{\mathrm{d}}
\newcommand{\L}{\mc{L}}
\newcommand{\M}{\mc{M}}
\newcommand{\S}{\mc{S}}
\newcommand{\U}{\mc{U}}
\newcommand{\Er}{\mc{E}}
\newcommand{\ht}{\hat{\t}}
\newcommand{\hp}{\hat{\p}}
\newcommand{\D}{\mc{D}}
\newcommand{\H}{\mc{H}}
\newcommand{\softmax}{\text{softmax}}
\newcommand{\up}[1]{^{(#1)}}
$$

The VAE is one way of tractably fitting a hidden variable model to data. This is useful for generative modeling (we care about $x$) and representation learning (we care about $z$).

A hidden variable model is a parametrized joint distribution $p_\t(x,z)$, with parameters $\t\in\T$, on two variables, $x\in\X$ and $z\in\Z$, where we have a dataset $X = (x_1,x_2,\dots,x_n)$ of only $x$-examples (modeled as i.i.d.). The canonical example is that we have a dataset of images, e.g. $\X = \R^{w\times h\times c}$ is a space of images (multidimensional array of pixel values) with width $w$, height $h$ and $c$ color channels (e.g. $c=3$ for RGB).

Usually, hidden variable models are fit to $X$ with marginal likelihood maximization, i.e. find data-probability maximizing parameters:

$$\begin{align}
\t^* &= \argmax_\t p_\t(X) \\
&= \argmax_\t \prod_{i=1}^n p_\t(x_i)
\end{align}$$

where

$$\begin{aligned}
p_\t(x) = \int_\Z p_\t(x, z)\ \d z\,.
\end{aligned}$$

TODO: why maximum marginal likelihood? What does this accomplish? Why hidden variable?

Common use-cases of hidden variable models:

  1. Generative modeling
    1. Data generation - The fit model, $p_{\t^*}(x)$, is a “good” data generator if sampling $\hat{x} \sim p_{\t^*}(x)$ produces data that tends to “look like” the training data $X$. We may also consider the more quantitative objective of maximizing training likelihood without overfitting as measured by the likelihood on an evaluation dataset.
    2. Probability estimation - For some datapoint $\hat{x}$, we hope that the probability density/mass $p_{\t^*}(\hat{x})$ is a meaningful quantification of how likely/plausible $\hat{x}$ is in some objective sense, which might be useful for downstream applications. If the true data generating process that produced $X$ is well defined, then we desire that $p_{\t^*}(\hat{x})$ closely match the true probability of $\hat{x}$.
  2. Representation learning
    1. Efficient representations - We want to work in the latent $z$-space instead of raw data space $x$ in downstream tasks. We sample (encode) $\hat{z}\sim p_\t(z\mid x=\hat{x})$ and call $\hat{z}$ a representation of $\hat{x}$. If $\hat{x}$ is high dimensional (e.g. images) and $\hat{z}$ is lower dimensional, then it is computationally efficient to work with $\hat{z}$ in place of $\hat{x}$.
    2. Semantic “disentangling” - We desire that the marginal latent distribution $p_\t(z)$ be simple (e.g. Gaussian or uniform) while the most semantically salient parts of $x$-space are all accumulated into the high-density parts (modes) of $p_\t(z)$. This offers ease of access to semantically meaningful data through $z$-space, whereas the semantically meaningful parts of $x$-space are scattered around very chaotically.
  3. “Classical” statistics - see #Classical Use-Case

(Note that all of these objectives are not well-posed. I am purposefully being hand-wavy and vague because that is how they currently stand. I do think these objectives should be made well-posed. Perhaps I will say more about that in future posts.)

In general, $x$ and $z$ might each be vectors of many variables, and their joint distribution $p_\t(x,z)$ may factor to form a Bayesian network. The method that makes the VAE tractable may be applied to any Bayesian network (Kingma et al. actually introduces the general case, called variational inference/Bayes, and then gives the VAE as a special case). We have a VAE when the Bayesian network has the simple form:

$$
p_\t(x,z) = p_\t(x\mid z)p_\t(z)
$$

where our hidden variable model is defined in two parts: $p_\t(z)$ and $p_\t(x\mid z)$ - both parametrized distributions (we suppose that $\t$ is a vector of parameters which may be partitioned into separate parameters for each distribution). This case is called a variational autoencoder because this two-variable form is reminiscent of autoencoders (AEs). The VAE loss is often framed as being a stochastic version of the AE reconstruction loss (indeed Kingma et al. promotes this perspective), but I actually find that to be a flawed way of thinking about the VAE (more on that in a follow up post). For that reason, in this post I will present the VAE as a hidden variable model and not mention the stochastic AE perspective.

Classical Use-Case

We believe we know something about the data generating process (DGP), which generates both observed and unobserved (latent) data, and we want to estimate unknown DGP parameters from the data we observe.

I don’t include this use case in the list above because I don’t think that it describes the goals of AI/ML/DL. When working with perceptual (read high-dimensional) data like images, we don’t start from a place of knowing something about the DGP (having domain knowledge). While our data was generated by some physical process, it is almost certainly in practice no i.i.d. across the datapoints, and the structure of the unobserved data in the process is going to be very different from “a neural network that maps $x$ to $z$”. There is no sense in which whatever model we are using contains in its parameter space the true DGP, in all its glory.

Proxy Optimization

This section derives what is usually called “variational inference”, but I am calling it “proxy inference.” I would like to write a post about why I disagree with the label “variational” here but that might take a while. I want to understand the long and complex history behind how the word “variational” makes its way into machine learning. Nevertheless, the crux of my disagreement with the phrase is that the derivation of the variational autoencoder has nothing to do with the calculus of variations, which is the mathematics of optimization over infinite dimensional function spaces. Optimization over parametrized function spaces, which are finite dimensional, is just regular optimization for which regular calculus is sufficient. The absence of calculus of variations from the VAE will be clear in the derivation below.

All the math below defines and manipulates loss functions, which are functions we seek to minimizeIf we are using probability mass functions, as opposed to probability density, then these loss functions have the nicer property of being non-negative and lower bounded by $0$. I find this to be conceptually easier to work with. The usual VAE expositions (including Kingma) give objective functions which are maximized (which are negative and upper bounded by $0$ if we are using probability mass functions).

Here, our loss function is

$$
\S(\t;\ X) = \sum_{i=1}^n \S(\t;\ x_i)
$$

which is defined in terms of the element-wise loss function
$$\S(\t;\ x) = \log \big(1/p_\t(x)\big)\,.$$

The quantity $\log \big(1/p_\t(x)\big)$ is known as the surprisal (also self-information) of $x$ under $p_\t$, which is a non-negative number and equal to $0$ iff $p_\t(x)=1$. Notationally, $\S(\t;\ X)$ and $\S(\t;\ x)$ are functions of parameters $\t$ which are optimizing over, with the semicolons distinguishing the auxiliary data input $x$ or $X$ which is held fixed during optimization.

Note that $\log \big(1/p_\t(x)\big)$ is bounded below by $0$ if $p_\t(x)$ is a probability mass function, but is unbounded below if $p_\t(x)$ is a probability density function.

Let’s use the same hidden variable model from above, defined in two parts $p_\t(x\mid z)$ and $p_\t(z)$:
$$p_\t(x) = \int_\Z p_\t(z)p_\t(x\mid z)\ \d z$$

Let’s suppose this integral is intractable to compute and so our loss function is intractable to optimize (otherwise we can do exact likelihood maximization and we are done). The following exposition will find an alternative proxy loss that we might be able to work with in place of $\S(\t;\ X)$.

Let’s first rewrite $\S(\t;\ x)$ in an equivalent form,
$$\begin{align}
\S(\t;\ x) &= \log \par{\frac{1}{p_\t(x)}} \\
&= \log \par{\frac{p_\t(z \mid x)}{p_\t(x \mid z)p_\t(z)}} \\
&= \log \par{\frac{p_\t(z \mid x)}{p_\t(x,z)}}\,. \tag{1} \label{S}
\end{align}$$

(Note that while $z$ appears here as a free variable, the numerator and denominator are such that the fraction is constant w.r.t. choice of $z$, and so this form of $\S(\t;\ x)$ does not depend on $z$ in that sense.)

To obtain our proxy loss, let’s define a proxy distribution $q_\p(z\mid x)$ to use in place of $p_\t(z\mid x) = p_\t(x \mid z)p_\t(z)/p_\t(x)$ which is also intractable to calculate, with new parameters $\p\in\Phi$ for the proxy distribution. We want to choose $q_\p(z\mid x)$ so that it is tractable to work with directly - both for sampling from and calculating explicit probabilities.

(Note that the conditional notation $q_\p(z\mid x)$, defined as $q_\p(z, x)/q_\p(x)$, implies the existence of $q_\p(x)$, but we are not interested in that quantity. The conditional notation is convenient for expressing a probability distribution which depends on some auxiliary input $x$).

In expression $\eqref{S}$, the probability $p_\t(z \mid x)$ is the sole source of intractability ($p_\t(x \mid z)$ and $p_\t(z)$ are chosen by us to define the model and so we presumably choose them to be tractable to deal with). Swapping out $p_\t(z \mid x)$ for $q_\p(z\mid x)$ in expression $\eqref{S}$, we have

$$
\log \par{\frac{q_\p(z \mid x)}{p_\t(x,z)}}
$$

as our candidate proxy loss, comprised of only tractable quantities. However, this new expression now depends on both $x$ and $z$, because the numerator and denominator are not likely to perfectly cancel out the dependence on $z$ like in expression $\eqref{S}$.

We want to get rid of the dependence on a free variable $z$ in our proxy loss, because we have no $z$-observations. Let’s just try something here (of course it will end up working out). Take the expectation of $z$ w.r.t. $q_\p(z\mid x)$. Now we have
$$\begin{align}
\U(\t,\p;\ x) = \E_{z\sim q_\p(z \mid x)}\brak{ \log \par{\frac{q_\p(z \mid x)}{p_\t(x,z)}} } \tag{2}\label{U} \\
\S(\t;\ x) = \E_{z\sim q_\p(z \mid x)}\brak{ \log \par{\frac{p_\t(z \mid x)}{p_\t(x,z)}} } \tag{3}\label{SE}
\end{align}$$

where $\eqref{U}$ defines the new quantity $\U(\t,\p;\ x)$ and $\eqref{SE}$ is another equivalent form of $\S(\t;\ x)$. We will see that $\U(\t,\p;\ x)$ is a suitable proxy loss and approximation to $\S(\t;\ x)$.

Define the loss approximation error as
$$\begin{align}
\Er(\t,\p;\ x) &= \U(\t,\p;\ x) - \S(\t;\ x)\,.
\end{align}$$

Doing a bit of math (which I leave as an exercise to the reader), we find the equivalent form

$$\begin{align}
\Er(\t,\p;\ x) &= \int_{\Z} q_\p(z\mid x)\log\par{\frac{q_\p(z\mid x)}{p_\t(z\mid x)}}\ \d z \\
&= \kl{q_\p(z \mid x)}{p_\t(z \mid x)}\,. \tag{4}\label{KL1}
\end{align}$$

where $\eqref{KL1}$ is the KL-divergence from $q_\p(z\mid x)$ to $p_\t(z \mid x)$ w.r.t. $z$. KL-divergence is always non-negative, and $0$ iff $q_\p(z\mid x) = p_\t(z \mid x)$ for all $z$ (wherever both distributions have support). For this reason, KL-divergence is often used as a comparison function between distributions (not quite a distance function because it is not symmetric).

Viewing KL-divergence as a comparison of distributions, we see that $\Er(\t,\p;\ x)$ is also the approximation error due to using $q_\p(z\mid x)$ in place of $p_\t(z \mid x)$The ordering of the arguments to the KL-divergence has difference effects on how $q_\p(z\mid x)$ will fit $p_\t(z \mid x)$. https://blog.evjang.com/2016/08/variational-bayes.html explains this nicely., while simultaneously the approximation error due to using $\U(\t,\p;\ x)$ as a proxy loss in place of $\S(\t;\ x)$. How elegant!

We see now that $\U(\t,\p;\ x)$ as an approximation to $\S(\t;\ x)$ is as good as our approximation $q_\p(z\mid x)$ to $p_\t(z \mid x)$. We seek now to minimize $\U(\t,\p;\ x)$ simultaneously w.r.t. $\t$ and $\p$, where $\t$ controls the primary objective $\S(\t;\ x)$ and $\p$ controls the approximation $q_\p(z\mid x)$. More on the behavior of this dual optimization in #Optimization Peculiarities.

Since $\Er(\t,\p;\ x)$ is always non-negative, we have
$$
\U(\t,\p;\ x) \geq \S(\t;\ x)
$$

and so $\U(\t,\p;\ x)$ is an upper bound to $\S(\t;\ x)$, with equality iff $q_\p(z\mid x) = p_\t(z \mid x)$ for all $z$ (this is why I choose the letter $\U$). Call $\U(\t,\p;\ x)$ the Surprisal Upper BOund (SUBO).This is my pun on the usual Evidence Lower BOund (ELBO). The usual formulation is that we are maximizing log-likelihood, also called the evidence, $\log p_\t(x) = -\S(\t;\ x)$, giving us the lower bound to the log-likelihood, $\L(\t,\p;\ x) = -\U(\t,\p;\ x)$. We keep the error term positive, defining it as $\Er(\t,\p;\ x) = \log p_\t(x) - \L(\t,\p;\ x)$, and everything else follows in more or less the same way. I’ve chosen to make this sign flip because I find it more intuitive, especially when dealing with non-negative loss and upper bound in the case of probability mass functions, because then these are Shannon information quantities and we can reason about bits.

Finally, we can straight forwardly apply our proxy loss to the entire dataset where $\U(\t,\p;\ X) = \sum_{i=1}^n \U(\t,\p;\ x_i)$ and $\Er(\t,\p;\ X) = \sum_{i=1}^n \Er(\t,\p;\ x_i)$.

To recap, we have
$$
\U(\t,\p;\ X) = \Er(\t,\p;\ X) + \S(\t;\ X)\,, \tag{5}\label{Usum}
$$
with $\S(\t;\ X)$ as our primary loss, $\U(\t,\p;\ X)$ as an upper bound to our loss which we will minimize as a proxy, and $\Er(\t,\p;\ X)$ the non-negative gap between them and simultaneously the approximation error of $q_\p(z\mid x_i)$ in place of $p_\t(z \mid x_i)$ on $x_i \in X$.

Rewritten as the sum $\eqref{Usum}$, we may view $\U(\t,\p;\ X)$ as a dual objective: simultaneously minimize the loss $\S(\t;\ X)$ and minimize the approximation error $\Er(\t,\p;\ X)$ of $q_\p(z\mid x)$ to $p_\t(z \mid x)$ (for $x \in X$). Both $\S(\t;\ X)$ and $\Er(\t,\p;\ X)$ are intractable to calculate by themselves, but we shall see in the #Tractable Proxy Optimization that optimization of $\U(\t,\p;\ X)$ can be made tractable. How neat!

Optimization Peculiarities

We want to jointly minimize our proxy loss $\U(\t,\p;\ X)$ w.r.t $\t$ and $\p$. How this optimization affects the target loss function $\S(\t;\ X)$ which we actually care about is a bit tricky. Let’s briefly investigate the relationship between proxy and target during optimization of the proxy.

Without getting into the details of how we will perform the minimization of $\U(\t,\p;\ X)$ just yet (left to the next section, #Tractable Proxy Optimization), let’s discuss some of the consequences of this minimization. For now, all we need to assume is that our minimization process is numeric and performed iteratively in steps $t=1,2,3,\dots$ (yes we will be using stochastic gradient descent in practice).

First consider the case where we take a minimization step w.r.t. $\p$ while holding $\t$ fixed, so that we update $\p_t$ to $\p_{t+1}$ where

$$\U(\t,\p_{t+1};\ X) < \U(\t,\p_t;\ X)\,.$$

Looking at the decomposition $\U(\t,\p;\ X)=\Er(\t,\p;\ X) + \S(\t;\ X)$ from $\eqref{Usum}$, we see that only $\Er(\t,\p;\ X)$ depends on $\p$, and so we must have $\Er(\t,\p_{t+1};\ X) < \Er(\t,\p_{t};\ X)$ must decrease while $\S(\t;\ X)$ stays constant. Then $q_{\p_{t+1}}(z\mid x)$ has moved closer to $p_\t(z \mid x)$ on all $x\in X$.

Next consider a minimization step w.r.t. $\t$ while holding $\p$ fixed, so that we update $\t_t$ to $\t_{t+1}$ where

$$\U(\t_{t+1},\p;\ X) < \U(\t_t,\p;\ X)\,.$$

Both $\S(\t;\ X)$ and $\Er(\t,\p;\ X)$ depend on $\t$, but now we don’t know in which direction each may individually change. It may be the case that $\S(\t_{t+1};\ X) < \S(\t_{t};\ X)$, but at the cost of pushing $p_{\t_{t+1}}(z\mid x)$ further away from $q_\p(z\mid x)$, thus raising $\Er(\t_{t+1},\p;\ X)$. It may also be the case that $p_{\t_{t+1}}(z\mid x)$ is brought closer to $q_\p(z\mid x)$ so that $\Er(\t_{t+1},\p;\ X) < \Er(\t_{t},\p;\ X)$, but at the cost of raising $\S(\t_{t+1};\ X)$. In more fortuitous situations both quantities go down in tandem.

One thing we can do (and is done in practice) is alternate optimization steps between $\t$ and $\p$, where we first update $\t_t$ to $\t_{t+1}$ given fixed $\p_t$, and then we update $\p_t$ to $\p_{t+1}$ given fixed $\t_{t+1}$. The latter will keep the upper bound $\U(\t,\p;\ X)$ tight by bringing down the error $\Er(\t,\p;\ X)$ after each $\t$-step. I don’t know that this is the case, but it seems plausible enough to conjecture that when $\Er(\t,\p;\ X)$ is small compared to $\S(\t;\ X)$, then minimization steps on $\U(\t,\p;\ X)$ w.r.t. $\t$ are likely to cause $\S(\t;\ X)$ to decrease (unless $\S(\t;\ X)$ is near its minimum). Thus being aggressive on $\p$ optimization for each $\t$-step is a good thing. (Note that if our model class $q_\p(z\mid x)$ is not expressive enough to get close to $p_\t(z \mid x)$ then we are hosed out the gate.)

However, the downside to this strategy is that we may spend a lot of optimization time meandering, where $\U(\t,\p;\ X)$ goes down while $\S(\t;\ X)$ does not change very much. This is especially a problem if we employ early stopping (stopping the optimization process before convergence to avoid overfitting to the evaluation dataset). All we are guaranteed of is that if we globally minimize $\U(\t,\p;\ X)$ w.r.t. $\t$ and $\p$ jointly, then $\S(\t;\ X)$ will be globally minimized. In practice with neural networks and gradient descent, we have no idea how far we are from the global minimum. Since we don’t have a way to estimate the quantities $\Er(\t,\p;\ X)$ and $\S(\t;\ X)$ directly we don’t know how good our upper bound is during optimization.

Probability Density

If $p_\t(x)$ is a probability density, then $\S(\t;\ X)$ is unbounded below. This poses a problem, since minimizing $\U(\t,\p;\ X)=\Er(\t,\p;\ X) + \S(\t;\ X)$ might forever push $\S(\t;\ X)$ lower and lower without reducing the error $\Er(\t,\p;\ X)$.

In Kingma et al., two choices for $p_\t(x \mid z)$ were explored: Gaussian and Bernoulli. In the case of Gaussian, a fixed (diagonal) covariance matrix is chosen. This results in upper bounding $p_\t(x \mid z)$, and by extension $p_\t(x)$ (holding the prior $p_\t(z)$ fixed), and thus lower bounding $\S(\t;\ X)$.

In the case of Bernoulli, $p_\t(x \mid z)$ is a probability mass and so $p_\t(x)$ is upper bounded at $1$.

Thus, it is implied in Kingma et al. that we aught to constrain the form of $p_\t(x \mid z)$ so that it is upper bounded.

Tractable Proxy Optimization

Even with all the work we did above, $\U(\t,\p;\ x)$ is still intractable to calculate and optimize because of the pesky expectation over $q_\p(z\mid x)$. Unless we can make the optimization of $\U(\t,\p;\ x)$ tractable, we have not gotten anywhere with the introduction of this proxy loss. So now on to tractability…

Let’s rewrite $\U(\t,\p;\ x)$ in an equivalent form (another exercise for the reader):

$$\begin{align}
\U(\t,\p;\ x) &= \kl{q_\p(z\mid x)}{p_\t(z)} + \E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} } \\
&= \mc{D}(\t,\p;\ x) + \E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} } \,. \tag{6}\label{UDE}
\end{align}$$

where I define $\mc{D}(\t,\p;\ x) = \kl{q_\p(z\mid x)}{p_\t(z)}$ for convenience. Suppose that in $\eqref{UDE}$, $\mc{D}(\t,\p;\ x)$ can be computed exactlyIn practice we will define $q_\p$ to be something like $q_\p(z \mid x) = \mc{N}(z \mid \mu_\p(x), \s_\p(x))$, a Gaussian distribution with parameters $\mu_\p(x)$ and $\s_\p(x)$ which are complicated functions of $x$ (e.g. neural networks). In this case if $p_\t(z)$ is also Gaussian then $\kl{q_\p(z \mid x)}{p_\t(z)}$ is the KL-divergence of two gaussians for which we can symbolically calculate a closed-form expression. but that the expectation term $\E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} }$ is still intractable.

If we are using a stochastic gradient descent (SGD) based optimizer, all we actually need is the gradient $\nabla_{\t,\p}\ \U(\t,\p;\ X)$. We don’t even need to calculate $\U(\t,\p;\ X)$ itself. Furthermore, it is sufficient if we have some way of producing noisy gradients (a gradient estimator) of $\U(\t,\p;\ X)$ w.r.t. $\t,\p$ instead of exact gradients, so long as the expectation of the noisy gradients equals the true gradient (the estimator is unbiased).

Consider the Monte Carlo estimator of $\U(\t,\p;\ x)$, defined by randomly drawing a sample $Z = (z_1,\dots,z_m)$ i.i.d. from $q_\p(z\mid x)$, notated as $Z \overset{\text{iid}}\sim q_\p(z\mid x)$, and passing it into the estimator function

$$\begin{align}
\hat{\U}(\t,\p;\ x, Z) &= \mc{D}(\t,\p;\ x) + \frac{1}{\abs{Z}}\sum_{z \in Z} \log\par{1/p_\t(x \mid z)}\,.
\end{align}$$

Then the Monte Carlo estimator of $\U(\t,\p;\ X)$ is

$$\begin{align}
Z_1,\dots,Z_n &\overset{\text{iid}}\sim q_\p(z\mid x) \\
\hat{\U}(\t,\p;\ X, Z_1,\dots,Z_n) &= \sum_{i=1}^n \hat{\U}(\t,\p;\ x_i, Z_i)\,.
\end{align}$$

From here on, I will work with MC estimators of functions of single datapoints $x$. The MC estimators for functions of $X$ are in all cases the average of the element-wise MC estimators.

Note that in practice we calculate MC estimators on minibatches of data, $\tilde{X}$, instead of $X$, where $\tilde{X}$ is a vector of size $b < n$ randomly drawn from $X$ (e.g. uniformly with replacement).

Gradient Estimators

The Monte Carlo estimator $\widehat{\nabla_\t\ \U}(\t,\p;\ x, Z)$ of the gradient $\nabla_{\t}\ \U(\t,\p;\ x)$ is just the gradient of $\hat{\U}(\t,\p;\ x, Z)$,

$$\begin{align}
\widehat{\nabla_\t\ \U}(\t,\p;\ x, Z) &= \nabla_\t\brak{ \hat{\U}(\t,\p;\ x, Z) } \\
&= \nabla_\t \mc{D}(\t,\p;\ x) + \frac{1}{\abs{Z}}\sum_{z \in Z} \nabla_\t [\log\par{1/p_\t(x \mid z)}]\,.
\end{align}$$

Since MC estimators are unbiased, this gradient estimator is unbiased.

However, the Monte Carlo estimator $\widehat{\nabla_\p\ \U}(\t,\p;\ x, Z)$ of the gradient $\nabla_{\t}\ \U(\t,\p;\ x)$ cannot be derived in the same manner because the dependency of the expectation $\E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} }$ on $\p$ is in the distribution the expectation is taken over, $q_\p(z\mid x)$. This dependency is not explicitly present in the MC estimator, which replaces $\E_{z\sim q_\p(z\mid x)}[\dots]$ with a sum over instances of $z$ sampled from $q_\p(z\mid x)$. We cannot “pass the gradient” through the discontinuous operation of sampling from a distribution.

One contribution of Kingma et al. is a method for deriving a Monte Carlo gradient estimator $\widehat{\nabla_\p\ \U}(\t,\p;\ x, Z)$ which is useful to us. There was a previously existing method to do the same but it is not useful (estimator variance is too high). I’ll introduce the old way first, and then Kingma’s solution.

Log-Derivative Trick.

The log-derivative trick is

$$\frac{\d}{\d x}f(x) = f(x)\ \frac{\d}{\d x}\log(f(x))\,.$$

in which we’ve moved the derivative, in some sense, off of $f(x)$ (easy to verify this yourself with chain rule). This allows us to move derivatives inside expectations even when the expectation distribution depends on the differentiated variable:

$$\begin{align}
& \frac{\d}{\d \t}\E_{x\sim p_\t(x)} [f(x)] \\
=\ & \int \frac{\d}{\d \t}[p_\t(x)]\ f(x)\ \d x \\
=\ & \int p_\t(x)\ \frac{\d}{\d \t}[\log p_\t(x)]\ f(x)\ \d x \\
=\ & \E_{x\sim p_\t(x)} \brak{ \frac{\d}{\d \t}[\log p_\t(x)]\ f(x) }\,.
\end{align}$$

Applying this trick to the expectation in $\U(\t,\p;\ x)$, we have

$$
\nabla_\p\ \E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} } = \E_{z\sim q_\p(z\mid x)}\big[ \nabla_\p\brak{ \log q_\p(z\mid x) }\ \log\par{1/p_\t(x \mid z)} \big]\,,
$$

which we can Monte Carlo estimate as

$$\begin{align}
\frac{1}{\abs{Z}}\sum_{z \in Z} \nabla_\p\brak{ \log q_\p(z\mid x) }\ \log\par{1/p_\t(x \mid z)}\,, \quad Z \overset{\text{iid}}\sim q_\p(z\mid x)\,.
\end{align}$$

However, Kingma et al. reports that this estimator is known to have very high variance and is not useful.

One might note that this is exactly the same form as the REINFORCE (policy gradient) method for RL, where here $z$ is the action, $x$ is the observation, $q_\p(z\mid x)$ is the policy, and $\log\par{1/p_\t(x \mid z)}$ is the reward. I do wonder if modern variance reduction techniques for policy gradient methods might work if applied here (e.g. using a baseline). But we should expect that we can do better than RL methods since our “reward” function is known to us in closed-form and we can differentiate through it. RL methods are developed precisely for the situation where the reward function is unknown or non-differentiable.

Reparametrization Trick

Kingma et al. introduces the reparametrization trick to produce an alternative gradient estimator with lower variance.

Suppose we are able to define a differentiable function $g_\p(\e;\ x)$ w.r.t. $\p$ and distribution $p(\e)$ s.t. drawing samples $\hat{z} \sim q_\p(z \mid x)$ is equivalent to drawing $\hat{\e} \sim p(\e)$ and then letting $\hat{z} = g_\p(\hat{\e};\ x)$. The canonical example of this is in the case of $q_\p(z \mid x)$ being Gaussian, but Kingma et al. explains how this reparametrization can be done for a whole host of distributions.


Gaussian case:
Supposing $q_\p(z \mid x) = \mc{N}(z \mid \mu_\p(x), \s_\p(x))$,
then we have

$$\begin{align}
p(\e) &= \mc{N}(\e \mid 0, 1) \\
g_\p(\e;\ x) &= \mu_\p(x) + \s_\p(x)\e\,.
\end{align}$$


Applying the reparametrization trick to the expectation from above, we have

$$\begin{align}
\nabla_\p\ \E_{z\sim q_\p(z\mid x)}\brak{ \log\par{1/p_\t(x \mid z)} } &= \nabla_\p\ \E_{\e\sim p(\e)}\brak{ \log\par{1/p_\t(x \mid z=g_\p(\e;\ x))} } \\
&= \E_{\e\sim p(\e)}\brak{ \nabla_\p\brak{ \log\par{1/p_\t(x \mid z=g_\p(\e;\ x))} }}
\end{align}$$

which gives us our MC estimator

$$\begin{align}
\widehat{\nabla_\p\ \U}(\t,\p;\ x, E) &= \mc{D}(\t,\p;\ x) + \frac{1}{\abs{E}}\sum_{\e \in E} \nabla_\p \Big[ \log\big( 1/p_\t(x \mid z=g_\p(\e;\ x)) \big) \Big] \,, \\
E &= (\e_1,\dots,\e_m) \overset{\text{iid}}\sim p(\e)\,.
\end{align}$$

The trick is that we’ve reparametrized the expectation so that we take it over a distribution that we don’t parametrize, and transform what’s inside the expectation through the function $g_\p(\e;\ x)$ which transforms $p(\e)$ to $q_\p(z\mid x)$. Now we can straightforwardly perform gradient descent on $g_\p(\e;\ x)$ inside the expectation w.r.t. $\p$, and also Monte Carlo estimate that gradient.

Kingma et al. reports this estimator has much lower variance and makes the whole enterprise of $\U(\t,\p;\ x)$ minimization tractable. This is the breakthrough of the paper.

Question: Can we derive expressions for the variance of these two estimators? Why is the latter estimator lower variance than the former?

Conclusion

To recap:

Supposing we have a hidden variable model $p_\t(x,z)$ which we want to fit to data $X$ by minimizing the marginal surprisal $\log(1/p_\t(x))$ (or maximizing the marginal log likelihood), but where $p_\t(x)$ is intractable to calculate or maximize directly, we devised an upper bound $\U(\t,\p;\ x)$ which we still cannot calculate or optimize directly, but for which we can calculate a stochastic gradient estimator which is unbiased and importantly, low variance, for performing stochastic gradient descent on $\U(\t,\p;\ x)$ to indirectly minimize $\log(1/p_\t(x))$.

articleMLvariational-ML

Gaussian Processes

Shannon vs Universal Compression