Generative Adversarial Networks, GANs


Before text to image generation and Diffusion models hit the world by storm, Generative Adversarial Networks (GAN) were one of the most hyped generative models out there. The hype around GANs was not there without a reason. The quality of data generated by GANs was very impressive at that time. Generated images were crisp, didn’t lack high frequency details, and had a decent amount of variety (see StyleGan). Moreover, later people discovered methods enabling the usage of GANs as powerful image editing tools (see StyleGan Inversion). The generation could be conditioned well enough to produce certain desired results and much more (see pix2pix). To this date, GANs can hold their own in restricted domain scenarios and have their positives compared to the latest powerful Diffusion models.

I became aware of the astonishing results that GANs achieved during a conference and immediately fell victim to the hype myself. Of course, later it turned out that GANs were not a utopia. In fact, they proved to be extremely challenging to train and debug leading to months of work in order to achieve (remotely) desirable results. So, to me, it seems a good idea to go over GANs in this blog post discussing their principle, challenges, and some observations from my own experience that may help ease the difficulty of dealing with them. We will start by introducing the principles of GANs and why they are hard to train. Then, we will go over some practical and empirical considerations. I may even follow this up with a post solely dedicated to GAN architectures.

Introduction: Principles of GANs

Let’s start with the goal of Generative AI. The goal of generative AI is to achieve knowledge about the distribution of a given dataset $D = \left( x_1, x_2, \dots, x_n \right)$. This data can be images, tabular data, audio, graphs, you name it. One way of doing this is adopting some assumptions and using parametric models to model the density $p_x$. For instance, autoregressive models decompose the joint distribution of the data point into a product of marginal distributions $p \left( x \right) = \prod_{i} p(x_i \mid x_{< i})$. One example of such kind of model is the Transformer used to generate the next word in a sequence conditioned on previous words. Alternatively, one might use a parametric neural model to maximize a variational lower bound on the data density, like the VAE. One might use a series of invertible transformations parametrized using neural networks to transform simple random variables to points from the data distribution (or compute the density of those points). These are the normalizing flows, which build upon the random variable transformation formula.

GANs are different in the sense that they do not build any form of explicit or approximate density for the dataset. Instead, the focus is shifted into transforming a random noise to a point from the data distribution by applying a mapping parametrized by a neural network. The CDF of the induced distribution is given by integrating over the set of events ${ \lbrace G_{\theta}(z) \le x \rbrace}$. The density is obtained by differentiating this integral which is hardly tractable.

\[\begin{aligned} z &\sim p_z \\ \hat{x} &= G_{\theta}(z) \\ P_{\theta}(X \le x) &= \int_{ \lbrace G_{\theta}(z) \le x \rbrace} p(z)dz \end{aligned}\]

Hence, we say that GANs only model the density implicitly, rather than explicitly. With GANs we are not able to compute the density at a given point. However, GANs do support fast sampling from the data distribution. Then, the question becomes how to learn the transformation $G_\theta$.

The principle is to learn only from two sets of data, real and generated, using comparison methods. We do this by comparing the distributions of the generated and real datasets. If we are able to do this, we can make a generator that bridges the gap between real and generated distributions. Thus, we seek a metric $D(p, q)$ which will assess this difference. We want this metric to be 0 only when real and generated distributions are equal. Moreover, we need it to be differentiable to be able to train the generator $G_\theta$. Also, it should only perform the said comparison using samples from real and generated datasets. In other words, we want to be able to learn using only the samples and no other information. It turns out that these criteria can be achieved by devising a metric that uses ratios for comparing two distributions $r = \frac{q(x)}{p_\theta(x)}$. These distributions are the same if and only if this ratio is one everywhere where $p_\theta$ is not 0. Next, we modify this ratio a bit to achieve a trainable loss function that will guide the generator $G_\theta$ to become better at generating realistic data.

Learning with Ratio Estimation

It turns out that if we have a perfect binary classifier $D(x)$ such that $D(x) = 1 \iff x \sim q$ and $D(x) = 0 \iff x \sim p_\theta$. Using this classifier, we can reformulate the ratio of two distributions as $r = \frac{P(x \mid D(x) = 1)}{P(x \mid D(x) = 0)}$. Applying Bayes’ rule to both parts yields

\[r = \left( \frac{P(D(x) = 1 \mid x) \cdot P(x)}{P(D(x) = 1)} \right) \left( \frac{P(D(x) = 0)}{P(D(x) = 0 \mid x) \cdot P(x)} \right)\]

Assuming that priors are equal $P(D(x) = 0) = P(D(x) = 1) = \pi$ and cancelling terms leaves us with

\[r = \frac{P(D(x) = 1 \mid x)}{P(D(x) = 0 \mid x)} = \frac{P(D(x) = 1 \mid x)}{1 - P(D(x) = 1 \mid x)} = \frac{D(x)}{1 - D(x)}\]

So, now it remains for us to achieve this discriminator $D(x)$ by training it to differentiate between real and generated data using binary cross entropy. We parametrize it as a neural network $D_\phi$ and train it alongside the generator. The objectives of generator and discriminator networks are as follows:

