#011 Developing a DCGAN for CelebA Dataset

#011 Developing a DCGAN for CelebA Dataset

Highlights: In the previous two chapters, we built two distinct DCGAN models, one for MNIST Handwritten Digit dataset, and the other for CIFAR-10 dataset.

In this chapter, we will build yet another DCGAN model. However, this time, we will use a different dataset known as the CelebA dataset. Let’s start by learning a bit about the CelebA dataset and its features.

Tutorial overview:

  1. Downloading CelebA Dataset
  2. Initializing and Defining the DCGAN Model
  3. Training the DCGAN Model

1. Downloading CelebA Dataset

The large-scale CelebFaces Attributes Dataset is a collection of over 200,000 images of celebrities. It is annotated in such a way that we get an understanding of what is happening in each image. Is the person wearing a hat? Is the person wearing glasses? Is the person smiling or not? And, so on.

You can download this dataset easily from Kaggle. Or, you can simply run the following code in ‘terminal’ or Google Colab.

!wget 'https://docs.google.com/uc?export=download&id=1q0iNpIa-sRq8k1LPgXGXX6YvQFPwQz5U&confirm=t' -O img_align_celeba.zipCode language: JavaScript (javascript)

The above command will download a zip file named img_align_celeba.zip. Let’s create a folder where we will unzip all of our images. We’ll name this folder celeba_dataset.

!mkdir celeba_dataset

Once the folder is created, we’ll simply unzip the zip file here.

!unzip -q img_align_celeba.zip -d celeba_dataset/img_align_celeba

Now, inside this celeba_dataset folder, you will find another folder called img_align_celeba where all our images with respective filenames (ex: “202599.jpg”) will have been downloaded. This type of folder structure is necessary to work with the PyTorch image loader.

Let us use this dataset and start building our DCGAN model.

2. Initializing and Defining the DCGAN Model

The first thing we do is to import all the necessary and required libraries as shown in the code below.

import random
import matplotlib.pyplot as plt
import numpy as np
 
import torch 
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
 
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # This will generate a random seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
 
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)
Code language: PHP (php)

We have given a particular value for the parameter manualSeed in the code above. You can choose to keep it as is and you will get identical results as ours. You can even set it manually or choose a random value. For this, you can ‘uncomment’ the line below it (refer to the code above).

Due to the proper structuring of folders, we will be able to load our data easily. The structuring was an important step because we will use the ImageFolder dataset class in our code. For this, we require subdirectories in the dataset’s root folder. The structure we have given meets all these requirements.
Now, we are ready to create the dataset. Let’s initiate the parameter dataLoader and let the program run so that we can visualize some of our training data.

