#014 Pix2Pix Generative Adversarial Networks
Highlights: In the previous post, we introduced a way of training GAN models using a labeled dataset, known as Conditional Adversarial Networks (CGANs).
CGANs are a revolution in the field of AI and are gaining popularity quite rapidly. One of the most famous GANs being used today is the Image2Image GAN or the Pix2Pix GAN.
In this post, we will study Pix2Pix GAN in detail. Let’s begin.
Tutorial overview:
- What is a Pix2Pix GAN?
- The architecture of a Pix2Pix GAN
- U Net Generator Model
- PatchGAN Discriminator Model
- Loss Functions in a Pix2Pix GAN Model
- Implementing a Pix2Pix GAN Model
1. What is a Pix2Pix GAN?
A Pix2Pix GAN is a variation of the Conditional GAN. Remember CGANs request a label as input. For instance, if we input a label as “Shoe”, the model will generate an image of a shoe.
In the Pix2Pix GAN, the conditional part based on a labeled input is discarded. Instead, a new condition is created using an image as the input. This GAN needs two pairs of images that correspond with each other.
For example, if we want to transform a satellite image into a google map type image or a segmented house into a real-world house, we can use a Pix2Pix GAN. Have a look at the image below in order to understand this better.
2. The architecture of a Pix2Pix GAN
A Pix2Pix GAN is also known as an Image2Image Translation GAN. This is due to the fact that it can generate a fake image corresponding to a certain input image.
The architecture of a Pix2Pix GAN is quite similar to other GANs with a generator and a discriminator.
The discriminator model is provided with both the source input image and the newly generated image. It must determine if the generated image is a plausible transformation of the source image. Have a look at the following image to see the overall architecture of a Pix2Pix GAN.
3. U-Net Generator Model
Unlike the other GANs, in the case of the Pix2Pix GAN, the generator doesn’t take a random vector as input, but rather an image as input.
Due to this reason, the Pix2Pix GAN utilizes a U-Net architecture instead of the regular encoder-decoder structure. It takes an image as the input and passes it through the first few layers of the U-Net encoder part until it hits the bottleneck layer. Then, that output is upsampled over a few layers before outputting the final image. This can be seen in the image below.
The key highlight of a U-Net model is its ability to make skip connections between layers of the same size.
In the first few layers of convolutions, low-level features such as edges or blobs are captured, and in the subsequent layers, high-level features come into play. These high-level features are constructed from lots of low-level features. The skip connections give U-Net the power to combine the low-level features with the high-level features. This can be seen pictorially in the image below.
The first layer of the encoder has the same size as the last layer of the decoder, and they are merged using these skip connections. This process is repeated for all layers inside the U-Net.
4. PatchGAN Discriminator Model
The goal of the discriminator is the same as in the standard GAN. It needs to estimate the likelihood of an image being real or fake.
The difference between a PatchGAN and a normal Image GAN discriminator is that instead of estimating a whole image as being real or fake, it estimates the result based on small patches across the whole image. It goes convolutionally across the whole image and estimates NN patches as being real or fake.
Have a look at the following image. We can see that the output of this network is a single feature map of real or fake predictions that are averaged and given a single score. Patches of size 7070 have been found extremely effective in such cases.
5. Loss Functions in a Pix2Pix GAN Model
The PatchGAN Discriminator model is trained in a traditional way a regular GAN discriminator is trained. The goal of the loss is to minimize the log-likelihood of identifying real and fake images. But, because the training of the discriminator is faster than the training of the generator, we will halve the loss, in order to slow it down:
$$ New Discriminator Loss = 0.5 \times Discriminator Loss $$
On the other hand, the generator model is trained with both, the adversarial loss functions that are used for the discriminator, and the L1 loss or the mean absolute pixel difference between the generated image and the expected target image. These two loss functions are composed of a so-called Composite Loss Function.
The goal of the adversarial loss is to evaluate if the generator model can generate images that are plausible in the target domain. For example, it can estimate whether a generated google maps photo looks like an original google maps image.
On the contrary, the L1 loss is used to regularize the generator to output images that are plausible translations of the source image.
We usually want to give the L1 loss more control over the training than the adversarial loss, for this reason, we use a hyperparameter which is usually set to 10 or 100. This lambda is multiplied with the L1 loss to give it more importance over the adversarial loss.
$$ \text { Generator Loss }=\text { Adversarial Loss }+\lambda \times \text { L1 Loss } $$
5. Implementing a Pix2Pix GAN Model
Now that we have gained a fair understanding of the architecture and basics of a Pix2Pix GAN model, let us try and implement our learnings using PyTorch. To begin with, we will select a suitable dataset to use for our example.
Choosing a Dataset
There are many types of datasets available online. Here are a few among them:
- Satellite and Maps
- Black & White and Colours
- Product Sketches and Product Photographs
For our example, we will use the Satellite and Maps dataset. This dataset consists of satellite images of New York and their corresponding Google Maps formats, as can be seen in the image below.
We will use this dataset to convert a map image to a satellite-type image and vice-versa.
Downloading the Dataset
To download this dataset, we will just execute a straightforward command. This will download our dataset as a compressed file named maps.tar.gz
. We need to unzip it and place it in a certain structure.
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz -O /content/maps.tar.gz
Code language: JavaScript (javascript)
The following code will unzip our dataset and create a folder with a subfolder in it, where we will store our data.
!tar -xvf maps.tar.gz
!mkdir images/
!mkdir images/maps
We have successfully created all folders. We can now simply move our images into the corresponding folders.
!mv -v /content/maps/train /content/images/maps
!mv -v /content/maps/val /content/images/maps
Initializing and Defining the Model
Let us start by importing the libraries that we will be using.
import os
import numpy as np
import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchsummary import summary
from torchvision.utils import save_image
Code language: JavaScript (javascript)
The number of epochs is set to a low number as GANs usually take thousands of epochs of training. So, we can’t expect to get perfect images. The image size is \(256\times256 \) and it is an RGB image, so it has 3 channels.
# Training details
n_epochs = 20
lr = 0.0002 # Learning rate
# Dataset details
img_height = 256
img_width = 256
channels = 3
Code language: PHP (php)
Since this is not a simple classification problem, we will need to generate a custom PyTorch object for loading the dataset. It will list out all the images and load both image representations that are paired and return them as such.
class MAPDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms_=None, mode="train"):
self.transform = transforms.Compose(transforms_)
self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
if mode == "train":
self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)])
w, h = img.size
img_A = img.crop((0, 0, w / 2, h))
img_B = img.crop((w / 2, 0, w, h))
img_A = self.transform(img_A)
img_B = self.transform(img_B)
return {"source": img_A, "target": img_B}
def __len__(self):
return len(self.files)
Next, we define the transformations, the resizing, the transformation to the tensors, and also our normalization step. We, then, load our images on an DataLoader
object. We will have one for training and one for validation.
Notice how for the training step, we will have a batch size of 1. We will send image by image to our model, but for the validation, we will set it as high as 10.
# Configure dataloaders
transforms_ = [
transforms.Resize((img_height, img_width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
dataloader = DataLoader(
MAPDataset("/content/images/maps", transforms_=transforms_),
batch_size=1,
shuffle=True,
)
val_dataloader = DataLoader(
MAPDataset("/content/images/maps/", transforms_=transforms_, mode="val"),
batch_size=10,
shuffle=True,
)
Code language: PHP (php)
Following the best practices when working with GANs, we need to initialize our weights for the convolution and batch normalization as normal distributions with a mean of 0 and std of 0.02.
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
Code language: PHP (php)
Now, let’s create the discriminator model.
The discriminator model takes in both the fake image and its target representation. The PatchGAN is designed such that the output prediction maps to a \(70\times70 \) square or a patch of the input image. It will concatenate the input images and predict a patch output of predictions.
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(3*2, 64, 4, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 64*2, 4, stride=2, padding=1)
self.conv3 = nn.Conv2d(64*2, 64*4, 4, stride=2, padding=1)
self.conv4 = nn.Conv2d(64*4, 64*8, 4, stride=2, padding=1)
self.conv6 = nn.Conv2d(64*8, 1, 4, padding=1, bias=False)
self.bnorm1 = nn.BatchNorm2d(64*2)
self.bnorm2 = nn.BatchNorm2d(64*4)
self.bnorm3 = nn.BatchNorm2d(64*8)
self.bnorm4 = nn.BatchNorm2d(64*8)
self.zeropad = nn.ZeroPad2d((1, 0, 1, 0))
self.leakyRelu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x1):
x = torch.cat((x1[0], x1[1]), 1)
# print(x.shape)
x = self.conv1(x)
x = self.leakyRelu(x)
x = self.conv2(x)
x = self.bnorm1(x)
x = self.leakyRelu(x)
x = self.conv3(x)
x = self.bnorm2(x)
x = self.leakyRelu(x)
x = self.conv4(x)
x = self.bnorm3(x)
x = self.leakyRelu(x)
x = self.zeropad(x)
x = self.conv6(x)
# print(x.shape)
return x
discriminator = Discriminator().to(device)
discriminator.apply(weights_init_normal);
Let us pause here for a moment to remember the encoder-generator part of our Pix2Pix GAN architecture. In this generator, we perform a process that downsamples our image and, then, again upsamples it.
The U-Net is famous for its skip connections between the encoding layers and the corresponding decoding layers. We can see the skip connections in the following image. They are combining information from the starting layers with the information in the generator part, which also means they are combining high-level and low-level information.
Now that we understand how the generator works, we can build it in Pytorch. First, we build the downsampling part and, then, the upsampling part. Towards the end, we will combine these two parts into a big class.
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, batchnorm=True):
super(UNetDown, self).__init__()
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
if batchnorm:
layers.append(nn.BatchNorm2d(out_size))
layers.append(nn.LeakyReLU(0.2))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=False):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_size),
]
if dropout:
layers.append(nn.Dropout(0.5))
layers.append(nn.ReLU(inplace=True))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class Generator(nn.Module):
def __init__(self, out_channels=3):
super(Generator, self).__init__()
self.down1 = UNetDown(3, 64, batchnorm=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512)
self.down5 = UNetDown(512, 512)
self.down6 = UNetDown(512, 512)
self.down7 = UNetDown(512, 512)
self.down8 = UNetDown(512, 512, batchnorm=False)
self.up1 = UNetUp(512, 512, dropout=True)
self.up2 = UNetUp(1024, 512, dropout=True)
self.up3 = UNetUp(1024, 512, dropout=True)
self.up4 = UNetUp(1024, 512, dropout=True)
self.up5 = UNetUp(1024, 256)
self.up6 = UNetUp(512, 128)
self.up7 = UNetUp(256, 64)
self.upsampleLayer = nn.Upsample(scale_factor=2)
self.zeropad = nn.ZeroPad2d((1, 0, 1, 0))
self.final_conv = nn.Conv2d(128, out_channels, 4, padding=1)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)
x = self.upsampleLayer(u7)
x = self.zeropad(x)
x = torch.tanh(self.final_conv(x))
return x
generator = Generator().to(device)
generator.apply(weights_init_normal);
For the loss, as we mentioned we need to have two losses. The adversarial loss and the L1
loss. We also define which will have a value of 100. This way we can give the L1 loss more importance than the adversarial loss.
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100
Code language: PHP (php)
For the optimizers, we will be using Adam Optimizers for both of our models.
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
Code language: PHP (php)
After every 100 epochs, we will save an image to check our model’s evolution and progress. This function will take out the first batch from our validation set and pass it through our generator. We will stack the input image, the fake generated image, and also the target image.
def sample_images(batches_done, epoch):
imgs = next(iter(val_dataloader))
real_A = imgs["source"].to(device)
real_B = imgs["target"].to(device)
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
save_image(img_sample, f"/content/Epoch:{str(epoch).zfill(3)}~batch:{str(batches_done).zfill(4)}.png", nrow=5, normalize=True)
Code language: PHP (php)
Our Pix2Pix GAN model is successfully initialized and hyperparameters are also defined. Now, comes the interesting part, i.e., the training. Let’s see how we train our model.
Training the Pix2Pix GAN Model
As with the other GAN models that we trained in the previous chapters, here too, the training process is divided into two important parts:
- Training the Generator
- Training the Discriminator
We will iterate batch by batch and save our images as real_A
and real_B
. The real_A
will be the image that we will input into our generator and expect it to return an image that will look like real_B
.
We need to generate two tensors that will represent a list of Ones and Zeros. If it has a value of 1, it is considered to be a real image and if it has a value of 0, it is a fake image.
We pass our input image into the generator with the goal of tricking the discriminator to think that this is a real image. We pass this image into the discriminator together with the source image. We calculate the error using the first loss by passing in the prediction and the labels of ones. We also calculate the pixel-wise loss on the fake and real target images.
Next, we combine the two losses to get the total loss and perform gradient calculation along with updating the model parameters.
# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)
for epoch in range(n_epochs):
for i, batch in enumerate(dataloader):
# Model inputs
real_A = batch["source"].to(device)
real_B = batch["target"].to(device)
# Adversarial ground truths
valid = torch.ones((real_A.size(0), *patch)).to(device)
fake = torch.zeros((real_A.size(0), *patch)).to(device)
# Train Generator
optimizer_G.zero_grad()
# GAN loss
fake_B = generator(real_A)
pred_fake = discriminator((fake_B, real_A))
loss_GAN = criterion_GAN(pred_fake, valid)
# Pixel-wise loss
loss_pixel = criterion_pixelwise(fake_B, real_B)
# Total loss
loss_G = loss_GAN + (lambda_pixel * loss_pixel)
loss_G.backward()
optimizer_G.step()
Code language: PHP (php)
The next step is to train the discriminator. We pass in the real target image with the real source image and get a prediction. This is called the Real Prediction.
We perform the loss calculation on this data and get the “real loss”. The same steps are applied again, just replacing the real part with the fake image and labels. We find an average between the real and fake loss, then, perform gradient calculation and finally, update our discriminator model.
# Train Discriminator
optimizer_D.zero_grad()
# Real loss
pred_real = discriminator((real_B, real_A))
loss_real = criterion_GAN(pred_real, valid)
# Fake loss
pred_fake = discriminator((fake_B.detach(), real_A))
loss_fake = criterion_GAN(pred_fake, fake)
# Total loss (0.5 to slow down the process of training)
loss_D = 0.5 * (loss_real + loss_fake)
loss_D.backward()
optimizer_D.step()
Code language: PHP (php)
Every 100 batches we will be saving the image our generator generates and also the models after every epoch.
if i % 100 == 1:
sample_images(i, epoch)
# Save model checkpoints
torch.save(generator.state_dict(), f"/content/generator_{epoch}.pth")
torch.save(discriminator.state_dict(), f"/content/discriminator_{epoch}.pth")
Code language: PHP (php)
Let us remind ourselves again that 20 epochs are a really small number for such a network. It needs to be at least in hundreds or even thousands for us to achieve the best results.
Have a look at the image below where we can see check some of the images from mapping that our generator has learned after we finish the training.
In the image above, we can see that the results are on the right track. Our model has learned the difference between road, building, and grass. The only things missing here are some of the details such as the quantity of grass or the finishing of the roads.
Summary
Now that you have learned all about Pix2Pix GAN models, you can start practicing building your own models using different datasets and modifying hyperparameters such as the number of epochs, the learning rate, and others.
In the following posts, we will study more types of GANs and see how their implementations are similar or different from what we have learned so far.