\[\max_{\phi} \mathcal{V_ {\text{D}}}(G_ \theta, D_ \phi) = \mathbb{E}_ {x \sim q(x)}\log D_\phi(x) + \mathbb{E}_{z \sim p _z(z)}\log(1 - D _\phi(G(z)))\]

The objective of the generator is to generate samples that are so good that they fool this discriminator. The initially proposed objective is minimizing the probability of generated data being fake

\[\min_{\theta} \mathcal{L_{\text{G}}}(G_\theta, D_\phi) = \mathbb{E}_{z \sim p _z(z)}\log(1 - D _\phi(G _\theta(z)))\]

It can be shown that if the discriminator is optimal, minimizing the objective for the generator above leads to the minimization of the Jensen-Shannon Divergence between the generative distribution and the real distribution. This is really nice because now we have associated learning from samples with a distance metric comparing two distributions. Using this fact, people have analyzed the theoretical properties of GAN optimization. This association means that under the condition of having an optimal discriminator and good enough capacity of two networks, the generator will bridge that gap between real and generated samples (while also inducing a distribution close to the real one). However, this is an if of mountainous size. We train the generator and discriminator at the same time by alternating and we don’t have access to an optimal discriminator. More like we have a differentiable loss function $D_\phi$ that learns to tell apart real and fake samples thereby supplying gradients for the generator. Without an optimal discriminator, we can’t claim that we are minimizing the JSD between two distributions during the entirety of the training process. However, we hope that during the training process, the discriminator and the generator will both become good at their tasks, improving each other. In practice, it has been shown many times that GANs can generate data with good quality, even though theoretical guarantees are not strong.

Challenges in Training GANs

As one might rightfully guess from the GAN loss function definitions and the training procedure, they are really hard to train. We are going to go over some of the more prominent challenges before jumping into practical considerations that can help people surmount these issues.

Vanishing Gradients and Oscillatory Dynamics

We have already established that under optimal discriminator $D$, the training procedure of the generator results in the minimization of JS divergence between the implicitly learned distribution and the real one. However, here lies the first problem. If the discriminator becomes optimal, the gradients for the generator start to vanish and become 0 at some point. So, the generator is not able to learn at all from an optimal discriminator which entirely defeats the purpose of the setup. In the early stages of training, when the generator still produces garbage, the discriminator can become too powerful and result in vanishing gradients. It has been shown that this can happen in the middle of training as well. So, using the standard minmax loss results in slow training. This loss is extremely sensitive to hyperparameter tuning and random initializations.

To circumvent this issue, people use the nonsaturating GAN loss (NSGAN) instead. NSGAN tries to maximize the probability of the generated sample being real instead of minimizing the probability of it being fake. This simple change of perspective results in an alternative loss $-log(D_\phi(G_\theta(z)))$ which doesn’t saturate even if the discriminator becomes strong. However, this loss function leads to a different problem. The gradients of the NSGAN have more variance resulting in significant oscilations during the training process. This is why it is crucial to monitor the generated samples during training as they can vary between good and bad the entire time. Principled monitoring and validation are necessary to tackle quality problems that may result from the oscillatory training procedure.

Aside from NSGAN, a plethora of other variants like least squares GAN (LSGAN), WGAN, and WGAN-GP emerged. Each of these variants claims the advantage of one over the others. The truth is, some are easier to optimize and some are harder, and equally good results can be achieved even with the original minmax loss, given enough time and budget. Not only that but there are many regularization techniques and training schedules proposed which try to mitigate GAN convergence issues. In the next section, we will go over the empirical results and talk about this. For more theoretical analysis of NSGAN and minmax losses please refer to Arjovsky et al.. It has nice explanations of why GANs tend to learn only some parts of the target distribution and why loss functions are unstable.

Mode Collapse

Mode collapse in GANs collectively refers to a lack of variety in the generated samples. The samples themselves might look very good, but they may all be too similar to each other. For instance, when trained on hand-written digits with ten modes, the generator might fail to produce some of the digits. This might happen when GAN fails to learn all modes of the distribution or when the generator maps several noise vectors to the same activations. In principle, the task of the generator is to produce samples that look maximally real to the discriminator. Hence, a natural solution for the generator would be to produce a single most probable sample, ignoring the input noise entirely. If the discriminator doesn’t learn to penalize this behavior, the generator will keep generating similar samples from that single mode. A closely related issue is mode hopping, which is when the generator alternates between producing samples from several modes. For example, it might first start generating the digit 1, then sometime later it might start generating only the digit 7 and so on. This happens when the discriminator learns to punish it for generating samples from a single mode. But the generator just starts generating from other modes. And this behavior continues.

To identify this, it is worth monitoring the outputs during the training. If mode collapse starts happening, there are some methods you can try. Using minibatch statistics such as minibatch standard deviation as input in the discriminator might encourage the generator to produce samples with realistic statistics and more variety, thereby defeating mode collapse. Increasing the batch size and the discriminator capacity may help as well. One very good option is to use packing, which doesn’t require many changes to the architecture (see PacGAN). Essentially, one modifies the discriminator to assign a single label for a pair of inputs instead of a single input. It has been shown in the paper PacGAN that doing this causes a large penalty for the generator if it starts ignoring modes. In other words, if the generator produces samples from the set $\lbrace x \mid q(x) > 0 \text{ but } p_\theta(x) \approx 0 \rbrace$ it gets penalized. Formally, the discriminator starts differentiating between samples from the product distributions $(p^{2}, q^{2})$. Additionally, this technique might also improve the quality of the generated data. It is worth trying.

