Loading spinner
Deep Image Colorization
Mar. 20th, 2021ยท14 min read

Introduction

Although this was a school project, this was open-ended as opposed to a consistent assignment so I am able to share much of the code here! Alongside this project page, you can also find more details in the report that was submitted for this work.

Model Overview

I opted to implement the model as detailed in Deep Koalarization: Image Colorization using CNNs and Inception-Resnet-v2. The main points of the network are as follows:

  • CNN encoder for initial feature extraction
  • Additional feature extractor using a pre-trained Inception-ResNet-V2
  • Fusion layer to combine both sets of features
  • Decoder to upsample and estimate the output from the fused features

Model Overview
Model Overview

I enjoy this network as it combines many of the benefits of transfer learning with an additional bare network for fine tuning as the domain transfer will never be perfect with a pre-trained encoder. The paper linked prior goes into much more detail and was a great guide for my implementation, if you are interested I advise you read it.

Problem Statement

The problem if image colorization is quite an interesting one the simplifies into an elegant self-supervised problem. Given a grayscale image, can we estimate the additional color channels required to gave a full-color image. This sounds difficult at first but with some tweaks to the images, it's quite easy to formulate this as a simple problem.

In the RGB color space, it's hard to imagine how you may split the original image into a set of components, but if you convert the image to be in the LAB color space the problem can easily be self-supervised. In the LAB color space, the L component is the grayscale component and the AB channels encode color information. Now if we convert our dataset images to be in the LAB space, we easily get our input L channel and our ground truth AB channels! Therefore, the dataset requirements are just full-color images. The problem statement simplified as such, we are looking to estimate a function f such that:

Where will be our large convolutional net.

Dataset

I opted to use the places2 dataset as the context of colorizing an image of scenery or typical views from the human perspective are well captured in this dataset. As the dataset is very large and I had limited hardware to train on, I opted to subset the dataset into smaller splits.

I first created folders to store the subsets.

I then split images from the validation set into specified counts for each split.

Now that I have folders containing my data, I needed to work on a PyTorch dataset for data loading. There are a few things to consider when creating the dataloader.

  • The input image is the grayscale channel (L channel in LAB space) so we'll need to retrieve that component after converting the RGB dataset images
  • The two other components (AB in LAB) are our ground truth output so we'll also retrieve those as such
  • Pre-Trained ResNet requires a different input image size than the base encoder
    • We'll need to resize the dataset images to have each variant

All things considered, the following code achieves all of these goals and helps us bring all the necessary data into PyTorch.

Then it's as simple retrieving the relevant categories of images and passing them through the PyTorch DataLoader.

Now we have all of our required dataloaders to move forward with model construction.

Model Construction

Here I will discuss how I constructed the entire model as specified in the original paper.

Base Encoder

Here is where I construct the layers for the base encoder block. These layers will not be pre-trained and will help in refining results more specific to this dataset as opposed to the dataset the pre-trained net is based on. This is a simple CNN network so we can model it as a sequential model wrapped in a module API class.

Pre-Trained Encoder

The pre-trained encoder is quite simple, PyTorch has models available that are pre-trained on ImageNet so all we need to do is call an instance of the network.

Fusion

The fusion layer combines the features from the base encoder and the pre-trained encoder into a single input for the decoder network.

Decoder

The decoder is very similar to the encoder but in reverse, the number of channels reduces as the feature map sizes increase through the upsampling layers. This is where the encoded features from the base encoder and pre-trained encoder post fusion are used to estimate the AB channels.

Whole Network

Lastly, we just need to connect all the prior blocks into our whole network!

Training

Here I will go over decisions made for training, results of training, and some issues I came across.

Optimizer

I chose to use the Adam optimizer with the following parameters:

ParameterValue
Learning Rate0.0012
Weight Decay1e-6

In PyTorch that looks like this:

Criterion

As we will be comparing the expected AB results with the ground truth we can use a simple MSE loss for this problem.

Training Loop

Nothing too special here. Given the prior dataloaders we made it's easy to enumerate over the loader and proceed with moving data to the GPU, running the forward pass, computing loss, and run an optimizer step. Sadly, given my hardware limitation I was not able to compute the validation loss during training which made backtracking to the best model impossible, I just had to assume training was at a reasonable point from the training loss.

Loss Results

The training loss was very interesting for this model, a shape I've never seen before. In hindsight, this seems like too large of a learning rate but I didn't know better back then.

Training Loss
Training Loss

At the end of training a checkpoint is saved for us in inference later on.

And a little helper function to retrieve the model back.

After that we can test some images!

Results

Below we see a set of sample images from the test set that were passed through the model. The left column are the input grayscale images, the center are the predicted images, and the right side are the original images. As we can see, this hardly passes the human eye but the performance isn't too bad.

Test Set Images
Test Set Images

It's interesting to start to characterize the performance. For example, consider the first row images of the building. The model does quite well on the sky and building color as buildings don't come in too many colors so it's easier to be more confident about what the color is. On the other hand, people's clothing greatly vary in color and so it's much more difficult to tell what color they should be.

Loss Problem

The MSE loss promotes predicting unsaturated colors for objects that are harder to tell. An intuitive way to think about this is considering two objects that look identical in grayscale bt vary in true color. The model could choose to be confident and choose a saturated color but if it's wrong it's VERY wrong. So to minimize MSE, the model is conservative and chooses median colors to represent objects of ambiguity. We see this as a fairly vibrant sky whereas the people are drab and gray.

Gray People
Gray People

Another way to think about this is to consider what are the possible colors for a given object, we see those that are well-defined are more vibrant whereas others are not as vibrant.

Object Colors
Object Colors

Either way the solution is a loss that allows for many possibilities for a given object as opposed to dictating a single correct solution. Such losses exist and are detailed here but I did not implement it.

Conclusion

As always let me know if you have any questions or suggestions but I hope you enjoyed the post!