#022 PyTorch – DeepLab v2 Semantic Segmentation in PyTorch
Highlights: Semantic segmentation is an important subject in Computer Vision that enables a model to label specific regions of an image according to what’s being shown. DeepLab is a state-of-the-art model by Google with many versions making a family of algorithms used for semantic segmentation.
In this tutorial post, we will introduce the DeepLab algorithm and specifically talk about the DeepLab v2 that introduced three famous advancements in the field of semantic segmentation. Towards the end, we’ll implement our theoretical knowledge into a PyTorch code too. So, let’s begin!
Tutorial Overview:
- Introduction to DeepLab Algorithm
- Atrous (Dilated) Convolution
- Multi-Scale Image Representations using Atrous Spatial Pyramid Pooling
- Fully-Connected Conditional Random Fields for Accurate Boundary Recovery
- DeepLab v2: PyTorch Code
1. Introduction to DeepLab Algorithm
When DeepLab v2 was published in 2017, it was achieving state-of-the-art results in semantic segmentation tasks. The paper that proposed DeepLab v2 presented novel solutions to the current challenges.
One of the challenges arose when Deep Convolutional Neural Networks (DCNNs) were applied to a problem of semantic segmentation in order to predict output image label maps. This type of problem is known as a dense prediction task.
In contrast to this, there are other classification prediction tasks that aim to recognize an object in the image. Hence, when working with the classification tasks, our goal is to compress the input image (a large number of input pixels) and to extract meaningful and interpretable features. As a result of this process, our input image shrinks considerably in size. We can also say that the spatial dimension is lost! This means that the image is downsampled during the heavy processing in max-pooling layers.
Such a loss of spatial resolution is very undesirable for the process of semantic segmentation. The reason is that we want to get the output image of the same size as the input one. Moreover, without accurate spatial resolution, the precise segmentation of edges and contours is difficult and challenging. Hence, DeepLab offered successful solutions to the following challenges:
- Challenge 1: Atrous (Dilated) Convolution
- Challenge 2: Objects at Multiple Scale.
- Challenge 3. Poor Localization Refined with Fully Connected Conditional Random Fields
The model proposes the following ideas: Apply atrous convolution. Then, Apply spatial pyramids. And, finally, modify the output result using the Fully Connected Conditional Random Fields (CRFs) to improve the segmentation boundaries.
Let us move ahead with further explanations on the building blocks of the DeepLab model.
2. Atrous (Dilated) Convolution
Have a look at the image below.
As we can see, our final wish is to perform semantic segmentation. We have an input image on the left, which is defined by its width, height, and it can be an RGB image (3 channels). Then, we would like to have the output image of the exactly same size, but with carefully detected label instances (e.g. dog and cat).
So, instead of RGB channels, we will now have segmentation classes as channels. Moreover, the problem with CNNs, is that convolutional + max pooling layers reduce the size of the feature maps and decrease the spatial resolution. For instance, the size of the input image may reduce in size 32 times. Although this is not a problem for classification tasks, in semantic segmentation it represents a big challenge. In other words, we need to find a way to circumvent this downsampling introduced by a max pooling layers. But how?
DeepLab inventors compared classical convolutional block with an alternative approach. Have a look at the image below (blue arrows). In the top block, we have a standard downsampling of an image with a factor of 2 (max pooling stride=2). This downsampled image will be referred to as a low resolution. Next, we apply a convolutional filter of kernel size=7. The stride of a convolutional filter is set to 1. Then, the processed image (feature map) is upsampled with a factor of 2. We obtain the resulting image in the upper right corner.
In the lower block (represented by red arrows), we see an application of the atrous (dilated) convolution. First, we are processing the original image and we refer to it a as a higher resolution. Note, that we use the original filter as the one above. That is, it has the same kernel size=7 and uses a stride=1. However, the filter is upsampled! This factor, called rate (also known as dilation), which is set to 2, implies that we have upsampled our filter and introduced zeros into the newly-obtained positions. Finally, we can see that the result is that of a processed image (feature map) with the same size as the original image.
Have a look at this simplistic code as presented in [1]. The author has carefully explored the above diagram block replicating this experiment with standard images. Here is our adaptation of the method.
Let’s start with importing the necessary libraries first.
import numpy as np
import torch
from torch.nn import Conv2d
from PIL import Image
import matplotlib.pyplot as plt
import requests
from scipy.ndimage import zoom as resize
Code language: JavaScript (javascript)
url_image = "https://raw.githubusercontent.com/maticvl/dataHacker/master/DATA/Frame1.jpg"
im = Image.open(requests.get(url_image, stream=True).raw).convert("L")
im = im.resize((100,100))
plt.imshow(im, cmap="gray")
plt.show()
im = np.array(im)
Code language: PHP (php)
Next, we import the image and resize it to 100×100. Subsequently, we define in PyTorch two filters as defined in the image above: 1) standard convolution :: conv_1 and 2) atrous (dilated) convolution :: conv_2.
kernel_size = 7
dilation = 1
stride = 1
padding = 3
# Define a 2D convolutions
conv_1 = Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size,
dilation=dilation, padding=padding)
kernel_size = 7
dilation = 2
stride = 1
padding = 6 # NOTE this is also different
conv_2 = Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size,
dilation=dilation, padding=padding)
# Use the same weights and bias
conv_2.weight.data = conv_1.weight.data
conv_2.bias.data = conv_1.bias.data
Code language: PHP (php)
Then, we process the input image with the following code and plot the images.
# 1) Downsampling 100x100 -> 50x50
im1 = resize(im, (0.5, 0.5)).reshape(1, 1, 50, 50) # B, C, H, W
# 2) Apply the convolution to that image
input_image = torch.Tensor(im1)
output_image = conv_1(input_image).detach().numpy()[0, 0]
# 3) Upsample
upsampled = resize(output_image, (2, 2))
im2 = im.reshape(1, 1, 100, 100) # B, C, H, W
# Apply the convolution to that image
input_image = torch.Tensor(im2)
output_image = conv_2(input_image).detach().numpy()[0, 0]
Code language: PHP (php)
Here, we show the original image (above) as well as the two output images after processing with convolutional filters (below). On the left, we see the output of an atrous convolution, whereas, on the right, we can see the output of the standard process with image downsampling, convolution and upsampling.
Observe that the atrous convolution creates much nicer edges and lines in the image. For example, note that at the image on the right, we can see the lines perpendicular to the road lines that actually do not exist in the original image.
1D Atrous (Dilated) Convolution
To gain better insights into atrous convolution, we will explore it in 1D as well. It is defined with the following formula as written below. In case that the rate \(r=1 \), this formula simplifies to a standard 1D convolution. Here, \(w \) is a convolutional filter of length \(K \).
$$ y[i]=\sum_{k=1}^{K} x[i+r \cdot k] w[k] $$
import numpy as np
import matplotlib.pyplot as plt
x=np.random.randn(100) + np.sin(2*np.pi*0.1*np.arange(100))
r=2 # r = rate
K=3 # length a mean 1D filter
w = np.ones(K) / K
stride = 1
Code language: PHP (php)
Now, we will show two ways of implementing this simple formula. Feel free to experiment with the parameter \(r \) and observe the similarities.
# implementing the formula y[i] = sum ( x[i+r*k] * w[k] ) k=1..K
y=[]
for i in range(len(x) - r*K-1):
temp = 0
for k in range(0,K):
temp += x[i*stride+r*k] * w[k]
y.append(temp)
y = np.array(y)
plt.plot(x)
plt.plot(y)
Code language: PHP (php)
# the same formula as above, but here we are zero padding the filter
# Atrous convolution
y= []
w2 = np.zeros(r*len(w))
w2[::r] = w.copy()
for i in range(len(x) - r*K-1):
temp = 0
for k in range(0,K*r):
temp += x[i+k] *w2[k]
y.append(temp)
Code language: PHP (php)
Notice that both the versions generated the same output signal. The first one can be viewed as a formula where we have performed downsampling, thereby, generating a low resolution signal \(x \). On the other hand, the second one is where we have extended the filter to size \(2\times K \) and filled it with zeros.
Have a look at how 1D atrous convolution can be pictorially represented, in the image below.
We have two options. At the top, we have a downsampled signal (low resolution) that is convolved with a kernel of size 3. In this way, we obtain the so called sparse features. Note that the line connections can be viewed as “multiplication with filter coefficients”. Hence, the output element is obtained after tree multiplications and one summation.
On the other hand, we have an example that illustrates atrous (dilated) convolution at the bottom. Now, the original signal has higher resolution. Compared to that we have twice as much elements, and thus, we preserve the high resolution. In contrast, now the filter is upsampled and filled with zeros.
Let’s further interpret this. We can see this as if we are processing \(2K \) elements of the signal \(x \), However, since every second filter coefficient is now zero, we will only use \(K \) samples from the original high resolution signal. Thus, we do not show connections in the image where we would assume a multiplication with zeros.
Similarly as above, the connecting lines illustrate element-wise multiplication between input signal elements and the filter coefficients. Once again, note that the filters had a kernel of size=3 (non zero coefficients). In both cases, the stride is 1! Don’t forget that!
And in the second atrous approach, we have a rate=2. This is the same rate that we used to downsample our original signal (top image). In the end, we can say that the second approach allows us to process high resolution feature maps/signals and to extract dense features.
3. Multi-Scale Image Representations using Atrous
Spatial Pyramid Pooling
In this part, we will discuss the detection and segmentation of objects of different sizes. Although CNNs are very effective when processing objects at different scales, further improvements are still possible. In particular, it is possible to explicitly account for object size and improve the effectiveness of image processing at various scales.
The approach adopted in this paper is fairly simple and builds upon the proposal of the R-CNN spatial pyramid pooling method. More specifically, the authors combined atrous convolution using different sampling rate values. Subsequently, the extracted features are further fused and combined to generate the final output. This method is called the “Atrous Spatial Pyramid Pooling” – ASPP and is shown in the image below.
4. Fully-Connected Conditional Random Fields for Accurate Boundary Recovery
As we already discussed at the beginning of the post, the application of the CNNs can be successfully applied to object classification tasks. In other words, we can easily and accurately predict the class label. For instance, is it a dog or a cat in that image? However, the output image with precise class labeling still suffers from inaccurate border segmentation. Have a look at the example below.
To improve this, a proposed solution is to combine the DCNN with the Fully-Connected Conditional Random Fields (CRFs). This idea resulted in the remarkably successful localization of the object boundaries, producing accurate semantic segmentation results. Moreover, it was recovering object boundaries at the level of detail that was well beyond the existing methods (back in 2017.). The use of fully connected CRFs is defined with the following energy function:
$$ E(\boldsymbol{x})=\sum_{i} \theta_{i}\left(x_{i}\right)+\sum_{i j} \theta_{i j}\left(x_{i}, x_{j}\right) $$
Here, \(\boldsymbol{x} \) is the label assignment for pixels.
In addition, the researchers defined a unary potential \(\theta_{i}\left(x_{i}\right)=-\log P\left(x_{i}\right) \), where \(P\left(x_{i}\right) \) is the label assignment probability at pixel \(i \) as computed by a DCNN [used loss function].
In other words, the first term is defined as a loss function defined by a DCNN for every pixel. On the other hand, the second term can be viewed as an interplay between all pixel pairs \((i, j) \).
Note that the \(\theta_{ij}\) \((x_{i},x_{j}) \) is defined with the following formula.
$$ \begin{gathered}\theta_{i j}\left(x_{i}, x_{j}\right)=\mu\left(x_{i}, x_{j}\right)\left[w_{1} \exp \left(-\frac{\left\|p_{i}-p_{j}\right\|^{2}}{2 \sigma_{\alpha}^{2}}-\frac{\left\|I_{i}-I_{j}\right\|^{2}}{2 \sigma_{\beta}^{2}}\right)\right. \\\left.+w_{2} \exp \left(-\frac{\left\|p_{i}-p_{j}\right\|^{2}}{2 \sigma_{\gamma}^{2}}\right)\right]\end{gathered} $$
Here, \(\mu\left(x_{i}, x_{j}\right)=1 \) if \(x_{i} \neq x_{j} \), and zero otherwise.
This means that we have a \(\mu \) function that is equal to 1 only if we have different pixels \(x_{i} \) not equal \(x_{j} \). Therefore, only the pixels from different label classes are penalized.
The equation above includes a sum of Gaussian functions.
Recall that we have encountered a similar equation when we have explored a bilateral filter.
The values, \(w_{1} \), \(w_{1} \), and three \(\sigma \) -s \(\alpha \), \(\beta \) and \(\gamma \), may be viewed as hyper parameters.
The distances \(p_{i} \) and \(p_{j} \) refer to the relative distance between the two pixels. We will have a small value of this distance for pixels that are close to each other. In addition, the distance (scalar value) between \(I_{i} \) and \(I_{j} \) (3-dimensional color RGB vectors) represents Euclidean distance between the pixel intensity values.
The larger the color difference, the larger the distance is. However, the “distances” have a minus sign in front of them, so the highest impact within the energy function will be when both the position and color are close to each other. Therefore, neighboring pixels with a similar color will be highly penalized if they come from different label classes. Next, we have a repeated term for location distance used in the second Gaussian function. It is designed to control the smoothness of the energy function and only considers the distance between pixels.
This covers the main idea of the fully connected CRFs. Let us now see how optimization of this function can generate improved results. Check out the iterative process below. You’ll be amazed by the results!
5. DeepLab v2: PyTorch Code
In the original paper, a large number of architectures had been explored. They were based both on the ResNet-101 and VGG network architectures. Here, we have opted to share the architecture as illustrated in [2], and we liked this “text based” defined architecture.
The most interesting part is the ASPP layer that can be represented as the following image:
We will base our PyTorch implementation on the work presented in [3]. At this repo, it is possible to find the code that includes the post processing step with the fully connected CRFs. Here, we will not explore CRFs implementation, and we opt to present the network based on the ResNet architecture.
Here, there are two important blocks that are implemented. First, there is a set of convolutional layers that is rather straightforward. Second, there is a little bit more advanced module of ASPP.
Below, we have the class DeepLabV2 and we can see that it uses ResNet blocks.
class DeepLabV2(nn.Sequential):
"""
DeepLab v2: Dilated ResNet + ASPP
Output stride is fixed at 8
"""
def __init__(self, n_classes, n_blocks, atrous_rates):
super(DeepLabV2, self).__init__()
ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates))
ASPP block is actually implemented using the Conv2d layer. This is achieved by passing the rate as a dilation parameter.
class _ASPP(nn.Module):
"""
Atrous spatial pyramid pooling (ASPP)
"""
def __init__(self, in_ch, out_ch, rates):
super(_ASPP, self).__init__()
for i, rate in enumerate(rates):
self.add_module(
"c{}".format(i),
nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True),
)
for m in self.children():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
return sum([stage(x) for stage in self.children()])
...
__name__ == "__main__":
model = DeepLabV2(
n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
)
Code language: JavaScript (javascript)
In the call above, we have specified how the model should be instantiated.
Finally, with these basic blocks, we hope that we managed to convey the message and ideas of the DeepLab v2 architecture. What is your next segmentation project? Do let us know!
Semantic Segmentation Using DeepLab with PyTorch
- DeepLab is a model proposed by Google to solve semantic segmentation problems
- DeepLab v2 was introduced in 2017 with significant improvements
- DeepLab was made to tackle the challenges of Deep Convolutional Neural Networks (DCNNs)
- First challenge is tackled using Atrous Convolution
- Second challenge takes care of Multi-Scale Objects
- The third challenge is handled using Fully-Connected Conditional Random Fields
Summary
That’s all folks! This post presented an overview of the DeepLab v2 family. Our goal was to present the main ideas of this model. We do not present the full code since very quickly after this edition, the authors developed a DeepLab v3. This version you can find with the PyTorch official documentation as well [4]. We will explore this model in the next post. Do tell us how you are liking our content. If you like it, you’ll love what we’re doing on Twitter and YouTube. Follow our channels and wait for us to show up next with another interesting topic. Take care 🙂
References:
- https://github.com/Machine-Learning-Tokyo/__init__/blob/main/session_02/DeepLabv2.ipynb
- https://gist.github.com/BassyKuo/7b3645c3744e3b833de49cb7fa460427
- https://github.com/kazuto1011/deeplab-pytorch
- https://colab.research.google.com/github/pytorch/pytorch.github.io/blob/master/assets/hub/pytorch_vision_deeplabv3_resnet101.ipynb#scrollTo=serial-nelson