How GANs Work
Generative adversarial networks (GANs) are a class of neural network models capable of generating data. Because of their capabilities, they have become a prominent research topic in deep learning. In a few years, GANs have progressed from producing blurry digits to creating images that look like realistic human faces.
GANs are a type of generative model. That means they can produce entirely new, valid data. Valid data refers to outputs that are acceptable as target instances.
For example, suppose we want to generate new images to train an image classification network. For such an application, we want the synthesized images to look as realistic as possible and to be similar in style to other training images.
GANs consist of two adversarial networks: a generator and a discriminator. The generator is trained to create realistic images from random noise input. The discriminator is trained to classify whether an image is real or fake.
Their real power comes from adversarial training. The generator's weights are updated based on the discriminator's loss. The generator is trained to produce images that the discriminator cannot reliably label as fake. As generated images become more realistic, the discriminator becomes better at distinguishing real from fake. Both networks improve through this feedback loop.
Technically, the discriminator's loss measures classification error on real versus fake images. The generator's loss depends on how well it can "fool" the discriminator, i.e., the discriminator's errors on fake images. Thus, the generator aims to maximize the discriminator's error on fake images.
GANs therefore establish a feedback loop where the generator trains the discriminator and the discriminator trains the generator. The following diagram illustrates this.
Training a GAN in PyTorch to Generate Digits
We will demonstrate how to build and train a GAN using PyTorch. The MNIST dataset contains 60,000 training examples: 28x28-pixel grayscale images of digits 0–9. This dataset is suitable for our use case and is a common proof-of-concept dataset in machine learning.
We start with imports; we only need components from PyTorch.
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
import torchvision
import torchvision.transforms as transforms
Next, prepare a DataLoader for the training data. Remember, we want to generate random digits from 0 to 9, so we will also create labels for these 10 classes.
Now we can define the networks. Start with the discriminator. The discriminator classifies whether an image is real, so it is an image classification network. The input matches standard MNIST size: 28x28 pixels. We flatten the image into a 784-length vector. The output is a single value indicating whether the image is a real MNIST digit.
Next is the generator. The generator creates images from pure noise. In this example, the generator takes a 100-length vector of random noise and outputs a 784-length vector, which can be reshaped to 28x28 pixels.
To set up training we need:
- a loss function
- optimizers for each network
- the number of training epochs
- batch size
If running on GPU, PyTorch requires explicitly moving models to the GPU.
Now the training loop. A typical PyTorch training loop has an outer loop over epochs and an inner loop over batches. For GAN training, we update both the generator and the discriminator within the same loop. The code below trains a GAN in PyTorch; the steps are described afterward.
Steps in the training loop:
- Prepare a batch of real images for the discriminator. These inputs are real MNIST images and the corresponding target labels are all ones, where 1 indicates real.
- Create input vectors for the generator to produce fake images. The generator expects 100-length random noise vectors; images.size(0) is the batch size.
- Pass the random noise through the generator to produce fake images. These fake images, together with the real images from step 1, are used to train the discriminator. The target labels for fake images are zeros, where 0 indicates fake.
- Train the discriminator using both fake and real images. The total discriminator loss is the loss on fake images plus the loss on real images.
- With the discriminator updated, generate predictions on fake images and backpropagate the loss through the generator so its weights update according to how well it fooled the discriminator.
- a. Generate a batch of fake images for prediction.
- b. Use the discriminator to predict on the fake images and collect the outputs.
- Train the generator using the discriminator's predictions. Use a target of all ones for the generator, because the generator's objective is to make the discriminator output 1 for its images. A generator loss of zero would mean the discriminator predicted all ones.