transform = transforms.Compose([transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = datasets.ImageFolder(root="celeba_dataset/img_align_celeba",
                               transform=transform)
# Create the dataloader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)Code language: PHP (php)

Have a look at how we sampled images from the dataset. The images can be seen below the code.

# Helper function, that takes a whole batch of images and shows them as a combined image
def imshow(img):
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()
 
# We will take one batch of data and show the distribution of pixels in one sample image
data_iter = iter(data_loader)
images, labels = data_iter.next()
 
imshow(torchvision.utils.make_grid(images))Code language: PHP (php)

Weight Initialisation

With our input parameters set and the dataset prepared, we can now get into the implementation. We will start with the eight initialization strategies, and, then, talk about the generator, discriminator, loss functions, and training loop in detail. 

In the DCGAN paper, the authors specify that all model weights should be randomly initialized from a Normal distribution with mean=0 and stdev=0.02. To meet these criteria, the weights_init function takes an initialized model as input and reinitializes all convolutional, convolutional-transpose, and batch normalization layers. This function is applied to the models immediately after initialization.

# Creating a function that will create a normal distribution of weights to the Convolutional layers
def init_normal(m):
    classname = m.__class__.__name__
 
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)Code language: PHP (php)

Our goal is to convert a latent space, z, to a data space or RGB image of the same size as the training images (\(64 \times 64 \times 3 \)).

In practice, this is accomplished through a series of strided two-dimensional Convolutional Transpose layers, each paired with a 2D Batch Normalization layer and a ReLU activation.

Then, the output of the generator is fed through a tanh() function to return it to the input data range of [-1,1]. An image of the generator from the DCGAN paper is shown in the image below.

Defining the Generator Model

Now that we understand the architecture, we can write the code for the generator model. Note, that we will apply the weights_init function immediately after the initialization.

feature_map       = 64
latent_space_size = 100
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
 
        self.convTranspose1 = nn.ConvTranspose2d( latent_space_size, feature_map * 8, 4, 1, 0, bias=False)
        self.convTranspose2 = nn.ConvTranspose2d( feature_map * 8  , feature_map * 4, 4, 2, 1, bias=False)
        self.convTranspose3 = nn.ConvTranspose2d( feature_map * 4  , feature_map * 2, 4, 2, 1, bias=False)
        self.convTranspose4 = nn.ConvTranspose2d( feature_map * 2  , feature_map * 1, 4, 2, 1, bias=False)
        self.convTranspose5 = nn.ConvTranspose2d( feature_map * 1  , 3              , 4, 2, 1, bias=False)
 
        self.bnorm1 = nn.BatchNorm2d(feature_map * 8)
        self.bnorm2 = nn.BatchNorm2d(feature_map * 4)
        self.bnorm3 = nn.BatchNorm2d(feature_map * 2)
        self.bnorm4 = nn.BatchNorm2d(feature_map * 1)
 
    def forward(self, x):
        x = self.convTranspose1(x)
        x = F.relu(self.bnorm1(x))
 
        x = self.convTranspose2(x)
        x = F.relu(self.bnorm2(x))
 
        x = self.convTranspose3(x)
        x = F.relu(self.bnorm3(x))
 
        x = self.convTranspose4(x)
        x = F.relu(self.bnorm4(x))
 
        x = self.convTranspose5(x)
        x = torch.tanh(x)
        return x
 
# Create the generator
generator = Generator().to(device)

Next, we create a function to generate the latent space z upon call and a random N number of vectors of the size 100.

# Helper function for creating a random latent vector
def generate_latent_vectors_z(N):
    z = torch.randn(N, 100, 1, 1, device=device)
    return z    Code language: PHP (php)

We can even test the generator in the early stage, with no training, just by calling it on the latent space.

# Generate a random latent space of 100 elements which we will use to transform to an image
z = generate_latent_vectors_z(1).to(device)
random_noise = generator(z)
 
plt.imshow(random_noise[0].cpu().permute(1, 2, 0).detach().numpy())Code language: PHP (php)

Defining the Discriminator Model

The next step is to build the discriminator, D. It is a binary classification network that takes an image as input and outputs a scalar probability that the input image is real (as opposed to fake).

Here, D takes a \(3 \times 64 \times 64 \) input image and processes it through a series of Conv2d, BatchNorm2d, and LeakyReLU layers, and outputs the final probability through a Sigmoid activation function.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3              , feature_map    , 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(feature_map    , feature_map * 2, 4, 2, 1, bias=False)
        self.conv3 = nn.Conv2d(feature_map * 2, feature_map * 4, 4, 2, 1, bias=False)
        self.conv4 = nn.Conv2d(feature_map * 4, feature_map * 8, 4, 2, 1, bias=False)
        self.conv5 = nn.Conv2d(feature_map * 8, 1              , 4, 1, 0, bias=False)
 
        self.bnorm1 = nn.BatchNorm2d(feature_map*2)
        self.bnorm2 = nn.BatchNorm2d(feature_map*4)
        self.bnorm3 = nn.BatchNorm2d(feature_map*8)
 
        self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
 
        x = self.conv2(x)
        x = F.leaky_relu(self.bnorm1(x))
 
        x = self.conv3(x)
        x = F.leaky_relu(self.bnorm2(x))
        
        x = self.conv4(x)
        x = F.leaky_relu(self.bnorm3(x))
 
        x = self.conv5(x)
        x = self.sigmoid(x)
        return x

Defining the Loss Function

Now that we have set up the generator and the discriminator, we can specify how they learn through the loss functions and optimizers.

For the loss function, we will be using the Binary Cross Entropy Loss.

Then, we set up two separate optimizers, one for the discriminator and one for the generator. Both of these optimizers will be Adam Optimizers, each with a learning rate of 0.0002 respectively.

# Initialize BCELoss function
criterion = nn.BCELoss()
 
# Models
generator =     Generator().to(device)
discriminator = Discriminator().to(device)
 
# Use the modules apply function to recursively apply the initialization
generator.apply(init_normal)
discriminator.apply(init_normal)
 
# Optimizers
generator_optimizer     = torch.optim.Adam(generator.parameters(),     lr=0.0002)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
Code language: PHP (php)

We have successfully initialized and defined our DCGAN model, including the generator, the discriminator, the loss function, and the optimizers. Now, we can start to train our model. We will first train our discriminator and then, train the generator.

3. Training the DCGAN Model

Similar to how we trained the DCGAN models for MNIST and CIFAR-10 datasets in the previous two posts, we will train this model in two legs:

  1. Training the Discriminator
  2. Training the Generator


Training the Discriminator

The goal of training the discriminator is to maximize the probability of correctly classifying a given input as real or fake.

We begin by constructing a batch of real samples from the training set. Then, we forward pass through the discriminator, calculate the loss, and calculate the gradients in a backward pass. 

After this, we will construct a batch of fake samples with the current generator, forward pass this batch through the discriminator, calculate the loss, and accumulate the gradients with a backward pass.

Now, with the gradients accumulated from both the all-real and all-fake batches, we can optimize the Discriminator.

Training the Generator

As stated in the original paper, we want to train the generator by minimizing the loss to generate better fake images. It was actually proposed by Goodfellow to not provide sufficient gradients, especially, early on in the learning process.

We have to maximize the loss instead in order to fix this. In our code, we accomplish this by classifying the generator output from Part 1 with the discriminator, then, computing generators loss using real labels as GT, then, computing G’s gradients in a backward pass, and finally, updating G’s parameters with an optimizer step.

# Training Loop
print("Starting Training Loop...")
epochs = 20
N_samples = 128
 
# For each epoch
for epoch in range(epochs):
    # For each batch in the dataloader
    for idx, images in enumerate(data_loader):
        discriminator.zero_grad()
 
        if (idx%100 ==0):
            print(f"{epoch}/{epochs} epoch | {idx}/{len(data_loader)} batch \n")
 
        # Generate examples of real data
        real_data = images[0].to(device)
        real_data_label = torch.ones(real_data.shape[0]).to(device)
 
        # Training of Discriminator
        # Forward pass real batch through discriminator
        output = discriminator(real_data).view(-1)
        errD_real = criterion(output, real_data_label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
 
        # Create fake data with a generator
        z = generate_latent_vectors_z(len(real_data)).to(device)
        fake_data = generator(z)
        fake_data_label = torch.zeros(fake_data.shape[0]).to(device)
 
        # Forward pass fake data through discriminator
        output = discriminator(fake_data.detach()).view(-1)
        errD_fake = criterion(output, fake_data_label)
        errD_fake.backward()
 
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update discriminator
        discriminator_optimizer.step()
 
        # Training of Generator
        generator.zero_grad()
        real_data_label = torch.ones(fake_data.shape[0]).to(device) # fake labels are real for generator cost
        output = discriminator(fake_data).view(-1)
        errG = criterion(output, real_data_label)
        errG.backward()
 
        # Update Generator
        generator_optimizer.step()
Code language: PHP (php)

When the network finishes training, we will save the generator and discriminator models.

torch.save(generator.state_dict(), '/content/DCGAN_celeba_g.pth')
torch.save(discriminator.state_dict(), '/content/DCGAN_celeba_d.pth')Code language: JavaScript (javascript)

Finally, we can now use our model to generate fake images with the help of the code snippet below.

# Generate a random latent space of 100 elements which we will use to transform to an image
z = generate_latent_vectors_z(len(real_data)).to(device)
fake_data = generator(z)
plt.imshow(fake_data[10].cpu().permute(1, 2, 0).detach().numpy())Code language: PHP (php)

Notice the result above. These images may not be perfect, but, in order to get accurate images, the models need to be trained on many more epochs. This means that you can experiment by setting the number of epochs to a more significant number to see how good the results get.

Another experiment that can be performed here is to modify this model by taking a different dataset and then, possibly, changing the size of the images as well as the model architecture.

Summary

Great! In no time, you have studied the workings and building of three interesting Deep Convolutional Generative Adversarial Network (DCGAN) models for three popular datasets – MNIST Handwritten Digit dataset, CIFAR-10 dataset, and CelebA dataset.

In the next post, we will go into detail about the working of a Generator model in DCGANs, especially, the input latent space, its mathematics, and a bit of code, of course.