GANs using PyTorch Lightning
Training Generative Adversarial Network using PyTorch Lightning

Goal

This example covers how to train a Generative Adversarial Network (GAN) using Pytorch Lightning; the dataset used is cifar-10 images
    1.
    What are GANs
    2.
    The model
    3.
    Training
    4.
    Visualizing results

What are GANs?

Generative Adversarial Networks, or GANs are an approach to generative modeling using deep learning methods, such as convolutional neural networks. https://arxiv.org/abs/1406.2661
In this approach, the generative model is pitted against an adversary or discriminative model that learns to determine whether a sample image is from the model distribution or the data distribution.

The model

PyTorch Lightning Bolts contains many state of the art pre-trained model recipes. For this example we will use the basic gan script to train a basic generative adversarial network; check out the code here. The dataset argument can train for any standard image dataset; in this example we choose cifar10
1
from argparse import ArgumentParser
2
3
import torch
4
from pytorch_lightning import LightningModule, seed_everything, Trainer
5
from torch.nn import functional as F
6
7
from pl_bolts.models.gans.basic.components import Discriminator, Generator
8
9
10
class GAN(LightningModule):
11
"""
12
Vanilla GAN implementation.
13
Example::
14
from pl_bolts.models.gans import GAN
15
m = GAN()
16
Trainer(gpus=2).fit(m)
17
Example CLI::
18
# mnist
19
python basic_gan_module.py --gpus 1
20
# imagenet
21
python basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
22
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
23
--batch_size 256 --learning_rate 0.0001
24
"""
Copied!

Training

Training this model using Grid, we are going to use the Web application. Login to Grid and start new Run; in the Github repo box; paste the script. Choose any CPU or GPU and lightning framework. Learning rate and dataset can be specified as a script argument.
For this example, we used these values
--learning_rate "uniform(1e-5, 1e-1, 3)" and --dataset "['cifar10']"
Next, see the model training, generating metrics and artifacts; download as necessary

Visualizing results

The basic gan model recipe generates metrics such as loss which can be visualized in the web interface; here you see the generator loss and discriminator loss;
1
def generator_loss(self, x):
2
# sample noise
3
z = torch.randn(x.shape[0], self.hparams.latent_dim, device=self.device)
4
y = torch.ones(x.size(0), 1, device=self.device)
5
6
# generate images
7
generated_imgs = self(z)
8
9
D_output = self.discriminator(generated_imgs)
10
11
# ground truth result (ie: all real)
12
g_loss = F.binary_cross_entropy(D_output, y)
13
14
return g_loss
15
16
def discriminator_loss(self, x):
17
# train discriminator on real
18
b = x.size(0)
19
x_real = x.view(b, -1)
20
y_real = torch.ones(b, 1, device=self.device)
21
22
# calculate real score
23
D_output = self.discriminator(x_real)
24
D_real_loss = F.binary_cross_entropy(D_output, y_real)
Copied!
This model will train for a long time and is just an example; stop experiments at any time
Last modified 3mo ago