#013 Conditional Generative Adversarial Networks (CGANs)

#013 Conditional Generative Adversarial Networks (CGANs)

Highlights: In the previous posts, we studied Deep Convolutional Generative Adversarial Networks (DCGAN). These types of conventional DCGAN architectures are trained in an unsupervised and unconditional manner. Therefore, there are no labels involved in the training process.

In this post, we will see how we can train our GAN model using a labeled dataset by the use of Conditional Generative Adversarial Networks (CGAN). So, let’s begin.

Tutorial overview:

  1. What are Conditional GANs?
  2. Initializing and Defining a CGAN Model in PyTorch
  3. Training the CGAN Model

1. What are Conditional GANs?

As mentioned earlier, DCGANs are unsupervised and don’t make use of any labels for training. This means that we have no control over the type of images that are being generated by the GAN model, even though it is capable of creating fresh realistic samples for a given dataset.

Now, what do we do if we want our GAN model to generate only a certain type of image and not any random image?

We can try training our GAN and then, using a random sample of a noise vector \(Z \) from a typical normal distribution. This vector can be fed to the generator and an output image can be generated which reflects a particular image in the dataset. For example, in the MNIST Handwritten Digit dataset, it may be 8 or some other number.

However, we still don’t have any control over the generation process.

So, what if, instead of random images, we want our GAN model to generate only one sort of image? What if, after training the GAN, we take a random sample of a noise vector \(Z \) from a typical normal distribution, feed it to the generator, and get an output image that reflects any image in the dataset? If we use the MNIST dataset, for example, it may be 8 or another number. Remember, we still have no control over the generation process in reality.

This can be solved by a Conditional GAN. Have a look at how it’s represented in the following image.

In order to make use of class label information in a GAN model, we have two key motivations:

  1. Improve the GAN: Class labels and other additional information connected with the input photos can be used to improve the GAN. This improvement will reflect steadier and faster training and higher quality of generated images.
  2. Targeted Image Generation: Class labels can also be used to generate photographs of a specific type in an intentional or targeted manner.

Alternatively, a GAN can be trained such that the class label is conditioned on both the generator and the discriminator models. When the trained generator model is used to generate images in the domain, this indicates that images of a given type or class label, can be generated.

Have a look at the image below wherein we can see the targeted image generation in CGAN as compared to standard GAN when applied to the MNIST Handwritten Digit dataset.

On the contrary, if both the generator and the discriminator are conditioned on some extra information y, this can be expanded to a conditional model. Conditioning can be done by supplying y as an additional input layer to the discriminator and the generator.

For example, in the MNIST Handwritten Digit dataset, individual handwritten digits such as the number 9 can be generated. In CIFAR-10, specific object photos such as ‘frogs’ can be generated. And, in the Fashion MNIST dataset, specific pieces of apparel such as ‘dress’ can be generated.

Now that we have learned the fundamentals of CGANs, let us try and implement our knowledge by training a CGAN model in PyTorch.

2. Initializing and Defining a CGAN Model in PyTorch

The first thing we need to do is to import all the necessary libraries for creating our CGAN model.

import random
import matplotlib.pyplot as plt
import numpy as np
 
import torch 
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
 
from torchvision.utils import make_grid
from torch.autograd import Variable
from torchvision.utils import save_imageCode language: JavaScript (javascript)

In order to get the same results every time, we will create a random seed and define our training device, a CPU or GPU.

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # This will generate a random seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
 
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)Code language: PHP (php)

For the data transformations, we will transform the data into Tensors and also normalize the data.

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5), (0.5))])

We will use the famous MNIST Handwritten Digit dataset for our CGAN model and PyTorch framework, which is really useful for loading such datasets.

Now, we create a data_loader and set the batch size to 128.

dataset = torchvision.datasets.MNIST("./data", download=True, transform=transform)
 
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)Code language: PHP (php)

Let us explore the dataset a bit by showing one batch of images.

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))
 
def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        breakCode language: JavaScript (javascript)

As we can see below, the above code outputs some random numbers from 0 to 9 in this batch of 128 images.

Next, we will define some hyperparameters like the image size, the dimension of features, and latent space size, among others.

image_shape = (1, 28, 28)
image_dim = int(np.prod(image_shape))
latent_dim = 100
feature_dim = 128
 
n_classes = 10
embedding_dim = 50

This embedding_dim variable will be used for a so-called embedding layer. Don’t worry, we will get into further details on this later on.

Weight Initialisation

Now, we also need to initialize the weights of the model to a random Gaussian distribution as always.

# custom weights initialization called on generator and discriminator
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)Code language: PHP (php)

Defining the Generator Model

We can start building our Generator model now. It will have the same layers as a normal GAN. The only difference is that it will accept not only a Latent Space but also a label as input.

This label will be passed into an embedding layer. This embedding layer is similar to the one-hot encoding of the classes, the only difference being that it creates a long vector with more information about a specific class. If we give it a number 6 as input, it would mark the 6th index as a high value from 0 to 1. Since number 9 looks similar to number 6, it would also have a number from 0 to 1.

