[This is a note, which is the seed of an idea and something I’ve written quickly, as opposed to articles that I’ve optimized for readability and transmission of ideas.]
$$
\newcommand{\0}{\mathrm{false}}
\newcommand{\1}{\mathrm{true}}
\newcommand{\mb}{\mathbb}
\newcommand{\mc}{\mathcal}
\newcommand{\mf}{\mathfrak}
\newcommand{\and}{\wedge}
\newcommand{\or}{\vee}
\newcommand{\es}{\emptyset}
\newcommand{\a}{\alpha}
\newcommand{\t}{\theta}
\newcommand{\T}{\Theta}
\newcommand{\o}{\omega}
\newcommand{\O}{\Omega}
\newcommand{\x}{\xi}
\newcommand{\z}{\zeta}
\newcommand{\fa}{\forall}
\newcommand{\ex}{\exists}
\newcommand{\X}{\mc{X}}
\newcommand{\Y}{\mc{Y}}
\newcommand{\Z}{\mc{Z}}
\newcommand{\P}{\Phi}
\newcommand{\y}{\psi}
\newcommand{\p}{\phi}
\newcommand{\l}{\lambda}
\newcommand{\s}{\sigma}
\newcommand{\pr}{\times}
\newcommand{\B}{\mb{B}}
\newcommand{\N}{\mb{N}}
\newcommand{\R}{\mb{R}}
\newcommand{\E}{\mb{E}}
\newcommand{\e}{\varepsilon}
\newcommand{\set}[1]{\left\{#1\right\}}
\newcommand{\par}[1]{\left(#1\right)}
\newcommand{\tup}{\par}
\newcommand{\brak}[1]{\left[#1\right]}
\newcommand{\vtup}[1]{\left\langle#1\right\rangle}
\newcommand{\abs}[1]{\left\lvert#1\right\rvert}
\newcommand{\inv}[1]{{#1}^{-1}}
\newcommand{\ceil}[1]{\left\lceil#1\right\rceil}
\newcommand{\df}{\overset{\mathrm{def}}{=}}
\newcommand{\t}{\theta}
\newcommand{\kl}[2]{D_{\text{KL}}\left(#1\ \| \ #2\right)}
\DeclareMathOperator*{\argmin}{argmin}
\DeclareMathOperator*{\argmax}{argmax}
\newcommand{\d}{\mathrm{d}}
\newcommand{\L}{\mc{L}}
\newcommand{\M}{\mc{M}}
\newcommand{\S}{\mc{S}}
\newcommand{\U}{\mc{U}}
\newcommand{\Er}{\mc{E}}
\newcommand{\ht}{\hat{\t}}
\newcommand{\hp}{\hat{\p}}
\newcommand{\D}{\mc{D}}
\newcommand{\H}{\mc{H}}
\newcommand{\softmax}{\text{softmax}}
\newcommand{\up}[1]{^{(#1)}}
$$
I was trying to understand bits-back coding since it appears in a paper I’m reading, Variational Lossy Autoencoder. However, the explanation in the paper is confusing, so I did some searching around for another resource on the topic. I found this video lecture by Pieter AbbeelApparently he too once found this topic confusing: https://twitter.com/pabbeel/status/1104816262176096256 which illuminated it for me: Week 4 CS294-158 Deep Unsupervised Learning (2/20/19) @ 5679s.
I’ll explain bits-back coding briefly here in case anyone finds it helpful.
Bits-Back Coding
Suppose there is a sender that wants to communicate with a receiver. The sender wants to send a message $x_0\in\X$. The sender and receiver already share knowledge of a probability distribution $p(x)$.
Note that we can always convert any distribution $p$ into a pair of functions, an encoder and decoder, between $\X$ and $\B^*$ (set of all finite binary strings) such that a binary encoding $\hat{x}$ of $x$ has approximately length $\ell(\hat{x})\approx -\lg p(x)$ bits, and $x$ can be fully recovered from $\hat{x}$ by the decoder (lossless coding). This encoding length can be achieved in practice using, for example, arithmetic coding (AC).
Now suppose that $p(x)$ is intractable to calculate directly, making the encoder and decoder functions also computationally intractable to call. How can the sender still send $x_0$ efficiently?
If we can extend our distribution to a new variable $z\in\Z$ so that we have a tractable joint distribution $p(x,z)$ where $p(x) = \int p(x,z)\ \d z$, then the sender can encode the tuple $(x_0,z_0)$, for some choice of $z_0$, for $-\lg p(x,z)$ bits. We can minimize the cost of $z_0$ by first randomly drawing $z_0$ from the marginal $p(z)$, and then encoding $x_0$ w.r.t. $p(x\mid z=z_0)$. The receiver (who also knows $p(x,z)$ ) first decodes $\hat{z}_0$ to $z_0$ using $p(z)$ and then decodes $\hat{x}_0$ to $x_0$ using $p(x\mid z=z_0)$. This requires that $p(z)$ and $p(x\mid z)$ be tractable.
So long as we want to send $z_0$, we cannot avoid spending bits sending $z_0$. The insight of bits-back coding is that, if the sender had a sequence of messages to send after $x_0$, that they can use the same bits $\hat{z}_0$ to simultaneously send $z_0$ and subsequent messages, because we don’t actually care about what $z_0\in\Z$ we choose. So in that sense, the sender and receiver get bits back.
Let’s go through the communication protocol in detail. Suppose the sender has some subsequent message $y_0\in\Y$ (imagine a queue of messages starting with $x_0,y_0,\dots$) which has already been encoded to $\hat{y}_0$ (e.g. using some $p(y)$ shared with the receiver). If $\hat{y}_0$ has a low deficiency of randomness (i.e. random looking, see Li & Vitanyi) then it is as good as using a “random” sequence. Note that sampling $z_0 \sim p(z)$ is equivalent to decoding by passing a random binary sequence (equivalently a real number in $[0,1]$) into the CDF of $p(z)$. So we can decode $\hat{y}_0$ to some $z_0$ via the decoder for $p(z)$.
The sender protocol:
- Encode $y_0$ to $\hat{y}_0$ using some encoder (doesn’t really matter what).
- Decode the first $k\approx -\lg p(z\mid x=x_0)$ bits of $\hat{y}_0$ to $z_0$ using the decoder for $p(z\mid x=x_0)$, where $k$ is the number of bits the decoder consumes to produce $z_0$.
- In other words, $\hat{y}_0$ is the input to the sampling function that draws $z_0 \sim p(z\mid x=x_0)$.
- Assume that $\hat{y}_0$ is at least as long as the number of bits $k$ required to sample from $p(z\mid x=x_0)$.
- Let $\hat{y}_0\up{>k}$ be the remaining bits of $\hat{y}_0$ left after consuming the first $k$.
- Encode $z_0$ to $\hat{z}_0$ via $p(z)$.
- Encode $x_0$ to $\hat{z}_0$ via $p(x \mid z=z_0)$.
- Send $(\hat{x}_0,\hat{z}_0,\hat{y}_0\up{>k})$, for a total transmission length of $-\lg p(x_0,z_0)+\ell(\hat{y}_0)-k$ bits.
The receiver protocol:
- Receive $(\hat{x}_0,\hat{z}_0,\hat{y}_0\up{>k})$.
- Decode $\hat{z}_0$ to $z_0$ using $p(z)$.
- Decode $\hat{x}_0$ to $x_0$ using $p(x\mid z=z_0)$.
- Encode $z_0$ to $\hat{y}_0\up{\leq k}$ using $p(z\mid x=x_0)$.
- Decode $\hat{y}_0=\hat{y}_0\up{\leq k}\hat{y}_0\up{>k}$ to $y_0$ using the corresponding $y$-decoder.
We have $-\lg p(x_0,z_0)-k \approx -\lg p(x_0,z_0) + \lg p(z_0\mid x=x_0) = -\lg(x_0)$. So for the cost of $-\lg p(x_0,z_0)+\ell(\hat{y}_0)-k=-\lg p(x_0) + \ell(\hat{y}_0) \approx \ell(\hat{x}_0) + \ell(\hat{y}_0)$ bits sent, we’ve transmitted $(x_0,y_0,z_0)$, as if $z_0$ did not consume any bits at all.
Variational Lossy Autoencoder
https://arxiv.org/abs/1611.02731
I was trying to understand the argument in Variational Lossy Autoencoder (Chen 2017), which invokes bits-back coding (BBC). Now that I understand BBC, I feel BBC is not actually necessary for the argument at all.
Recasting the argument without BBC:
Recall from my Variational Autoencoders that to fit a VAE we minimize the proxy loss (surprisal upper bound)
$$\begin{align}
\U(\t,\p;\ x) &= \E_{z\sim q_\p(z \mid x)}\brak{ \log \par{\frac{q_\p(z \mid x)}{p_\t(x,z)}} } \\
&= \log(1/p_\t(x)) + \kl{q_\p(z\mid x)}{p_\t(z\mid x)}\,.
\end{align}$$
Suppose the decoder $p_\t(x\mid z)$ has the ability to ignore $z$ so that $p_\t(x\mid z) = p_\t(x)$ for all $x$ and $z$. Then $p_\t(z\mid x) = p_\t(z)$. Also suppose $q_\p(z\mid x) = p_\t(z)$ is a possibility for $\p$, then $\kl{q_\p(z\mid x)}{p_\t(z\mid x)}=0$. Then $\U(\t,\p;\ x)=\log(1/p_\t(x))$.
Now suppose that the decoder $p_\t(x\mid z)$ is so expressive that, for any data distribution $p^*(x)$, there exists $\t$ so that $p_\t(x\mid z)\approx p^*(x)$ for all $x,z$ (no dependence on $z$) with some small approximation error (i.e. we are using a universal function approximator). Then we can minimize $\U(\t,\p;\ x)$ and have $q_\p(z\mid x) = p_\t(x\mid z) = p_\t(x)$ by maximizing the $p_\t(x\mid z)$ data-likelihood independent of $z$. However, this need not be the only solution, and there is no a priori reason for the solution to arrive at $q_\p(z\mid x) = p_\t(x\mid z) = p_\t(x)$.
However, if we suppose that $q_\p(z\mid x)$ is not very expressive so that for most target posteriors $p_\t(z\mid x)$ there is no parameter $\p$ s.t. $\kl{q_\p(z\mid x)}{p_\t(z\mid x)}$ is small, and like before we suppose the decoder $p_\t(x\mid z)$ is arbitrarily expressive, then the optimal solution to this optimization finds $p_\t(x\mid z)$ that ignores $z$ so that $q_\p(z\mid x) = p_\t(x\mid z) = p_\t(x)$, thereby minimizing both the KL and $\log(1/p_\t(x))$.
Chen writes that they believe $\kl{q_\p(z\mid x)}{p_\t(z\mid x)}$ will remain irreducibly large in VAEs in the foreseeable future (circa 2017), even with more expressive approximate posteriors $q_\p(z\mid x)$, e.g. using normalizing flows (Kingma 2017). Therefore, Chen argues that $p_\t(x\mid z)$ should be made less expressive (e.g. by only allowing local receptive fields conditioned on global information $z$) to force $z$ to be utilized, i.e. the gain in $\log(1/p_\t(x))$ will outweigh the KL cost when $p_\t(x\mid z)$ must depend on $z$.