Here we’ll look at how latent-variable models can be manipulated to yield an unbiased estimator of the marginal log-likelihood. We’ll implement the recently proposed SUMO (Stochastically Unbiased Marginalization Objective) in Jax, highlighting how Jax differs from standard Autodiff frameworks in some major aspects. Section 1 concerns the derivation of the SUMO objective. Section 2 walks through the Jax implementation, while section 3 contains the obligatory toy problem. This post did turn out longer than expected, so you may want to take a breather between sections. If you are only interested in the practical part, skip to Section 2 or check out the associated repository on Github.
The SUMO (Luo et. al., 2020, [1]) estimator makes use of the importance-weighted ELBO (IW-ELBO), examined in an earlier post. Briefly, we are interested in finding a lower bound of the log-marginal likelihood $\log p(x)$. Under certain conditions, the tightness of this bound scales with the variance of some unbiased estimator $R$ of the marginal probability. $R$ depends on the latent variables through some proposal distribution $q(z)$: $\E{q}{R} = p(x)$. The regular VAE takes $R = p(x,z)/q(z)$, forming the normal ELBO $\E{q}{\log p(x,z) - q(z)}$, but a natural way to reduce the sample variance is through the $K$-sample mean, forming the IW-ELBO:
\[\begin{align} R_K &= \frac{1}{K} \sum_{k=1}^K \frac{p(x, z_k)}{q(z_k)}, \quad z_k \sim q \\ \textrm{IW-ELBO}_K(x) &= \log \frac{1}{K} \sum_{k=1}^K \frac{p(x, z_k)}{q(z_k)}, \quad z_k \sim q \end{align}\]Here to be consistent with the notation in [1], we define this without the expectation under $q$. One can show [2] that the IW-ELBO is monotonically increasing in expectation and converges to the true log-marginal likelihood as the variance vanishes as $K \rightarrow \infty$:
\[\begin{equation} \E{q}{\textrm{IW-ELBO}_{K+1}(x)} \geq \E{q}{\textrm{IW-ELBO}_{K}(x)}, \quad \lim_{K \rightarrow \infty}\E{q}{\textrm{IW-ELBO}_{K}(x)} = \log p(x) \end{equation}\]So it appears we can only obtain an unbiased estimator of $\log p(x)$ after doing an infinite amount of computation.
SUMO circumvents this difficulty by randomized truncation of a specially constructed infinite series - falling under the class of the so-called ‘Russian roulette estimators’ [3]. Where exactly the specific form of the estimator comes from in the paper is a bit mysterious, so we’ll motivate it heuristically below. For a general discussion, let $\Delta_k$ be the $k$th term of some series that converges to some quantity of interest $p^* = \lim_{K \rightarrow \infty} \left( \sum_{k=1}^K \Delta_k\right)$. We want to truncate the series at some finite stopping time $K$, while retaining unbiasedness.
This is possible if we treat $K$ as a discrete random variable with mass function $p(K)$ with support over the natural numbers. The core idea is to draw $K \sim p(K)$, evaluate the first $K$ terms in the sum (each appropriately weighted to account for truncation), and stop. Taking the expectation of the partial sum w.r.t. $p(K)$ and shuffling things around in the resulting double sum1:
\[\begin{align} \E{p(K)}{\sum_{k=1}^K \Delta_k(x)} &= \sum_{K=1}^{\infty} p(K) \sum_{k=1}^K \Delta_k(x) \\ &= \sum_{K=1}^{\infty} \sum_{k=1}^{\infty} p(K)\Delta_k(x) \mathbb{I}(K \geq k)\\ &= \sum_{k=1}^{\infty} \Delta_k(x) \sum_{K=1}^{\infty} p(K)\mathbb{I}(K \geq k) \\ &= \sum_{k=1}^{\infty} \Delta_k(x) \mathbb{P}(K \geq k) \end{align}\]Where $\mathbb{P}(K \geq k)$ is the complement of the CDF. This suggests we can obtain an unbiased estimator of the infinite series by first sampling the stopping time $K$, and then evaluating $K$ appropriately weighted terms2:
\[\begin{equation} \sum_{k=1}^{\infty} \Delta_k(x) = \E{K \sim p(K)}{\sum_{k=1}^K \frac{\Delta_k(x)}{\mathbb{P}(K \geq k)}} \end{equation}\]To connect this back to lower bounds on $\log p(x)$, note that the following telescoping series converges to $\log p(x)$ in expectation in the limit:
\[\begin{equation} \log p(x) = \E{q}{\textrm{IW-ELBO}_1(x) + \lim_{K \rightarrow \infty} \left[ \sum_{k=1}^K \left( \textrm{IW-ELBO}_{k+1}(x) - \textrm{IW-ELBO}_k(x)\right)\right]} \end{equation}\]Let $\hat{\Delta}_k(x)= \E{q}{\textrm{IW-ELBO}_{k+1}(x)} - \E{q}{\textrm{IW-ELBO}_k(x)}$ then apply the Russian roulette estimator to the sequence of partial sums:
\[\begin{align} \log p(x) &= \E{q}{\textrm{IW-ELBO}_1(x)} + \sum_{k=1}^{\infty} \hat{\Delta}_k(x) \\ &= \E{q}{\textrm{IW-ELBO}_1(x)} + \E{K \sim p(K)}{\sum_{k=1}^K \frac{\hat{\Delta}_k(x)}{\mathbb{P}(K \geq k)}} \\ &= \E{K \sim p(K)}{\E{q}{\textrm{IW-ELBO}_1(x) + \sum_{k=1}^K \frac{\textrm{IW-ELBO}_{k+1}(x) - \textrm{IW-ELBO}_k(x)}{\mathbb{P}(K \geq k)}}} \end{align}\]Luo et. al. [1] define the interior of the expectation to be the Stochastically Unbiased Marginalization Objective (SUMO) estimator, which is equal to the marginal log-probability in expectation over $K$ and $q(z)$:
\[\begin{equation} \textrm{SUMO}(x) = \textrm{IW-ELBO}_1(x) + \sum_{k=1}^K \frac{\textrm{IW-ELBO}_{k+1}(x) - \textrm{IW-ELBO}_k(x)}{\mathbb{P}(K \geq k)}, \quad z_k \sim q(z), K \sim p(K) \end{equation}\]Defining $\Delta_k(x)= \textrm{IW-ELBO}_{k+1}(x) - \textrm{IW-ELBO}_k(x)$, we arrive at the slightly more palatable form:
\[\begin{equation} \textrm{SUMO}(x) = \log \frac{p(x,z_1)}{q(z_1)} + \sum_{k=1}^K \frac{\Delta_k(x)}{\mathbb{P}(K \geq k)}, \quad z_k \sim q(z), K \sim p(K) \end{equation}\]As Lyne et. al. [3] (Appendix A.2.) demonstrate, the variance of the generic Russian roulette estimator can be potentially infinite - there is a tradeoff between computation time (controlled by the expected value of the stopping time and hence the form of the PMF $p(K)$), and the variance of the resulting estimator. Luo et. al. [1] propose that the variance of SUMO can be minimized through optimization of the parameters $\phi$ of an amortized encoder $q_{\phi}(z \vert x)$, noting that:
\[\begin{align} \nabla_{\phi}\mathbb{V}\left[\textrm{SUMO}(x)\right] &= \nabla_{\phi}\E{q_{\phi}, \, p(K)}{\textrm{SUMO}(x)^2} - \nabla_{\phi}\left(\E{q_{\phi}, \, p(K)}{\textrm{SUMO}(x)}\right)^2 \\ &= \nabla_{\phi}\E{q_{\phi}, \, p(K)}{\textrm{SUMO}(x)^2} \\ &= \E{p(\epsilon), \, p(K)}{\nabla_{\phi}\textrm{SUMO}(x)^2} \end{align}\]Where $\nabla_{\phi}\left(\E{q_{\phi}, \, p(K)}{\textrm{SUMO}(x)}\right)^2= \nabla_{\phi} \left(\log p(x)\right)^2 = 0$ by virtue of SUMO being an unbiased estimator of $\log p(x)$, and we apply reparameterization w.r.t. $z$ in the last line to pull the gradient into the expectation. So here the latent variables are doing double duty - used to obtain a sequence of lower bounds on $\log p(x)$, and to minimize the variance of the resulting estimator.
Jax is a machine learning framework described as the spiritual successor of autograd. Here’s the elevator pitch: numpy
+ automatic differentiation, plus:
jax.grad
.jax.vmap
.jax.pmap
.jax.jit
.All the above primitives can be arbitrarily composed as well. From my experience, Jax was easy to come to grips with, sharing much of its minimalist interface with numpy
/scipy
. Furthermore, it’s blazingly fast out-of-the-box - switching from CPU to GPU/TPU requires almost no work, and it is straightforward to fuse your code into performant kernels or parallelize across multiple devices.
However, Jax adopts a functional programming, as opposed to an object-oriented model, in order for function transformations to work nicely. This imposes certain constraints that may be unnatural to users accustomed to machine learning frameworks written in Python (Torch/Tensorflow). Python functions subject to transformation/compilation in Jax must be functionally pure: all the input data is passed through the function parameters, and all the results are output through the returned function values - there should be no dependence on global state or additional ‘side-effects’ beyond the return values. This can be somewhat vexing to those of us accustomed to calling loss.backward()
in Torch - in fact, there’s a very long list of common gotchas in Jax associated with this functionally pure constraint, but we’ll see that this is not too big a barrier once you get used to it.
This section is quite long, but implementing and optimizing the SUMO estimator will force us to deal with a surprisingly large cross-section of Jax, together with a large volume of gotchas (at the time of writing this, as of v0.1.69).
First we need a way to get the importance-weighted ELBO terms that appear in the Russian roulette estimator. This part is a modified version of the full discussion in an earlier post about the importance-weighted ELBO. We’ll start off with the most numpy
/torch
-like part of the code, defining convenience functions to evaluate the log-density of a diagonal-covariance Gaussian and to sample from a diagonal Gaussian using reparameterization.
Note in Jax we explicitly pass a PRNG state, represented above by rng
into any function where randomness is required, such as when sampling $\mathcal{N}(0,\mathbb{I})$ for reparameterization. The short reason for this is because standard PRNG silently updates the state used to generate pseudo-randomness, resulting in a hidden ‘side effect’ that is problematic for Jax’s transformation and compilation functions, which only work on Python functions which are functionally pure. Instead Jax explicitly passes the PRNG state as an argument and splits the state as required whenever we require more instances of PRNG. Check out the documentation for more details.
Here we are performing amortized inference - instead of learning variational parameters per-datapoint, we train separate neural networks to output the parameters of the approximate posterior $q(z \vert x)$ and conditional model $p(x \vert z)$. These are called the encoder/decoder, respectively. The variational parameters are then the collective parameters of these networks - shared between all datapoints. Using a standard Gaussian for the prior $p(z)$ and diagonal-covariance Gaussian distribution for the conditional and approximate posterior, the necessary samples to compute the IW-ELBO are obtained through reparameterization - see Section 3.2. of this earlier post for more details.
\[\begin{align} (\mu_z, \Sigma_z) &= \textrm{Encoder}(x) \\ z \sim q(z \vert x) \quad &\equiv \quad \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{1}), \; z = \mu_z + \Sigma_z^{1/2} \epsilon \\ (\mu_x, \Sigma_x) &= \textrm{Decoder}(z) \\ x \sim p(x \vert z) \quad &\equiv \quad \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{1}), \; x = \mu_x + \Sigma_x^{1/2} \epsilon \\ \end{align}\]Due to the functional programming paradigm, defining the encoder/decoder networks in Jax
is slightly different from Torch/Tensorflow, etc. The core idea is that the constructor for each layer in the network expects an argument specifying the output dimension and returns:
init_fun
, which initializes the layer parameters.apply_fun
, which defines the forward pass of the layer.The Jax-native stax
library provides a straightforward way to compose layers into a single model, wrapping them together according to a user-defined structure, and returning an init_fun
which initializes parameters for all layers in the model and an apply_fun
which defines the forward pass of the model as the composition of the forward computations for each layer. See the documentation for more details. Here both the amortized encoder/decoder are defined as single-hidden-layer networks, where the output of each is split into two pieces - the mean and logarithm of the diagonal of the covariance matrix over the respective space. We’ll refer to the latter as the log-variance.
The initialization functions accept:
They return:
output_shape
: A tuple specifying the dimensions of the network output.init_params
: A tuple holding the initial network parameters.Next we define the log of the summand of the IW-ELBO estimator - i.e. $\log \frac{p(x, z_k)}{q(z_k \vert x)}, z_k \sim q$:
Note when applying the forward pass through the encoder and decoder, we have to explicitly pass the parameters of each network as arguments - this is in line with Jax’s functional programming interface - functions should only rely on their arguments, not global state. If you don’t like this functional interface, the good news is that there are two projects - flax and trax which provide a more object-oriented interface for network construction (these are a lot more readable/hackable, IMO).
The manner in which we collect the summands for different samples $z_k \sim q(z \vert x)$ represents another significant point of departure from other autodiff frameworks. We have to arrange the summands into the estimator:
\[\begin{align} \textrm{IW-ELBO}_K(x) &= \log \frac{1}{K} \sum_{k=1}^K \frac{p(x, z_k)}{q(z_k \vert x)}, \quad z_k \sim q \end{align}\]First, we split the rng
states into num_samples
new states, one for each importance sample. Next we make use of jax.vmap
- one of the ‘killer features’ of Jax. This function vectorizes a function over a given axis of the input, so the function is evaluated in parallel across the given axis. A trivial example of this would be representing a matrix-vector operation as a vectorized form of dot products between each row of the matrix and the given vector, but vmap
allows us to (in most cases) vectorize more complicated functions where it is not obvious how to manually ‘batch’ the computation. For brevity, you can read more about vmap
via the documentation or this previous post.
The return result of vmap
is the vectorized version of iw_estimator
, which we call to get a vector, iw_log_summand
, representing the log-summands $\log p(x \vert z_k) + \log p(z_k) - \log q(z_k \vert x)$ for each importance sample $z_k$. This will be a one-dimensional array with number of elements given by the number of importance samples $K$. Finally for numerical stability we take the $\text{logsumexp}$ of the log-summands and average this to give the final IW-ELBO(K). Note that parallelization was near-automatic here for the computation of the different importance sample summands using vmap
- we didn’t have to keep track of a cumbersome batch dimension anywhere.
The SUMO estimator is pleasingly simple to implement, here are the basic steps:
We already did steps 2 and 3 above, let’s focus on step 1 - the authors propose the following tail distribution for the number of terms $\mathcal{K}$ to evaluate in the series: $\mathbb{P}(\mathcal{K} \geq k) = \frac{1}{k}$. This has CDF $F_{\mathcal{K}}(k) = \mathbb{P}(\mathcal{K} \leq k) = 1 - \frac{1}{k+1}$ and so by the inverse CDF transform $\left \lfloor{\frac{u}{1-u}}\right \rfloor \sim p(\mathcal{K})$, where $u \sim \text{Uniform}[0,1]$3. To reduce the possibility of sampling large stopping times, they flatten the tails of the CDF by heavily penalizing the probability of $k > 80$:
We can tie this all together in the following function:
jit
TechnicalitiesWhat’s this business with the decorator? When Jax analyzes a function to compile using jax.jit
, it passes an abstract value in lieu of an actual array for each argument, called a tracer value. These arrays are used to characterize what the function does on abstract Python objects - in most cases an array of floats - in order to be compiled using XLA. Tracer values share some functionality with ndarray
- you can call .shape
on these, for example - but forcing these to commit to specific numerical values will throw an error. Their only purpose is to characterize the behaviour of the function on a restricted set of possible inputs in order to be jit
-compiled.
Tracer values are passed for each argument to the jit
-transformed function, with the exception of those arguments identified as static_argnums
by jit
. These will remain regular values and treated as constant during compilation - with the caveat that when these arguments change, the function must be jit
-compiled again. Hence, as the name suggests, static_argnums
should not change too much to avoid incurring overhead from repeated compilation. Functions such as random.split
and lax.iota
do not accept tracer values as arguments, so we are forced to keep the sampled stopping time $K$ and minimum number of terms $m$ as regular Python values. This means, unfortunately, that each batch is restricted to use the same value of $K$ if we want to use jit
. To counteract this, we can use a small batch size. For a more detailed discussion, you can read the Jax FAQ and the discussion within this Github issue.
Here we’ll look at an interesting application of SUMO to approximate sampling via density matching. This is the ‘flip’ side of the coin to density estimation. Here we are given some complex target distribution $p^*(x) \propto \exp\left(-U(x)\right)$, where $U(x)$ is some energy function. Often no sampling procedure exists and we would like to generate samples from the target distribution $p^*(x)$ using the learned model $p_{\theta}$ as a surrogate.
With the analytical form of the target density the reverse-KL objective can be efficiently optimized for this purpose:
\[\begin{align} \mathcal{L}(\theta) = \kl{p_{\theta}(x)}{p^*(x)} &= \E{p_{\theta}(x)}{\log \frac{p_{\theta}(x)}{p^*(x)}} \\ &= \E{p_{\theta}(x)}{\log p_{\theta}(x) - \log p^*(x)} \\ &= -\mathbb{H}(p_{\theta}) + \E{p_{\theta}(x)}{U(x)} + \text{const.} \end{align}\]Sampling from the surrogate $p_{\theta}$ may be achieved by sampling from the prior or approximate posterior in latent space and subsequently sampling from the conditional model:
\[\begin{equation} x \sim p_{\theta}(x) \equiv z \sim q_{\lambda}(z), \: \: x \sim p_{\theta}(x \vert z) \end{equation}\]Note the presence of an entropy maximization term in the reverse-KL objective - this translates to minimization of $\log p_{\theta}(x)$ for samples drawn from the surrogate $p_{\theta}$. If we only have access to a lower bound of $\log p_{\theta}$, as in the case of the standard ELBO, then we are minimizing a lower bound - of course, this is in the morally wrong direction. A decrease in the objective can be achieved by deterioration of the tightness of the bound - i.e. an increase in bias rather than actual minimization of $\log p_{\theta}(x)$. The target distribution $p^*$ in this case is defined as (originally proposed by Neal, 2003 [4]):
\[p^*(x_0,x_1) = \mathcal{N}\left(x_0 \vert 0, \exp(2 x_1)\right) \cdot \mathcal{N}(x_1 \vert 0, 1.35^2)\]The distribution is shaped like a funnel (see Section 3.3.), with a sharp neck in $x_0$ as the variance is exponential in $x_1$. Note we can straightforwardly sample from $p^*$ by reparameterization: $x_0 = \exp(x_1) \cdot \epsilon$, $x_1 = 1.35 \cdot \epsilon$, where $\epsilon \sim \mathcal{N}(0,1)$, but we’ll see how SUMO fares.
First we’ll define a function that allows us to sample from the surrogate model. Here we sample from the latent prior and then pass this through the amortized decoder to yield samples $x \sim p_{\theta}$.
Next we specify a function, reverse_kl
, that evaluates the reverse-KL objective for a single example between the target log_prob
and our model estimate log_px_estimator
. The function batch_reverse_kl
applies the vectorizing operator to the single-sample reverse_kl
function so we can operate over a batch of samples. Again, the objective functions are made functionally pure by requiring the model parameters as arguments.
Jax natively provides a suite of optimizers for use. In order for the optimization procedure to be functionally pure, optimizers in Jax are defined as an (init_fun, update_fun, get_params)
triple of functions:
init_fun
sets the initial optimizer state from the initial values of the model parameters:update_fun
updates the internal state of the optimizer, based on the gradients of the objective function with respect to the model parameters:get_params
returns the updated model parameters extracted from the internal optimizer state.The end result is that the model parameters are passed into their respective optimizers, and the updated model parameters are returned after taking some gradient step. We’ll use the ubiquitious Adam optimizer, defining separate optimizers for the amortized encoder and decoder.
Armed with this knowledge, we can package together gradient calculation and parameter updates for the encoder and decoder into separate functions. Note below the parameters of the encoder are optimized to minimize the variance of the SUMO estimator.
Note when taking gradients the argnums
argument to jax.grad
indicates the argument that the gradient of the function should be computed with respect to. Finally, we write a succinct training loop4. A callback function is included to plot the learned densities periodically:
If you are interested in the full codebase, you may find it in the repository associated with this post.
Approximating the log-evidence of our surrogate model using SUMO and minimizing the reverse-KL objective allows our modest approximate sampler to learn the contours of the density; even handling the problematic neck of the funnel relatively well, excepting very low values in $y$. The below plots show the contours of the learned density with samples overlaid in blue. Training using SUMO is significantly more stable than using the IWELBO estimator for similar expected values of compute - for low values ok $K$, the KL-objective optimized under the IWELBO model eventually turns negative and veers toward instability, indicating the model is optimizing the bias rather than the actual objective.
This was just a toy application for an unbiased estimator of $\log p_{\theta}(x)$. In the paper [1], they apply this to more useful problems such as classic density estimation and entropy regularization in reinforcement learning.
Jax has some very interesting design choices. In particular, automatic vectorization through vmap
helps with mental overhead a lot once you don’t have to carry around a cumbersome batch dimension. Personally, there’s not enough of a QoL improvement at the moment over Torch for me to port my existing libraries to Jax, but I do enjoy tinkering with it and its minimalism lends itself well to quick proof of concept sketches (provided you are well-acquainted with the sharp edges). One nice use I found for it is a quick way to check correctness of Jacobian determinants computed by hand. Next in this unofficial series we might implement some interesting, possibly continuous, normalizing flow variants, which could be really easy, or troublesome, depending on the behaviour of jax.jit
.
[1] Luo, Yucen, et. al. SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models. International Conference on Learning Representations (2020). https://openreview.net/forum?id=SylkYeHtwr
[2] Burda, Yuri and Grosse, Roger and Salakhutdinov, Ruslan. Importance Weighted Autoencoders. International Conference on Learning Representations (2016).
[3] Lyne, Anne-Marie, et.al. On Russian Roulette Estimates for Bayesian Inference with Doubly-Intractable Likelihoods Statistical Science. Institute of Mathematical Statistics, (2015).
[4] Neal, Radford M. Slice Sampling. Annals of Statistics 31 (3): 705–67 (2003).
You made it!
We are abandoning any pretense of formality here, ignoring all prerequisite conditions on $p(K)$ and $\Delta_k$ - see Appendix A.2. of [3] if you have aversion to mathematical sin. ↩
I realize this doesn’t look quite kosher, but do the same shuffling around of terms in the original derivation with the new estimator to convince yourself. ↩
To see this, note $F_{\mathcal{K}}\left(F^{-1}_{\mathcal{K}}(k)\right) = 1 - \frac{1}{F^{-1}_{\mathcal{K}}(k)+1} = k \implies F^{-1}_{\mathcal{K}}(k) = \frac{k}{1-k}$, then take the floor to map to an integer. ↩
Yes, the stopping time K should really be passed as an argument all the way through to the SUMO function, but we omit this here for readability. ↩