#012 Understanding Latent Space in Generators

#012 Understanding Latent Space in Generators

Highlights: Hello and welcome. In the previous posts, we have seen the architecture, building, and implementation of three DCGAN models using different datasets. By now, you must have gotten a fair bit of idea about Generative Adversarial Networks in general.

In this post, we will go deeper into the workings of a generator in a typical GAN model. More specifically, we will understand the Latent Space which is used as an input for the generator. So, let’s begin with our post.

Tutorial overview:

  1. Latent Space & its Mathematics
  2. Latent Space in DCGAN Generator for CelebA Dataset
  3. Interpolating Between Different Faces

1. Latent Space & its Mathematics

A latent space is used by the generator to generate images through mapping. Let’s start by understanding the mathematics involved in this.

As mentioned earlier, the generator takes points from the Latent Space as input and generates a new image.

In most cases, the Latent Space is a vector consisting of 100 numbers. In the training phase, the generator learns how to map each point to generate an image. Every time the model is re-trained, it will learn a different mapping.

Different images are usually created by creating a different Latent Space, a random one.

The GANs are experimental in a way that we can do arithmetical operations on the Latent Space to generate new faces.

For instance, if we take a Latent Space that was used to generate a smiling man and we subtract a neutral man, we would get a latent space for a smile. If we add this result to a neutral woman, we would get a smiling woman. Have a look at the pictorial representation of the same in the following image.

The arithmetic in a Latent Space is always performed on the points in the particular Latent Space.

2. Latent Space in DCGAN Generator for CelebA Dataset

Let’s take the example of the DCGAN model that we built and trained in the previous post, for the CelebA dataset. We will perform some arithmetic operations on some of the Latent Spaces for this model.

The first thing we need to do is to create our generator model and load all the parameters that we have saved by training the model in the previous post.

import torch
import torch.nn as nn
import torchvision
 
import matplotlib.pyplot as plt
import math
import numpy as np
 
feature_map       = 64
latent_space_size = 100
 
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
 
        self.convTranspose1 = nn.ConvTranspose2d( latent_space_size, feature_map * 8, 4, 1, 0, bias=False)
        self.convTranspose2 = nn.ConvTranspose2d( feature_map * 8  , feature_map * 4, 4, 2, 1, bias=False)
        self.convTranspose3 = nn.ConvTranspose2d( feature_map * 4  , feature_map * 2, 4, 2, 1, bias=False)
        self.convTranspose4 = nn.ConvTranspose2d( feature_map * 2  , feature_map * 1, 4, 2, 1, bias=False)
        self.convTranspose5 = nn.ConvTranspose2d( feature_map * 1  , 3              , 4, 2, 1, bias=False)
 
        self.bnorm1 = nn.BatchNorm2d(feature_map * 8)
        self.bnorm2 = nn.BatchNorm2d(feature_map * 4)
        self.bnorm3 = nn.BatchNorm2d(feature_map * 2)
        self.bnorm4 = nn.BatchNorm2d(feature_map * 1)
 
    def forward(self, x):
        x = self.convTranspose1(x)
        x = F.relu(self.bnorm1(x))
 
        x = self.convTranspose2(x)
        x = F.relu(self.bnorm2(x))
 
        x = self.convTranspose3(x)
        x = F.relu(self.bnorm3(x))
 
        x = self.convTranspose4(x)
        x = F.relu(self.bnorm4(x))
 
        x = self.convTranspose5(x)
        x = torch.tanh(x)
        return x
 
# Create the generator
generator = Generator().to('cpu')
 
generator.load_state_dict(torch.load("/content/DCGAN_celeba_g.pth", map_location=torch.device('cpu')))

Next, we will use a helper function to generate N number of Latent Spaces of size 100.

# Helper function for creating a random latent vector
def generate_latent_vectors_z(N, random_state=42):
    torch.manual_seed(random_state)
    z = torch.randn(N, 100, 1, 1)
    return zCode language: PHP (php)

The images that our model is going to create will be in the range of [-1,1]. We need to un-normalize these images.
This is pretty straightforward in PyTorch. We take the normalization coefficients, mean, and std that we used for the normalization process, and perform an inverse of this, as shown in the code below.

inv_normalize = transforms.Normalize(
    mean=[-0.5/0.5, -0.5/0.5, -0.5/0.5],
    std=[1/0.5, 1/0.5, 1/0.5])

Then, we will construct a helper function to plot our generated images onto a matplotlib figure.

def plot_generated(examples, n):
  plt.figure(figsize=(15, 8))
  for i in range(n*n):
    plt.subplot(n, n, 1+i)
    plt.axis('off')
    
    img = inv_normalize(examples[i])
    plt.imshow(img.cpu().permute(1, 2, 0).detach().numpy())Code language: JavaScript (javascript)

In order to test our model, we will generate 100 images and plot them. From these 100 images, we will choose 3 images for 3 categories: neutral woman, smiling woman, and neutral woman.

