Quantcast
Channel: Probability – mlcvGru
Viewing all articles
Browse latest Browse all 30

Variational Autoencoders 3: Training, Inference and comparison with other models

$
0
0

Variational Autoencoders 1: Overview
Variational Autoencoders 2: Maths
Variational Autoencoders 3: Training, Inference and comparison with other models

Recalling that the backbone of VAEs is the following equation:

\log P\left(X\right) - \mathcal{D}\left[Q\left(z\vert X\right)\vert\vert P\left(z\vert X\right)\right] = E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] - \mathcal{D}\left[Q\left(z\vert X\right) \vert\vert P\left(z\right)\right]

In order to use gradient descent for the right hand side, we need a tractable way to compute it:

  • The first part E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] is tricky, because that requires passing multiple samples of z through f in order to have a good approximation for the expectation (and this is expensive). However, we can just take one sample of z, then pass it through f and use it as an estimation for E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] . Eventually we are doing stochastic gradient descent over different sample X in the training set anyway.
  • The second part \mathcal{D}\left[Q\left(z\vert X\right) \vert\vert P\left(z\right)\right] is even more tricky. By design, we fix P\left(z\right) to be the standard normal distribution \mathcal{N}\left(0,I\right) (read part 1 to know why). Therefore, we need a way to parameterize Q\left(z\vert X\right) so that the KL divergence is tractable.

Here comes perhaps the most important approximation of VAEs. Since P\left(z\right) is standard Gaussian, it is convenient to have Q\left(z\vert X\right) also Gaussian. One popular way to parameterize Q is to make it also Gaussian with mean \mu\left(X\right) and diagonal covariance \sigma\left(X\right)I, i.e. Q\left(z\vert X\right) = \mathcal{N}\left(z;\mu\left(X\right), \sigma\left(X\right)I\right), where \mu\left(X\right) and \sigma\left(X\right) are two vectors computed by a neural network. This is the original formulation of VAEs in section 3 of this paper.

This parameterization is preferred because the KL divergence now becomes closed-form:

\displaystyle \mathcal{D}\left[\mathcal{N}\left(\mu\left(X\right), \sigma\left(X\right)I\right)\vert\vert P\left(z\right)\right] = \frac{1}{2}\left[\left(\sigma\left(X\right)\right)^T\left(\sigma\left(X\right)\right) +\left(\mu\left(X\right)\right)^T\left(\mu\left(X\right)\right) - k - \log \det \left(\sigma\left(X\right)I\right) \right]

Although this looks like magic, but it is quite natural if you apply the definition of KL divergence on two normal distributions. Doing so will teach you a bit of calculus.

So we have all the ingredients. We use a feedforward net to predict \mu\left(X\right) and \sigma\left(X\right) given an input sample X draw from the training set. With those vectors, we can compute the KL divergence and \log P\left(X\vert z\right), which, in term of optimization, will translate into something similar to \Vert X - f\left(z\right)\Vert^2.

It is worth to pause here for a moment and see what we just did. Basically we used a constrained Gaussian (with diagonal covariance matrix) to parameterize Q. Moreover, by using \Vert X - f\left(z\right)\Vert^2 for one of the training criteria, we implicitly assume P\left(X\vert z\right) to be also Gaussian. So although the maths that lead to VAEs are generic and beautiful, at the end of the day, to make things tractable, we ended up using those severe approximations. Whether those approximations are good enough totally depend on practical applications.

There is an important detail though. Once we have \mu\left(X\right) and \sigma\left(X\right) from the encoder, we will need to sample z from a Gaussian distribution parameterized by those vectors. z is needed for the decoder to reconstruct \hat{X}, which will then be optimized to be as close to X as possible via gradient descent. Unfortunately, the “sample” step is not differentiable, therefore we will need a trick call reparameterization, where we don’t sample z directly from \mathcal{N}\left(\mu\left(X\right), \sigma\left(X\right)\right), but first sample z' from \mathcal{N}\left(0, I\right), and then compute z = \mu\left(X\right) + \mu\left(X\right)Iz'. This will make the whole computation differentiable and we can apply gradient descent as usual.

The cool thing is during inference, you won’t need the encoder to compute \mu\left(X\right) and \sigma\left(X\right) at all! Remember that during training, we try to pull Q to be close to P\left(z\right) (which is standard normal), so during inference, we can just inject \epsilon \sim \mathcal{N}\left(0, I\right) directly into the decoder and get a sample of X. This is how we can leverage the power of “generation” from VAEs.

There are various extensions to VAEs like Conditional VAEs and so on, but once you understand the basic, everything else is just nuts and bolts.

To sum up the series, this is the conceptual graph of VAEs during training, compared to some other models. Of course there are many details in those graphs that are left out, but you should get a rough idea about how they work.

vae

In the case of VAEs, I added the additional cost term in blue to highlight it. The cost term for other models, except GANs, are the usual L2 norm \Vert X - \hat{X}\Vert^2.

GSN is an extension to Denoising Autoencoder with explicit hidden variables, however that requires to form a fairly complicated Markov Chain. We may have another post  for it.

With this diagram, hopefully you will see how lame GAN is. It is even simpler than the humble RBM. However, the simplicity of GANs makes it so powerful, while the complexity of VAE makes it quite an effort just to understand. Moreover, VAEs make quite a few severe approximation, which might explain why samples generated from VAEs are far less realistic than those from GANs.

That’s quite enough for now. Next time we will switch to another topic I’ve been looking into recently.



Viewing all articles
Browse latest Browse all 30

Trending Articles