Z-Forcing: Training Stochastic Recurrent Networks
Intro
A new training procedure for recurrent VAEs is proposed. Recall that for VAEs, we model a joint distribution over observations $x$ and latent variables $z$, and assume that $z$ is involved in the generation of $x$. This distribution is parameterized by $\theta$. Maximizing the marginal log-likelihood $p_{\theta}(x)$ wrt $\theta$ is intractable bc it requires integrating over $z$. Instead, introduce a variational distribution $q_{\phi}(z|x)$ and maximize a lower bound on the marginal log-likelihood–the ELBO.
Stochastic recurrent networks
When applying VAEs to sequences, it has been proposed to use recurrent networks for the recognition network (aka inference network aka variation posterior) and the generation network (aka decoder aka conditional probability of the next observation given previous observations and latents). These probabilistic models can be autoregressive (in this paper, they use LSTMs with MLPs for predicting the parameters of Gaussian distributions). It is common to model these conditional distributions with Gaussians for continuous variables or categoricals for discrete variables.
Usually, the prior over latent variables is also learned with a parametric model.
If I’m not mistaken, learning the parameters of these parametric models with a training data set, and the using them at test time for fast inference is referred to as amortized variational inference, which appears to have correlaries in our cognition.
Z-forcing
Strong autoregressive decoders overpower the latent variables $z$, preventing the CPD from learning complex multi-modal distributions. To mitigate this, they introduce an auxiliary cost to the training objective. An extra parametric model is introduced, $p_{\eta}(b | z)$, that “forces” the latents to be predictive of the hidden states $b$ of the “backward network” (the inference network).
Experiments
They validate the approach on speech modeling (TIMIT, Blizzard) and language modeling. The metric is average LL. On Seqeuential MNIST, z-forcing is competitive with “deeper” recurrent generative models like PixelRNN.
Takeaways
It’s always a consideration as to whether increasing the complexity of an approach (adding an extra network and auxiliary cost) is worth the effort vs. simpler approaches that can get almost the same performance. The results on TIMIT and Blizzard are pretty convincing. The authors also suggest incorporating the auxiliary loss with PixelRNN/CNN in future work.