Training Generative Adversarial Network using PyTorch Lightning
This example covers how to train a Generative Adversarial Network (GAN) using Pytorch Lightning; the dataset used is cifar-10 images
What are GANs
What are GANs?
Generative Adversarial Networks, or GANs are an approach to generative modeling using deep learning methods, such as convolutional neural networks. https://arxiv.org/abs/1406.2661
In this approach, the generative model is pitted against an adversary or discriminative model that learns to determine whether a sample image is from the model distribution or the data distribution.
PyTorch Lightning Bolts contains many state of the art pre-trained model recipes. For this example we will use the basic gan script to train a basic generative adversarial network; check out the code here. The dataset argument can train for any standard image dataset; in this example we choose cifar10
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, seed_everything, Trainer
from torch.nn import functional as F
from pl_bolts.models.gans.basic.components import Discriminator, Generator
Training this model using Grid, we are going to use the Web application. Login to Grid and start new Run; in the Github repo box; paste the script. Choose any CPU or GPU and lightning framework. Learning rate and dataset can be specified as a script argument.
For this example, we used these values
--learning_rate "uniform(1e-5, 1e-1, 3)" and --dataset "['cifar10']"
Next, see the model training, generating metrics and artifacts; download as necessary
The basic gan model recipe generates metrics such as loss which can be visualized in the web interface; here you see the generator loss and discriminator loss;
def generator_loss(self, x):
# sample noise
z = torch.randn(x.shape, self.hparams.latent_dim, device=self.device)