n = 100
latent_points = generate_latent_vectors_z(n)
X = generator(latent_points)
plot_generated(X, 10)

From the above results, we can choose the best candidates in the following way.

smiling_woman_idx = [79, 87, 88]
neutral_woman_idx = [58, 70, 89]
neutral_man_idx   = [20, 32, 39]

As mentioned earlier, we will be taking 3 images from 3 categories each and then, will calculate the average between them.

def average_points(points, idx):
  vectors = points[idx]
  vectors = vectors.view(len(idx), 100)
 
  avg_vector = torch.mean(vectors, axis=0)
  
  all_vectors = torch.vstack([vectors, avg_vector])
  all_vectors = all_vectors.view(all_vectors.shape[0], 100, 1, 1)
  return all_vectors
 
# average vectors
smiling_woman = average_points(latent_points, smiling_woman_idx)
neutral_woman = average_points(latent_points, neutral_woman_idx)
neutral_man = average_points(latent_points, neutral_man_idx)
 
all_vectors = torch.vstack((smiling_woman, neutral_woman, neutral_man))
images = generator(all_vectors)Code language: PHP (php)

If we plot our images now, we get the following output.

Notice that in the first row, we can see the images of women smiling. In the second row, we can see the images of women with neutral expressions, and in the third, we can see men with neutral expressions.

The last image in each of the three rows represents the mean between the 3 other images in that category.

Moving ahead, we can now perform the arithmetic operations that we have been talking about before.

smiling_woman – neutral_woman + neutral_man = smiling_man

result_vector = (smiling_woman[-1] - neutral_woman[-1]) + neutral_man[-1]
result_vector = result_vector.unsqueeze(0)

Once we obtain our final Latent Space, we use it to generate a new image.

image = generator(result_vector)
plot_generated(image, 1, 1)

Voila! Our arithmetics have done the trick! Out of the different Latent Spaces, we have obtained a new one that has generated the image we were looking for.

3. Interpolating Between Different Faces

Let’s take another example. Here, we will create two random Latent Spaces for two random faces. Then, we will generate all mapping combinations from one Latent Space to the other. Have a look at the image below.

This is fairly easy in PyTorch. We generate two different Latent Spaces using the helper functions that we built.

latent_points = generate_latent_vectors_z(2)

We want to interpolate from one Latent Space to the other by taking certain steps. What does this mean?

  • In the beginning, we take 100% of the first Latent Space and 0% of the other.
  • In the 2nd step, we take 90% of the first and 10% of the second
  • In the 3rd step, we take 80% of the first and 20% of the second
  • In the final step, we take 0% of the first and 100% of the second
def interpolate_points(p1, p2, n_steps=10):
  ratios = np.linspace(0, 1, num=n_steps)
 
  p1 = p1.view(-1)
  p2 = p2.view(-1)
 
  vectors = []
  for ratio in ratios:
    v = (1.0 - ratio) * p1 + ratio * p2
    vectors.append(v.view(-1, 1, 1).numpy())
 
  return torch.from_numpy(np.array(vectors))Code language: PHP (php)

Now, we can pass our two Latent Spaces into this function and get new Latent Spaces that we can use to generate faces.

interpolated = interpolate_points(latent_points[0], latent_points[1])
X = generator(interpolated)
 
plot_generated(X, len(interpolated))

Well, this was one way of interpolating the points. We can do this using a different technique as well.

Spherical Linear Interpolation

SLERP or Spherical Linear Interpolation is a bit complex when it comes to its mathematics, but it performs the same function.

def slerp(val, low, high):
  omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
  so = np.sin(omega)
  
  if so == 0:
    return (1.0-val) * low + val * high
    
  return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high
Code language: JavaScript (javascript)

Next, we will use this math equation inside our intepolate_points() function.

def interpolate_points(p1, p2, n_steps=10):
  ratios = np.linspace(0, 1, num=n_steps)
 
  p1 = p1.view(-1)
  p2 = p2.view(-1)
 
  vectors = []
  for ratio in ratios:
    v = slerp(ratio, p1.numpy(), p2.numpy())
    vectors.append(v.reshape(-1, 1, 1))
 
  return torch.from_numpy(np.array(vectors))
Code language: PHP (php)

Same as before, we generate two Latent Spaces and use our interpolate_points() function on them again. This way we get 10 latent spaces and we generate 10 images. The output can be seen below the code.

n = 2
latent_points = generate_latent_vectors_z(n)
 
interpolated = interpolate_points(latent_points[0], latent_points[1])
X = generator(interpolated)
 
plot_generated(X, 10)

In the image above you can see 10 different images that we generated from the 10 latent spaces using Spherical Linear Interpolation.

Summary

Wonderful! After learning about DCGANs and various model implementations on different datasets, and the important topic of Latent Spaces in GANs, we are now ready to move ahead and study the other types of Generative Adversarial Networks. In the next post, we are going to talk about Conditional Generative Adversarial Networks.