Then, we will concatenate this embedding with the Latent Space and pass it through the network. Because of this concatenation, we need to add +1 to the feature dimension of 128.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.label_embed  = nn.Embedding(n_classes, embedding_dim)
        self.label_linear = nn.Linear(embedding_dim, 7*7)
 
        self.latent_input = nn.Linear(latent_dim, 128*7*7)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
 
        self.conv_transpose1 = nn.ConvTranspose2d(feature_dim+1, feature_dim, 4, 2)
        self.batch_norm = nn.BatchNorm2d(feature_dim)
        # Leaky
 
        self.conv_transpose2 = nn.ConvTranspose2d(feature_dim, feature_dim, 4, 2)
        # Batch norm
        # Leaky
 
        # Final output
        self.final_conv = nn.Conv2d(feature_dim, 1, 7)
        self.tanh = nn.Tanh()
 
    def forward(self, inputs):
        noise_vector, label = inputs
 
        label_output = self.label_embed(label)
        label_output = self.label_linear(label_output)
        label_output = label_output.view(-1, 1, 7, 7)
 
 
        latent_output = self.latent_input(noise_vector)
        latent_output = self.leaky_relu(latent_output)
        latent_output = latent_output.view(-1, 128, 7, 7)
 
        concat = torch.cat((latent_output, label_output), dim=1)
 
        image = self.conv_transpose1(concat)
        image = self.batch_norm(image)
        image = self.leaky_relu(image)
 
        image = self.conv_transpose2(image)
        image = self.batch_norm(image)
        image = self.leaky_relu(image)
 
        image = self.final_conv(image)
        image = self.tanh(image)
 
        return image

Defining the Discriminator Model

The discriminator model is a normal discriminator model similar to all the other GANs except for one key difference of having an embedding layer, just like in the generator model.

The embedding layer in the discriminator model accepts the number of classes that the dataset has, in our case 10, and makes an embedding on a vector of size 50. This embedding is, then, passed into a fully connected layer.

The results are passed into a normal convolutional network. Since our images are grayscale, they would have had only 1 channel. However, due to the concatenation of the embedding with the image, the first layer of the convolution will accept a 2-channel input.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.label_embed  = nn.Embedding(n_classes, embedding_dim)
        self.label_linear = nn.Linear(embedding_dim, 1*28*28)
 
        self.conv1 = nn.Conv2d(2, feature_dim, 3, 2)
        self.batch_norm1 = nn.BatchNorm2d(feature_dim)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
 
        self.conv2 = nn.Conv2d(feature_dim, feature_dim, 3, 2)
        # Batch norm
        # Leaky relu
 
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(4608, 1)
        self.sigmoid = nn.Sigmoid() 
 
    def forward(self, inputs):
        img, label = inputs
        label_output = self.label_embed(label)
        label_output = self.label_linear(label_output)
 
        label_output = label_output.view(-1, 1, 28, 28)
 
        concat = torch.cat((img, label_output), dim=1)
 
        output = self.conv1(concat)
        output = self.batch_norm1(output)
        output = self.leaky_relu(output)
 
        output = self.conv2(output)
        output = self.batch_norm1(output)
        output = self.leaky_relu(output)
 
        output = self.flatten(output)
        output = self.dropout(output)
        output = self.fc(output)
        output = self.sigmoid(output)
        return output

Defining the Loss Function

For creating the loss function, we will use what we call ‘adversarial loss’. It is a type of Binary Cross-Entropy Loss Function inside PyTorch. In the process, we will use a total of 2 functions – one for calculating the loss on the generator and the other for calculating the loss on the discriminator.

adversarial_loss = nn.BCELoss() 
def generator_loss(fake_output, label):
  gen_loss = adversarial_loss(fake_output, label)
  return gen_loss
 
def discriminator_loss(output, label):
  disc_loss = adversarial_loss(output, label)
  return disc_loss
Code language: JavaScript (javascript)

The optimizer will be a normal Adam optimizer with a learning rate of 0.0002.

learning_rate = 0.0002 
G_optimizer = torch.optim.Adam(generator.parameters(), lr = learning_rate)
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr = learning_rate)

We have successfully initialized and defined our CGAN model. It’s time to train the model.

3. Training the CGAN Model

The interesting part of every GAN is the training. First, we train the discriminator on both real and fake data generated by the generator.

We want our discriminator to be slightly better at guessing real or fake images so that we push our generator to make images as real as it can.

After the discriminator, we move on to train the generator as the second model. After each epoch, we will save a sample image as well as the models, so that we don’t have to repeat the training process again and again.

num_epochs = 50
 
for epoch in range(num_epochs): 
   
    for index, (real_images, labels) in enumerate(data_loader):
        real_images = real_images.to(device)
        labels = labels.to(device)
        labels = labels.unsqueeze(1).long()
 
        real_target = torch.ones(real_images.size(0), 1).to(device)
        fake_target = torch.zeros(real_images.size(0), 1).to(device)
 
        D_optimizer.zero_grad()
        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
    
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)  
        noise_vector = noise_vector.to(device)
        
        generated_image = generator((noise_vector, labels))
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output,  fake_target)
 
        # train with fake
        D_total_loss = (D_real_loss + D_fake_loss) / 2
      
        D_total_loss.backward()
        D_optimizer.step()
 
        # Train generator with real labels
        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
 
        G_loss.backward()
        G_optimizer.step()
    
    save_image(generated_image.data[:50], f'sample_{epoch}.png', nrow=5, normalize=True)
     
    torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pth')
Code language: PHP (php)

Now comes the main question. How do we generate images for some specific class?

It is easier than it looks! All we need to do is to create a random latent vector and pass it into the generator as well as a tensor of the images that we want to generate from that latent vector.

For example, let’s generate the number 8 from a random latent vector and save it under the name “test.png”.

noise_vector = torch.randn(1, latent_dim, device=device)  
noise_vector = noise_vector.to(device)
 
generated_image = generator((noise_vector, torch.tensor([8]).to(device)))
save_image(generated_image.data[0], 'test.png', nrow=5, normalize=True)Code language: PHP (php)

As you can see, our model generated the taret image with number 8.

Summary

In this post, you have learned how to build and implement CGAN, and studied the difference between a standard DCGAN and a CGAN. We showed how we can generate a specific image using the techniques of using class labels. Going ahead, we shall continue to explore CGAN and study a famous type of Conditional GAN along with its implementation in PyTorch.