Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference - 2017

Details

Title : Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference Author(s): Roeder, Geoffrey and Wu, Yuhuai and Duvenaud, David Link(s) : http://arxiv.org/abs/1703.09194

Rough Notes

The reparamterization trick removes the dependence between samples of latent variables \(z\) and the variational parameters \(\phi\) giving better ELBO gradient estimates - which still contain the score function which is a special case of the REINFORCE estimator. The main claim of this paper is that this term can be removed, giving a lower-variance gradient estimator in many cases.

The joint distribution \(p(\mathbf{x},\mathbf{z})\) factorized as \(p(\mathbf{x}|\mathbf{z}),p(\mathbf{z})\) has the ELBO \(\mathcal{L}(\phi)\) (with variational approximation \(q_\phi(\mathbf{z}|\mathbf{x})\) approximating the posterior) as which can be written in 3 ways:

  • \(\mathbf{E}_{\mathbf{z}\sim q}[\log p(\mathbf{x}|\mathbf{z}) + \log p(\mathbf{z}) - \log q_\phi(\mathbf{z}|\mathbf{x})]\) (1)
  • \(\mathbf{E}_{\mathbf{z}\sim q}[\log p(\mathbf{x}|\mathbf{z}) + \log p(\mathbf{z})] + \mathbb{H}[q_\phi]\) (2)
  • \(\mathbf{E}_{\mathbf{z} \sim q}[\log p(\mathbf{x}|\mathbf{z})] - KL(q_\phi(\mathbf{z}|\mathbf{x})||p(\mathbf{z}))\) (3)

One can notice the entropy and KL divergence terms imply choosing the appropriate priors and variational approximations can give closed form expressions for parts of (2,3), however, this paper makes the claim that (1) is sometimes better even in these cases of analytic expressions due to it having lower variance. Specifically, when \(q_\phi(\mathbf{z}|\mathbf{x})=p(\mathbf{z}|\mathbf{x})\), the variance of the full Monte Carlo estimator \(\hat{\mathcal{L}}_{MC}\) of (1) is exactly 0, because \(\hat{\mathcal{L}}_{MC} = \log p(\mathbf{x},\mathbf{z}) - \log q_\phi(\mathbf{z}|\mathbf{x}) = \log p(\mathbf{z}|\mathbf{x}) + \log p(\mathbf{x}) - \log p(\mathbf{z}|\mathbf{x}) = \log p(\mathbf{x})\).

Hence, if we believe the variational approximation is good, the estimator using (1) has lower variance.

When it comes to the gradients, the gradient of the fully Monte-Carlo estimator in (1) with respect to the variational parameters \(\phi\) is not 0 even when the variational approximation is exact. First, note that the reparametrization trick rewrites samples \(\mathbf{z}\) as some function of a random sample \(\epsilon\) which is independnet of the variational parameters \(\phi\), meaning \(\mathbf{z} = t(\epsilon, \phi)\). In this, the total derivative of the integrand in (1) with respect to \(\phi\) is

\[ \nabla_\phi [\log p(\mathbf{x}|\mathbf{z}) + \log p(\mathbf{z}) - \log q_\phi(t(\epsilon, \phi)|\mathbf{x})] \] \[ \nabla_\phi [\log p(\mathbf{z}|\mathbf{x}) + \log p(\mathbf{z}) - \log q_\phi(\mathbf{z}|\mathbf{x})] \] \[ \nabla_\mathbf{z}[\log p(\mathbf{z}|\mathbf{x}) - \log q_\phi (\mathbf{z}|\mathbf{x})]\nabla_\phi t(\epsilon, \phi) - \nabla_\phi \log q_\phi(\mathbf{z}|\mathbf{x}) \]

The last term above is the score function, the rest is the path derivative. We see that the path derivative is 0 if the variational approximation is exact, but the score function term is not necessarily zero (nor constant) for some \(\mathbf{z}\) meaning there is some variance even when the variational approximation is exact.

Fortunately, the score function is 0 in expectation, hence we can remove the term when using stochastic gradient based optimization methods, and the variational approximation approaches the true posterior, the (path derivative) gradient of the ELBO above has variance approaching 0.

However, this path derivative estimator may have higher variance when the score function is positively correlated with the remaining terms in the total derivative estimator. (#ASK)

During implementation, we would need to stop gradients for \(\phi\) when evaluating the ELBO in (1), and then compute the gradient with respect to \(\phi\) of this ELBO.

Emacs 29.4 (Org mode 9.6.15)