No Easy Evaluation

Unlike applications in other scenarios, where one has access to interpretable metrics for validation, GAN is an adversarial game. The generator and discriminator work against each other during the training, hence making the loss functions highly oscillatory and almost uninterpretable. It is hard to notice things like GAN overfitting, and underfitting with looking solely at the loss dynamics. However, the loss curves can still contain valuable information. Healthy loss dynamics typically follow some patterns that you can look up by searching for “good GAN loss curves” or something similar. There is also some intuition involved as well. If the generator loss drops to 0, it means that it’s fooling the discriminator with garbage data. If discriminator accuracy stays 100% for a long time then the generator might not improve. In these cases, it is worth rethinking your architectures and tuning parameters. It is undeniable that good evaluation metrics are still necessary to be able to compare GAN architectures and results.

People have come up with metrics such as Inception Score (IS) and Freched Inception Distance (FID) to compute the quality of the generated samples. These metrics make use of pre-trained neural networks to compute the similarity between generated and real datasets. These are a good starting point for evaluating the quality of the generated samples. It is a good idea to compute and track FID or IS during training to get an idea of the improvement of the GAN. Moreover, these metrics are good when comparing several GANs. Comparison is necessary to identify a set of architectures and hyperparameters that actually work. However, these metrics have some cons related to the way they are computed. Typically, they use networks trained on ImageNet for the calculation of metrics. However, what if we want to generate data that is far from the distribution of ImageNet? Will these metrics still serve their purpose as intended? How about some domains which cannot make use of such pre-trained networks?

An alternative to these metrics is a qualitative evaluation after say, each epoch. In this case, one needs to visualize generated samples and monitor how they change in time. This is costly and time-consuming but you have an advantage of verifying the quality yourselves. It is worth setting up a qualitative evaluation pipeline and sticking to it if resources permit it. The disadvantage here is the difficulty of comparing GANs. Comparing several GAN architectures becomes really hard.

Practical Considerations

With a plethora of GAN training techniques, losses, architectures, hyperparameters, and schedules available, it is hard for people to start using GANs without getting confused. The instability of training only makes the situation worse. Instability and convergence issues cause people to spend significant time tuning architectures and rarely reaching satisfactory results. Luckily for us, there was a large-scale study of GAN architectures performed in the paper “Are GANs Created Equal”. They compare a lot of hyperparameters, losses, and architectural choices which helps answer the question “Which GAN should I use?”.

Essentially, what they find out in the paper is that no single loss function significantly outperforms the other. This means you should just start with the Non-Saturating GAN loss (NSGAN) and not even think about changing the loss until it is really necessary. They also find out that there are clusters of hyperparameters that work well with each other and result in stable training. Whereas some other parameters are a recipe for unstable training. What is surprising in the study is that a set of parameters that work on one dataset can be used on other datasets with likely success. This is really good as one can just start with an architecture that has been proven to work and use it on their own data and tasks at hand. Of course, this doesn’t mean that one should stop hyperparameter tuning and just blindly use whatever is out there. Instead, this is meant to serve as a hint on where to start which can save a lot of time. Moreover, they also compare the time and resources required for tuning each GAN model they analyze. They find out that NSGAN is a good starting point and it’s easier to tune than something like WGAN-GP. In short, it is a good idea to familiarize yourself with architectures that are proven to work (pix2pix, DCGAN, StyleGan, etc.). You might find insights and tricks that may be very useful.

Finally, some tips that have been particularly useful in my own experience. These are short bullet points that were useful when creating and training GANs.

  1. Start with an architecture that was already shown to work for the task at hand. For unconditional generation, start with something like DCGAN, pick the parameters they used, and work your way from there. For conditional generation, look into pix2pix. They have a nice and clean codebase that you can extend.
  2. Start by training your GAN on small and limited data to debug issues. If your GAN can’t fit easy distributions first, it is unlikely that it can learn complex ones. Pick a small subset of the original data and start from there. Try to find a set of hyperparameters that result in stable training before moving on.
  3. Start with regular NSGAN first and tune the loss function when it is absolutely necessary.
  4. Set up a nice evaluation pipeline and stick to it. At least try to visualize the generated results once in a while to check for issues. Monitor loss function dynamics, and try to detect failure modes. (Search GAN failure modes for further reading).
  5. If you work on paired translation tasks such as image-to-image translation, keep a separate val set and see how your GAN performs the translation on samples that it hasn’t seen during training. This can help detect overfitting in GANs. You can detect if the GAN starts to ignore your input conditioning. In paired translation setups, it is a good idea to try auxiliary loss functions to push the generated samples closer to their true counterparts.
  6. Finally, exercise patience. You need time to tune hyperparameters. GANs won’t just magically work even if you choose a previously tested architecture.