Human Pose Estimation with Stacked Hourglass Network and TensorFlow

For full source code, please go to https://github.com/ethanyanjiali/deep-vision/tree/master/Hourglass/tensorflow. I really appreciate your ⭐STAR⭐ that supports my efforts.

Human is good at making different poses. Human is good at understanding these poses too. This makes body language become such an essential part of our daily communication, work, and entertainment. Unfortunately, poses have so much variance, so it’s not an easy task for a computer to recognize a pose from a picture…until we have deep learning!

With a deep neural network, the computer can learn a generalized pattern of human poses, and predict joints location accordingly. The Stacked Hourglass Network is just such kind of network, and I’m going to show you how to use it to make a simple human pose estimation. Although first introduced in 2016, it’s still one of the most important networks in pose estimation area, and widely used in lots of applications. No matter if you want to build a software to track basketball player’s action, or make a body language classifier based on a person’s pose, this would be a handy hands-on tutorial for you.

Network Architecture

Overview

To simply put, Stacked Hourglass Network (HG) is a stack of hourglass modules. It got this name because the shape of each hourglass module closely resemble an hourglass, as we can see from the picture below:

The idea behind stacking multiple HG (Hourglass) modules instead of forming a giant encoder and decoder network is that each HG module will produce a full heat-map for joint prediction. Thus, the latter HG module can learn from the joint predictions of the previous HG module.

Why would a heat-map help human pose estimation? This is a pretty common technique nowadays. Unlike facial keypoints, human pose data has lots of variances, which makes it hard to converge if we just simply regress the joint coordinates. Smart researchers come up with an idea to use heat-map to represent a joint location in an image. This preserves the location information, and then we just need to find the peak of the heat-map and use that as the joint location (plus some minor adjustment since heat-map is coarse). For a 256×256 input image, our heat-map will be 64×64.

In addition, we would also calculate the loss for each intermediate prediction, which helps us to supervise not only the final output but also all HG modules effectively. This is a brilliant design back then because pose estimation relies on the relationship among each area of the human body. For example, without seeing the location of the body, it’s hard to tell if an arm is left arm or right arm. By using a full prediction as the next modules’ input, we are forcing the network to pay attention to other joints while predicting a new join location.

Hourglass Module

So how does this HG (Hourglass) module itself look like? Let’s take a look at another diagram from the original paper:

In the diagram, each box is a residual block plus some additional operations like pooling. If you are not familiar with residual block and bottleneck structure, I’d recommend you to read some ResNet article first. In general, an HG module is an encoder and decoder architecture, where we downsample the features first, and then upsample the features to recover the info and form a heat-map. Each encoder layer would have a connection to its decoder counterpart, and we could stack as many as layers we want. In the implementation, we usually make some recursions and let this HG module to repeat itself.

I understand that it still seems too “convoluted” here, so it might be easier just to read the code. Here’s a piece of code copied from my Stacked Hourglass implementation on Github deep-vision repo:

def HourglassModule(inputs, order, filters, num_residual):
    """
    One Hourglass Module. Usually we stacked multiple of them together.
    https://github.com/princeton-vl/pose-hg-train/blob/master/src/models/hg.lua#L3

    inputs:
    order: The remaining order for HG modules to call itself recursively.
    num_residual: Number of residual layers for this HG module.
    """
    # Upper branch
    up1 = BottleneckBlock(inputs, filters, downsample=False)

    for i in range(num_residual):
        up1 = BottleneckBlock(up1, filters, downsample=False)

    # Lower branch
    low1 = MaxPool2D(pool_size=2, strides=2)(inputs)
    for i in range(num_residual):
        low1 = BottleneckBlock(low1, filters, downsample=False)

    low2 = low1
    if order > 1:
        low2 = HourglassModule(low1, order - 1, filters, num_residual)
    else:
        for i in range(num_residual):
            low2 = BottleneckBlock(low2, filters, downsample=False)

    low3 = low2
    for i in range(num_residual):
        low3 = BottleneckBlock(low3, filters, downsample=False)

    up2 = UpSampling2D(size=2)(low3)

    return up2 + up1

This module looks like an onion, and let’s start from the outmost layer first. up1 went through two bottleneck blocks and added together with up2. This represents two bigs boxes on the left and top, and also the right-most plus sign. The whole flow is up in the air, so we call it up channel. On line 17, there’s also a low channel. This low1 goes through some pooling and bottleneck block, then goes into another smaller Hourglass module! On the diagram, it’s the second layer of the big onion. And this is also why we are using recursion here. We keep repeating this HG module until layer 4, where you just have a single bottleneck instead of an HG module. And this final layer in the three tiny boxes in the middle of the diagram.

If you are familiar with some image classification networks, it’s clear that the author borrows the idea of skip connection very heavily. This repeating pattern connects the corresponding layers in the encoder and decoder together, instead of just having one flow of features. This not only helps the gradient to pass through but also lets the network consider features from different scales when decoding.

Intermediate Supervision

Now that we have an Hourglass module, and we know that the whole network consists of multiple modules like this, but how do we stack them together precisely? Here comes the final piece of the network: intermediate supervision.

As you can see from the diagram above, when we produce something from the HG module, we split the output into two paths. The top path includes some more convolutions to further process the features and then go to the next HG module. The interesting thing happens at the bottom path. Here we use the output of that convolution layer as an intermediate heat-map result (blue box) and then calculate loss between this intermediate heat-map and the ground-truth heat-map. In other words, if we have 4 HG modules, we will need to calculate four losses in total: 3 for the intermediate result, and 1 for the final result.

Prepare the Data

MPII Dataset

Once we finished the code for the Stacked Hourglass network, it’s time for us to think about what kind of data we’d like to use to train this network. If you have your own dataset, that’s great. But here I’d like to mention an open dataset for those beginners who want to have something to train on first. And it’s called MPII Dataset (Max Planck Institute for Informatics). You could find the download link here.

Although this dataset is mostly used for single person pose estimation, it does provide joints annotations for multiple people in the same image. For each person, it gives the coordinates for 16 joints, such as the left ankle or right shoulder.


However, the original dataset annotation is in Matlab format, which is really hard to use nowadays. An alternative is to use a preprocessed JSON format annotations provided by Microsoft here. The Google Drive link is here. After you download this JSON annotation, you would see a list with elements like this:

{
    "joints_vis": [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1
    ],
    "joints": [
        [
            804,
            711
        ],
        [
            816,
            510
        ],
        [
            908,
            438
        ],
        [
            1040,
            454
        ],
        [
            906,
            528
        ],
        [
            883,
            707
        ],
        [
            974,
            446
        ],
        [
            985,
            253
        ],
        [
            982.7591,
            235.9694
        ],
        [
            962.2409,
            80.0306
        ],
        [
            869,
            214
        ],
        [
            798,
            340
        ],
        [
            902,
            253
        ],
        [
            1067,
            253
        ],
        [
            1167,
            353
        ],
        [
            1142,
            478
        ]
    ],
    "image": "005808361.jpg",
    "scale": 4.718488,
    "center": [
        966,
        340
    ]
}

“joint_vis” indicates the visibility of a joint. In recent datasets, we usually need to differentiate occluded joints and visible joints. But in MPII, we only care about if the joint is in the view of the image: 1 -> in the view, 0-> out of the view. “joints” is a list of joint coordinates, and they follow the order of 0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee, 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck, 9 - head top, 10 - r wrist, 11 - r elbow, 12 - r shoulder, 13 - l shoulder, 14 - l elbow, 15 - l wrist. 

Cropping

The less clear part is “the scale” and “center”. Sometimes we could have more than one person in the image, so we need to crop out the one we are interested. Unlike the MSCOCO dataset, MPII didn’t give us the bounding box of the person. Instead, it gave us a center coordinate and a rough scale of the person. Both value is not accurate, but still represent the general location of a person in an image. Note that you’ll need to multiply “scale” by 200px to get the true height of a person. But how about the width? Unfortunately, the dataset didn’t really specify it. And the body may align somewhat horizontally, which makes the width way larger than height. One of the examples I saw before is the curling player crawling on the ground, and if you only use the height to crop, you could end up leaving his arms out. After some experiments, here’s my proposal to crop the image:

# avoid invisible keypoints whose value are <= 0
masked_keypoint_x = tf.boolean_mask(keypoint_x, keypoint_x > 0)
masked_keypoint_y = tf.boolean_mask(keypoint_y, keypoint_y > 0)

# find \left-most, top, bottom, and right-most keypoints
keypoint_xmin = tf.reduce_min(masked_keypoint_x)
keypoint_xmax = tf.reduce_max(masked_keypoint_x)
keypoint_ymin = tf.reduce_min(masked_keypoint_y)
keypoint_ymax = tf.reduce_max(masked_keypoint_y)

# add a padding according to human body height
xmin = keypoint_xmin - tf.cast(body_height * margin, dtype=tf.int32)
xmax = keypoint_xmax + tf.cast(body_height * margin, dtype=tf.int32)
ymin = keypoint_ymin - tf.cast(body_height * margin, dtype=tf.int32)
ymax = keypoint_ymax + tf.cast(body_height * margin, dtype=tf.int32)

# make sure the crop is valid
effective_xmin = xmin if xmin > 0 else 0
effective_ymin = ymin if ymin > 0 else 0
effective_xmax = xmax if xmax < img_width else img_width
effective_ymax = ymax if ymax < img_height else img_height
effective_height = effective_ymax - effective_ymin
effective_width = effective_xmax - effective_xmin

In short, we filter out invisible joints first and calculate coordinates of the left-most, top-most, bottom-most, and right-most joint from 16 joints. These four coordinates give us a region where we could at least include all available joint annotations. Then I padded this region based on a proportion of this person’s height, which is also calculated from the “scale” field. Lastly, we need to make sure that this crop would not go out of the border.

Gaussian

Another important thing to know about the ground-truth data is gaussian. When we curate the ground-truth heat-map, we don’t just assign 1 for the joint coordinates and assign 0 for all other pixels. This would make ground-truth too sparse to learn. If the model prediction is just a few pixels off, we should sort of encourage this behavior.

How do we model this encouragement in our loss function? If you took a probability class before, you might remember Gaussian distribution:

The center has the highest value and gradually decreasing values for the area around the center. This is exactly what we need. We would draw such a Gaussian pattern in our all-zero ground-truth canvas like the first figure below. And when you combine all 16 joints in one heat-map, it looks like the second figure below.


As you can see from the code, we calculate the size of the patch first, when sigma is 1, the size would be 7, and the center would be (3,3). Then we generate a meshgrid to represent the coordinates of each cell in this patch. And finally, substitute them into the gaussian formula.

scale = 1
size = 6 * sigma + 1
x, y = tf.meshgrid(tf.range(0, 6*sigma+1, 1), tf.range(0, 6*sigma+1, 1), indexing='xy')

# the center of the gaussian patch should be 1
center_x = size // 2
center_y = size // 2

# generate this 7x7 gaussian patch
gaussian_patch = tf.cast(tf.math.exp(-(tf.square(x - center_x) + tf.math.square(y - center_y)) / (tf.math.square(sigma) * 2)) * scale, dtype=tf.float32)

Note that the final code to generate a Gaussian is more complicated than this because it needs to handle some border cases. For full code, please take a look at my repo here: https://github.com/ethanyanjiali/deep-vision/blob/master/Hourglass/tensorflow/preprocess.py#L91

Loss Function

Until now, we discussed the network architecture and also the data to use. With those, we can make a forward pass for some training data to get a feeling about output. But modern deep learning is about back-propagation and gradient descent, which requires us to calculate the loss between ground-truth and prediction. So let’s get on it.

Fortunately, the loss function for Stacked Hourglass is pretty simple. You just take Mean Square Error between two vectors, which could be done in one line of code (vanilla version). However, in reality, I found it’s still kind of hard for the model to converge, and it learned to cheat by predicting all zeros to reach a local optimal. My solution here (improved version) is to assign a bigger weight for foreground pixels (those gaussian we drew) and make it hard for the network to just ignore these non-zero values. I choose 82 here because there’re 82 times background pixels than foreground pixels for a 7×7 patch in 64×64 heat-map.

# vanilla version
loss += tf.math.reduce_mean(tf.math.square(labels - output))
# improved version
weights = tf.cast(labels > 0, dtype=tf.float32) * 81 + 1
loss += tf.math.reduce_mean(tf.math.square(labels - output) * weights)

Predictions

So far, we’ve discussed the network, the data, and the optimization goal (loss). This should be sufficient for you to start your own training. Once we finished the training and get a model, it’s still not done yet. One shortcoming of using heat-map compared with regressing directly is the granularity. For example, with a 256×256 input, we are getting a 64×64 heat-map to represent key-point location. The four-times down-scale doesn’t seem quite bad. However, we usually first resize a bigger image, such as 720×480, into this 256×256 input. In this scenario, a 64×64 heat-map would be too coarse. To alleviate this problem, researchers came up an interesting idea. Instead of just using the pixel with max value, we also take into consideration of the neighbor pixel with largest value. Since the neighbor pixel also has a high , it infers that the actual key-point location might be a bit towards the direction of neighbor pixel. Sounds familiar, right? It’s pretty much like our gradient descent, which also points to the optimal solution.


Above is the prediction example of our network. The top one is all the joint locations. The bottom one is the skeleton, drew by linking those joints together. Although the result looks pretty decent, I have to admit that this is an easy example. In reality, there’re lots of twisted poses or occluded joints, which brings significant challenges to our network. For example, when there is only one foot in the image, the network could confuse itself by assign both left and right foot to the same location. How do we address this? Since this is a more a topic about improvement, I’ll leave it to you to think about it first and write another article to discuss in the future.

Conclusion

Congratulation, you reach the end of this tutorial. If you comprehend everything we discussed above, you should have a solid understanding of the theory and major challenges now. To start coding it in TensorFlow, I would suggest you to clone/fork my repo: https://github.com/ethanyanjiali/deep-vision/tree/master/Hourglass/tensorflow, follow the instruction to prepare dataset, and give it a run. If you run into any problems, please leave a Github issue so that I could take a look. And again, if you like my article or my repo, please ⭐star⭐ my repo and that will be the biggest support for me.

Dive Really Deep into YOLO v3: A Beginner’s Guide

For full source code, please go to https://github.com/ethanyanjiali/deep-vision/tree/master/YOLO/tensorflow. I really appreciate your STAR that supports my efforts.

Foreword

When a self-driving car runs on a road, how does it know where are other vehicles in the camera image? When an AI radiologist reading an X-ray, how does it know where the lesion (abnormal tissue) is? Today, I will walk through this fascinating algorithm, which can identify the category of the given image, and also locate the region of interest. There’s plenty of algorithms introduced in recent years to address object detection in a deep learning approach, such as R-CNN, Faster-RCNN, and Single Shot Detector. Among those, I’m most interested in a model called YOLO – You Only Look Once. This model attracts me so much, not only because of its funny name, but also some practical design that truly makes sense for me. In 2018, this latest V3 of this model had been released, and it achieved many new State of the Art performance. Because I’ve programmed some GANs and image classification networks before, and also Joseph Redmon described it in the paper in a really easy-going way, I thought this detector would just be another stack of CNN and FC layers that just works well magically.

But I was wrong.

Perhaps it’s because I’m just dumber than usual engineers, I found it really difficult for me to translate this model from the paper to actual code. And even when I managed to do that in a couple of weeks (I gave up once put it away for a few weeks), I found it even more difficult for me to make it work. There’re so quite a few blogs, GitHub repos about YOLO V3, but most of them just gave a very high-level overview of the architecture, and somehow they just succeed. Even worse, the paper itself is too chill that it fails to provide many crucial details of implementation, and I have to read the author’s original C implementation (when is the last time did I write C? Maybe at college?) to confirm some of my guesses. When there’s a bug, I usually have no idea why it would occur. Then I end up manually debugging it step by step and calculating those formulas with my little calculator.

Fortunately, I didn’t give up this time and finally made it work. But in the meantime, I also felt really strongly that there should be a more thorough guide out there on the internet to help dumb people like me to understand every detail of this system. After all, if one single detail is wrong, the whole system would go south quickly. And I’m sure that if I don’t write these down, I would forget all these in few weeks too. So, here I am, presenting you this “Dive Really Deep into YOLO V3: A Beginner’s Guide”. I hope you’ll like it.

Prerequisite

Before getting into the network itself, I’ll need to clarify with some prerequisites first. As a reader, you are expected to:

  1. Understand the basics of Convolutional Neural Network and Deep Learning
  2. Understand the idea of object detection task
  3. Have curiosity about how the algorithm works internally

If you need help on first two items, there’re plenty of excellent resources like Udacity Computer Vision Nanodegree, Cousera Deep Learning Specialization and Stanford CS231N
If you just want to build something to detect some object with your custom dataset quickly, check out this Tensorflow Object Detection API

YOLO V3

YOLO V3 is an improvement over previous YOLO detection networks. Compared to prior versions, it features multi-scale detection, stronger feature extractor network, and some changes in the loss function. As a result, this network can now detect many more targets from big to small. And, of course, just like other single-shot detectors, YOLO V3 also runs quite fast and makes real-time inference possible on GPU devices. Well, as a beginner to object detection, you might not have a clear image of what do they mean here. But you will gradually understand them later in my post. For now, just remember that YOLO V3 is one of the best models in terms of real-time object detection as of 2019.

Network Architecture

First of all, let’s talk about how this network look like at a high-level diagram (Although, the network architecture is the least time-consuming part of implementation). The whole system can be divided into two major components: Feature Extractor and Detector; both are multi-scale. When a new image comes in, it goes through the feature extractor first so that we can obtain feature embeddings at three (or more) different scales. Then, these features are feed into three (or more) branches of the detector to get bounding boxes and class information.

Darknet-53

The feature extractor YOLO V3 uses is called Darknet-53. You might be familiar with the previous Darknet version from YOLO V1, where there’re only 19 layers. But that was like a few years ago, and the image classification network has progressed a lot from merely deep stacks of layers. ResNet brought the idea of skip connections to help the activations to propagate through deeper layers without gradient diminishing. Darknet-53 borrows this idea and successfully extends the network from 19 to 53 layers, as we can see from the following diagram.

This is very easy to understand. Consider layers in each rectangle as a residual block. The whole network is a chain of multiple blocks with some strides 2 Conv layers in between to reduce dimension. Inside the block, there’s just a bottleneck structure (1×1 followed by 3×3) plus a skip connection. If the goal is to do multi-class classification as ImageNet does, an average pooling and a 1000 ways fully connected layers plus softmax activation will be added.

However, in the case of object detection, we won’t include this classification head. Instead, we are going to append a “detection” head to this feature extractor. And since YOLO V3 is designed to be a multi-scaled detector, we also need features from multiple scales. Therefore, features from last three residual blocks are all used in the later detection. In the diagram below, I’m assuming the input is 416×416, so three scale vectors would be 52×52, 26×26, and 13×13. Please note that if the input size is different, the output size will differ too.

Multi-scale Detector

Once we have three features vectors, we can now feed them into the detector. But how should we structure this detector? Unfortunately, the author didn’t bother to explain this part this his paper. But we could still take a look at the source code he published on Github. Through this config file, multiple 1×1 and 3×3 Conv layers are used before a final 1×1 Conv layer to form the final output. For medium and small scale, it also concatenates features from the previous scale. By doing so, small scale detection can also benefit from the result of large scale detection.

Assuming the input image is (416, 416, 3), the final output of the detectors will be in shape of [(52, 52, 3, (4 + 1 + num_classes)), (26, 26, 3, (4 + 1 + num_classes)), (13, 13, 3, (4 + 1 + num_classes))]. The three items in the list represent detections for three scales. But what do the cells in this 52x52x3x(4+1+num_classes) matrix mean? Good questions. This brings us to the most important notion in pre-2019 object detection algorithm: anchor box (prior box).

Anchor Box

The goal of object detection is to get a bounding box and its class. Bounding box usually represents in a normalized xmin, ymin, xmax, ymax format. For example, 0.5 xmin and 0.5 ymin mean the top left corner of the box is in the middle of the image. Intuitively, if we want to get a numeric value like 0.5, we are facing a regression problem. We may as well just have the network predict for values and use Mean Square Error to compare with the ground truth. However, due to the large variance of scale and aspect ratio of boxes, researchers found that it’s really hard for the network to converge if we just use this “brute force” way to get a bounding box. Hence, in Faster-RCNN paper, the idea of an anchor box is proposed.

Anchor box is a prior box that could have different pre-defined aspect ratios. These aspect ratios are determined before training by running K-means on the entire dataset. But where does the box anchor to? We need to introduce a new notion called the grid. In the “ancient” year of 2013, algorithms detect objects by using a window to slide through the entire image and running image classification on each window. However, this is so inefficient that researchers proposed to use Conv net to calculate the whole image all in once (technically, only when your run convolution kernels in parallel.) Since the convolution outputs a square matrix of feature values (like 13×13, 26×26, and 52×52 in YOLO), we define this matrix as a “grid” and assign anchor boxes to each cell of the grid. In other words, anchor boxes anchor to the grid cells, and they share the same centroid. And once we defined those anchors, we can determine how much does the ground truth box overlap with the anchor box and pick the one with the best IOU and couple them together. I guess you can also claim that the ground truth box anchors to this anchor box. In our later training, instead of predicting coordinates from the wild west, we can now predict offsets to these bounding boxes. This works because our ground truth box should look like the anchor box we pick, and only subtle adjustment is needed, whhich gives us a great head start in training.

In YOLO v3, we have three anchor boxes per grid cell. And we have three scales of grids. Therefore, we will have 52x52x3, 26x26x3 and 13x13x3 anchor boxes for each scale. For each anchor box, we need to predict 3 things:

  1. The location offset against the anchor box: tx, ty, tw, th. This has 4 values.
  2. The objectness score to indicate if this box contains an object. This has 1 value.
  3. The class probabilities to tell us which class this box belongs to. This has num_classes values.

In total, we are predicting 4 + 1 + num_classes values for one anchor box, and that’s why our network outputs a matrix in shape of 52x52x3x(4+1+num_classes) as I mentioned before. tx, ty, tw, th isn’t the real coordinates of the bounding box. It’s just the relative offsets compared with a particular anchor box. I’ll explain these three predictions more in the Loss Function section after.

Anchor box not only makes the detector implementation much harder and much error-prone, but also introduced an extra step before training if you want the best result. So, personally, I hate it very much and feel like this anchor box idea is more a hack than a real solution. In 2018 and 2019, researchers start to question the need for anchor box. Papers like CornerNet, Object as Points, and FCOS all discussed the possibility of training an object detector from scratch without the help of an anchor box.

Loss Function

With the final detection output, we can calculate the loss against the ground truth labels now. The loss function consists of four parts (or five, if you split noobj and obj): centroid (xy) loss, width and height (wh) loss, objectness (obj and noobj) loss and classification loss. When putting together, the formula is like this:

Loss = Lambda_Coord * Sum(Mean_Square_Error((tx, ty), (tx', ty') * obj_mask)
    + Lambda_Coord * Sum(Mean_Square_Error((tw, th), (tw', th') * obj_mask)
    + Sum(Binary_Cross_Entropy(obj, obj') * obj_mask) + Lambda_Noobj * Sum(Binary_Cross_Entropy(obj, obj') * (1 -obj_mask) * ignore_mask)
    + Sum(Binary_Cross_Entropy(class, class'))

It looks intimidating but let me break them down and explain one by one.

xy_loss =  Lambda_Coord * Sum(Mean_Square_Error((tx, ty), (tx', ty')) * obj_mask)

The first part is the loss for bounding box centroid. tx and ty is the relative centroid location from the ground truth. tx' and ty' is the centroid prediction from the detector directly. The smaller this loss is, the closer the centroids of prediction and ground truth are. Since this is a regression problem, we use mean square error here. Besides, if there’s no object from the ground truth for certain cells, we don’t need to include the loss of that cell into the final loss. Therefore we also multiple by obj_mask here. obj_mask is either 1 or 0, which indicates if there’s an object or not. In fact, we could just use obj as obj_mask, obj is the objectness score that I will cover later. One thing to note is that we need to do some calculation on ground truth to get this tx and ty. So, let’s see how to get this value first. As the author says in the paper:

bx = sigmoid(tx) + Cx
by = sigmoid(ty) + Cy

Here bx and by are the absolute values that we usually use as centroid location. For example, bx = 0.5, by = 0.5 means that the centroid of this box is the center of the entire image. However, since we are going to compute centroid off the anchor, our network is actually predicting centroid relative the top-left corner of the grid cell. Why grid cell? Because each anchor box is bounded to a grid cell, they share the same centroid. So the difference to grid cell can represent the difference to anchor box. In the formula above, sigmoid(tx) and sigmoid(ty) are the centroid location relative to the grid cell. For instance, sigmoid(tx) = 0.5 and sigmoid(ty) = 0.5 means the centroid is the center of the current grid cell (but not the entire image). Cx and Cy represents the absolute location of the top-left corner of the current grid cell. So if the grid cell is the one in the SECOND row and SECOND column of a grid 13×13, then Cx = 1 and Cy = 1. And if we add this grid cell location with relative centroid location, we will have the absolute centroid location bx = 0.5 + 1 and by = 0.5 + 1. Certainly, the author won’t bother to tell you that you also need to normalize this by dividing by the grid size, so the true bx would be 1.5/13 = 0.115. Ok, now that we understand the above formula, we just need to invert it so that we can get tx from bx in order to translate our original ground truth into the target label. Lastly, Lambda_Coord is the weight that Joe introduced in YOLO v1 paper. This is to put more emphasis on localization instead of classification. The value he suggested is 5.

wh_loss = Lambda_Coord * Sum(Mean_Square_Error((tw, th), (tw', th')) * obj_mask)

The next one is the width and height loss. Again, the author says:

bw = exp(tw) * pw
bh = exp(th) * ph

Here bw and bh are still the absolute width and height to the whole image. pw and ph are the width and height of the prior box (aka. anchor box, why there’re so many names). We take e^(tw) here because tw could be a negative number, but width won’t be negative in real world. So this exp() will make it positive. And we multiply by prior box width pw and ph because the prediction exp(tw) is based off the anchor box. So this multiplication gives us real width. Same thing for height. Similarly, we can inverse the formula above to translate bw and bh to tx and th when we calculate the loss.

obj_loss = Sum(Binary_Cross_Entropy(obj, obj') * obj_mask)
noobj_loss = Lambda_Noobj * Sum(Binary_Cross_Entropy(obj, obj') * (1 - obj_mask) * ignore_mask)

The third and fourth items are objectness and non-objectness score loss. Objectness indicates how likely is there an object in the current cell. Unlike YOLO v2, we will use binary cross-entropy instead of mean square error here. In the ground truth, objectness is always 1 for the cell that contains an object, and 0 for the cell that doesn’t contain any object. By measuring this obj_loss, we can gradually teach the network to detect a region of interest. In the meantime, we don’t want the network to cheat by proposing objects everywhere. Hence, we need noobj_loss to penalize those false positive proposals. We get false positives by masking prediciton with 1 - obj_mask. The ignore_mask is used to make sure we only penalize when the current box doesn’t have much overlap with the ground truth box. If there is, we tend to be softer because it’s actually quite close to the answer. As we can see from the paper, “If the bounding box prior is not the best but does overlap a ground truth object by more than some threshold we ignore the prediction.” Since there are way too many noobj than obj in our ground truth, we also need this Lambda_Noobj = 0.5 to make sure the network won’t be dominated by cells that don’t have objects.

class_loss = Sum(Binary_Cross_Entropy(class, class') * obj_mask)

The last loss is classification loss. If there’re 80 classes in total, the class and class' will be the one-hot encoding vector that has 80 values. In YOLO v3, it’s changed to do multi-label classification instead of multi-class classification. Why? Because some dataset may contains labels that are hierarchical or related, eg woman and person. So each output cell could have more than 1 class to be true. Correspondingly, we also apply binary cross-entropy for each class one by one and sum them up because they are not mutually exclusive. And like we did to other losses, we also multiply by this obj_mask so that we only count those cells that have a ground truth object.

To fully understand how this loss works, I suggest you manually walk through them with a real network prediction and ground truth. Calculating the loss by your calculator (or tf.math) can really help you to catch all the nitty-gritty details. And I did that by myself, which helped me find lots of bugs. After all, the devil is in the detail.

Implementation

If I stop writing here, my post will just be like another “YOLO v3 Review” somewhere on the web. Once you digest the general idea of YOLO v3 from the previous section, we are now ready to go explore the remaining 90% of our YOLO v3 journey: Implementation.

Framework

At the end of September, Google finally released TensorFlow 2.0.0. This is a fascinating milestone for TF. Nevertheless, new design doesn’t necessarily mean less pain for developers. I’ve been playing around TF 2 since very early of 2019 because I always wanted to write TensorFlow code in the way I did for PyTorch. If it’s not because of TensorFlow’s powerful production suite like TF Serving, TF lite, and TF Board, etc., I guess many developers will not choose TF for new projects. Hence, if you don’t have a strong demand for production deployment, I would suggest you implement YOLO v3 in PyTorch or even MXNet. However, if you made your mind to stick with TensorFlow, please continue reading.

TensorFlow 2 officially made eager mode a first-tier citizen. To put it simply, instead of using TensorFlow specific APIs to calculate in a graph, you can now leverage native Python code to run the graph in a dynamic mode. No more graph compilation and much easier debugging and control flow. In the case where performance is more important, a handy @tf.function decorator is also provided to help compile the code into a static graph. But, the reality is, eager mode and tf.function are still buggy or not well documented sometimes, which makes your life even harder in a complicated system like YOLO v3. Also, Keras model isn’t quite flexible, while the custom training loop is still quite experimental. Therefore, the best strategy for you to write YOLO v3 in TF 2 is to start with a minimum working template first, and gradually add more logic to this shell. By doing so, we can fail early and fix the bug before it hides too deeply in a giant nested graph.

Dataset

Aside from the framework to choose, the most important thing for successful training is the dataset. In the paper, the author used MSCOCO dataset to validate his idea. Indeed, this is a great dataset, and we should aim for a good accuracy on this benchmark dataset for our model. However, a big dataset like this could also hide some bugs in your code. For example, if the loss is not dropping, how do you know if it just needs more time to converge, or your loss function is wrong? Even with 8x V100 GPU, the training is still not fast enough for you to quickly iterate and fix things. Therefore, I recommend you to build a development set which contains tens of images to make sure your code looks “working” first. Another option is to use VOC 2007 dataset, which only has 2500 training images. To use MSCOCO or VOC2007 dataset and create TF Records, you could refer to my helper scripts here: MSCOCO, VOC2007

Preprocessing

Preprocessing stands for the operations to translate raw data into a proper input format of the network. For the image classification task, we usually just need to resize the image, and one-hot encode the label. But things are a bit more complicated for YOLO v3. Remember I said the output of the network is like 52x52x3x(4+1+num_classes) and has three different scales? Since we need to calculate the delta between ground truth and prediction, we also need to format our ground truth into such a matrix first.

For each ground truth bounding box, we need to pick the best scale and anchor for it. For example, a tiny kite in the sky should be in the small scale (52×52). And if the kite is more like a square in the image, we should also pick the most square-shaped anchor in that scale. In YOLO v3, the author provides 9 anchors for 3 scales. All we need to do is to choose the one that matches our ground truth box the most. When I implement this, I thought I need the coordinates of the anchor box as well to calculate IOU. In fact, you don’t need to. Since we just want to know which anchor fits our ground truth box best, we can just assume all anchors and the ground truth box share the same centroid. And with this assumption, the degree of matching would be the overlapping area, which can be calculated by min width * min height.

During the transformation, one could also add some data augmentation to increase the variety of training set virtually. For example, typical augmentation includes random flipping, random cropping, and random translating. However, these augmentations won’t block you from training a working detector, so I won’t cover much about this advanced topic.

Training

After all these discussions, you finally have a chance to run python train.py and start your model training. And this is also when you meet most of your bugs. You could refer to my training script here when you are blocked. Meanwhile, I want to provide some tips that are helpful for my own training.

NaN Loss

  1. Check your learning rate and make sure it’s not too high to explode your gradient.
  2. Check for 0 in binary cross-entropy because ln(0) is not a number. You can clip the value from (epsilon, 1 – epsilon).
  3. Find an example and walk through your loss step by step. Find out which part of your loss goes to NaN. For example, if width/height loss went to NaN, it could be because the way you calculate from tw to bw is wrong.

Loss remains high

  1. Try to increase your learning rate to see if it can drop faster. Mine starts at 0.01. But I’ve seen 1e-4 and 1e-5 works too.
  2. Visualize your preprocessed ground truth to see if it makes sense. One problem I had before is that my output grid is in [y][x] instead of [x][y], but my ground truth is reversed.
  3. Again, manually walk through your loss with a real example. I had a mistake of calculating cross-entropy between objectness and class probabilities.
  4. My loss also remains around 40 after 50 epochs of MSCOCO. However, the result isn’t that bad.
  5. Double-check the coordinates format throughout your code. YOLO requires xywh (centroid x, centroid y, width and height), but most of dataset comes as x1y1x2y2 (xmin, ymin, xmax, ymax).
  6. Double-check your network architecture. Don’t get misled by the diagram from a post called “A Closer Look at YOLOv3 – CyberAILab”.
  7. tf.keras.losses.binary_crossentropy isn’t the sum of binary cross-entropy you need.

Loss is low, but the prediction is off

  1. Adjusting lambda_coord or lambda_noobj to the loss based on your observation.
  2. If you are traininig on your own dataset, and the dataset is relative small (<30k images), you should intialize weights from a COCO pretrained model first.
  3. Double-check your non max suppression code and adjust some threshold (I’ll talk about NMS later).
  4. Make sure your obj_mask in the loss function isn’t mistakenly taking out necessary elements.
  5. Again and again, your loss function. When calculating loss, it uses relative xywh in a cell (also called tx, ty, tw, th). When calculating ignore mask and IOU, it uses absolute xywh in the whole image, though. Don’t mix them up.

Loss os low, but there’s no prediction

  1. If you are using a custom dataset, please check the distribution of your ground truth boxes first. The amount and quality of the boxes could really affect what the network learn (or cheat) to do.
  2. Predict on your training set to see if your model can overfit on the training set at least.

Multi-GPU training

Since the object detection network has so many parameters to train, it’s always better to have more computing power. For example, On MSCOCO 2017 Dataset, 8x Telsa V100 GPUs can train a whole epoch in 10 minutes, while 1x V100 would need more than 1 hour.

However, TensorFlow 2.0 doesn’t have great support over multi-GPU training so far. To do that in TF, you’ll need to pick a training strategy like MirroredStrategy, as I did here. Then wrap your dataset loader into a distributed version too. One caveat for distributed training is that the loss coming out of each batch should be divided by the global batch size because we are going to reduce_sum over all GPU results. For example, if the local batch size is 8, and there’re 8 GPUs, your batch loss should divide a global batch size of 64. Once you summed up losses from all replica, the final result will be the average loss of a single example.

Postprocessing

The final component in this detection system is a post-processor. Usually, postprocessing is just about trivial things like replacing machine-readable class id with human-readable class text. In object detection, though, we have one more crucial step to do to get final human-readable results. This is called non maximum suppression.

Let’s recall our objectness loss. When is false proposal has great overlap with ground truth, we won’t penalize it with noobj_loss. This encourages the network to predict close results so that we can train it more easily. Also, although not used in YOLO, when the sliding window approach is used, multiple windows could predict the same object. In order to eliminate these duplicate results, smart researchers designed an algorithm called non maximum supression (NMS).

https://medium.com/analytics-vidhya/yolo-v3-theory-explained-33100f6d193
(Credit to Analytics Vidhya)

The idea of NMS is quite simple. Find out the detection box with the best confidence first, add it to the final result, and then eliminates all other boxes which have IOU over a certain threshold with this best box. Next, you choose another box with the best confidence in the remaining boxes and do the same thing over and over until nothing is left. In the code, since TensorFlow needs explicit shape most of the time, we will usually define a max number of detection and stop early if that number is reached. In YOLO v3, our classification is not mutually exclusive anymore, and one detection could have more than one true class. However, some existing NMS code doesn’t take that into consideration, so be careful when you use them.

Conclusion

YOLO v3 is a masterpiece in the rising era of Artificial Intelligence, and also an excellent summary of Convolution Neural Network techniques and tricks in the 2010s. Although there’re many turn-key solutions like Detectron out there to simplify the process of making a detector, a hands-on experience in coding such sophisticated detector is really a great learning opportunity for machine learning engineers because merely reading the paper is far from enough. Like Ray Dalio said about his philosophy, “Pain plus reflection equals progress.” I hope my article could be a lighthouse in your painful journey of implementing YOLO v3, and perhaps you can also share the delightful progress with us later.

References

Gender Swap and CycleGAN in TensorFlow 2.0

To view the full neural network model and training scripts, please visit my Github repo here:
https://github.com/ethanyanjiali/deep-vision/tree/master/CycleGAN/tensorflow

Background

Recently, the gender swap lens from Snapchat becomes very popular on the internet. There’re many buzzwords about Generative Adversarial Networks since 2016 but this is the first time that ordinary people get to experience the power of GANs. What’s more extraordinary about this lens is its great real-time performance which make it just like looking into a magic mirror. Although we can’t know the exact algorithm behind this virus lens, it’s most likely a CycleGAN which is introduced in 2017 by Jun-Yan, Taesung, Phillip and Alexei in this paper. And in this article, I’m going to show you how to implement a gender swap effect with TensorFlow 2.0 just like Snapchat does.

GAN

First of all, I want to quickly go over the basics of Generative Adversarial Network (GAN) to help those readers who are not familiar with it. In some scenario, we want to generate an image which belongs to a particular domain. For example, we’d like to draw a random interior design photo. So to ask the computer to generate such an image, we need a mathematical representation of the interior design domain space. Assume there’s a function F, and a random input number x. We want y = F(x) to always be very close to our target domain Y. However, this target domain is in a very high dimensional space so that no human-being can figure out explicit rules to define it. A GAN is such kind of a network, by playing a minimax game between two AI agent, it can eventually find out an approximate representation F of our target domain Y.

So how does GAN accomplish it? The trick here is to break down the problem into two parts: 1. We need a generator to keep making new images out of some random number 2. We need a discriminator to give feedback for the generator about how good the generated image is. The generator here is just like a young artist who has no idea how to paint but want to fake some masterpiece, and the discriminator is a judge who can tell what’s wrong in the new paint. The judge doesn’t need to know how to paint by himself. However, as long as he’s good at telling the difference between a good one and a bad one, our young painter can benefit from his feedback for sure. So we use Deep Learning to build a good judge and use it to train a good painter in the meantime.

To train a good judge, we need to feed both the authentic image and the generated image to our discriminator together. Since we know which is authentic and which is fake beforehand, the discriminator can update its weights by comparing its decision and the truth. For the generator, it takes a look at the decision from the discriminator. If the discriminator seems more agreeable with the fake image, it indicates that the generator is heading in the right direction, and vice versa. The tricky part is that we can’t just train a judge without a painter, or a painter without a judge. They learn from each other and try to beat each other by playing this minimax game. If we train the discriminator too much without training the generator, the discriminator will become too dominant, and the generator won’t ever have a chance to improve because every move is a losing move.

Eventually, both generator and discriminator will be really good at their job. Also, by then, we will take the generator out to perform the task independently. However, what’s on earth does this GAN has anything to do with gender swap? In fact, this is a more straightforward problem than the one I just mentioned above. Now, instead of generating a random image of a target domain, we can skip the random number step and use the given image as input. Let’s say we want to convert a male face to a female face. We are looking for a function F, by taking a male face x, can output a value y that’s very close to the real female version y_0 of that face.

The GAN approach sounds clear. But we haven’t discussed one curial caveat in the above approach. To train the discriminator, we need both true image and false image. We get the false image y from the generator, but where do we get the true image for that specific male face? We can’t just use a random female face here because we want to preserve some common trait when we swap the gender and a face from a different person would ruin that. But it’s also really hard to get the paired training data as well. You could go out to find those real twin brother and sisters and take pictures of them. Or you can ask a professional dresser to ‘turn’ a man into a woman. Both are very expensive. So is there a way for the model to learn the most important facial difference between man and woman from an unpaired dataset?

CylceGAN

Fortunately, scientists discovered ways to utilize unpair training data. One of the most famous models is called CycleGAN. The main idea behind CycleGAN is that, instead of using paired data to train the discriminator, we can form a cycle, where two generators work together to convert the image back and forth. More specifically, generator A2B first generate an image from domain A, and then generator B2A use that as input to generate another image from domain B. We then set a goal to make sure the second image (reconstructed image) looks as close as the first input. For example, if a generator A2B first converts a horse to a zebra, then generator B2A converts that zebra back to a horse, the newly generated horse should look identical to the very original horse. In this way, the generator will learn to not generate some trivial changes, but only those critical differences between the two domains. Otherwise, it probably won’t be able to convert it back. With this goal setup, we can now use unpaired images as training data.

CycleGAN

In reality, we need two cycles here. Since we are training generator A2B and generator B2A together, we have to make sure both generators is improving over time; otherwise, it will still have a problem to reconstruct a good image. Moreover, as we discussed above, improving a generator means we need to improve the discriminator in the meantime. In the cycle A2B2A (A -> B -> A), we use discriminator A to decide if the reconstructed image is in domain A. Thus, discriminator A will be trained. Likewise, we also need a cycle B2A2B so that discriminator B can be trained as well. If both discriminator A and discriminator B are well trained, it means our generator A2B and B2A can improve too!

There’s another great article here for CycleGAN for further reading. Now that you get the main idea of this network, let’s dive deep into some details.

Optimizer

It’s recommended here that Adam is the best optimizer for GAN training. Although I don’t know the reason behind, the linear learning rate decay from the original paper looks quite effective during the training. It remains at 0.0002 for the first 100 epochs, then linearly decay to 0 in the next 100 epochs. Here total_batches is the number of mini-batches for each epoch because our learning rate scheduler only considers each mini batch as a step.

gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1)
optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)

Generator

Network Structure

CycleGAN uses a regular generator structure. It first encodes the input image into a feature matrix by applying 2D convolutions. This is used to extract valuable feature information from local or global.

Then, six or nine layers of ResNet blocks are used to transform the features from the encoder into the features in the target domain. As we know, the skip connection in ResNet block helps the network to memorize the gradients from previous layers, which makes sure the deeper layers can still learn something. If you are not familiar with ResNet, please refer to this paper.

Finally, a few layers of deconvolution is used as a decoder. The decoder converts the features from the target domain into an actual image from the target domain by upsampling.

Unlike the idea from VGG and Inception network, it’s recommended for a GAN to use a larger convolution kernel size like 7X7 so that it can pick up broader information instead of just focusing on details. It makes sense because when we reconstruct an image, it’s not only the details matters but also the overall pattern. Also, reflection padding is used here to improve the quality around the image border.

def make_generator_model(n_blocks):
    # 6 residual blocks
    # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
    # 9 residual blocks
    # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
    model = tf.keras.Sequential()

    # Encoding
    model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
    model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    # Transformation
    for i in range(n_blocks):
        model.add(ResNetBlock(256))

    # Decoding
    model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(ReflectionPad2d(3))
    model.add(tf.keras.layers.Conv2D(3, (7, 7), strides=(1, 1), padding='valid', activation='tanh'))

    return model

Loss Functions

There’re three types of loss we care about here:

  • To calculate the GAN loss, we measure the L2 distance (MSE) between the generated image and the truth image.
  • To calculate the cyclic loss, we measure the L1 distance (MAE) between the reconstructed image from the cycle and the truth image
  • To calculate the identity loss, we measure the L1 distance (MAE) between the identity image and the truth image

GAN loss is the typical loss we use the GANs, and I won’t discuss much here. The interesting parts are cyclic loss and identity loss. The cyclic loss measures how good the reconstructed image is, which helps both generators to catch the essential style difference between the two domains. The identity loss is optional, but it helps to avoid the generator to make unnecessary changes. The way it works is that, by applying generator A2B to a real B image, it shouldn’t make any changes as it’s already the desired outcome. According to the author, this mitigates some weird issues like background color change.

def calc_gan_loss(prediction, is_real):
    # Typical GAN loss to set objectives for generator and discriminator
    if is_real:
        return mse_loss(prediction, tf.ones_like(prediction))
    else:
        return mse_loss(prediction, tf.zeros_like(prediction))

def calc_cycle_loss(reconstructed_images, real_images):
    # Cycle loss to make sure reconstructed image looks real
    return mae_loss(reconstructed_images, real_images)

def calc_identity_loss(identity_images, real_images):
    # Identity loss to make sure generator won't do unnecessary change
    # Ideally, feeding a real image to generator should generate itself
    return mae_loss(identity_images, real_images)

To combine all losses, we also need to assign some weights for each loss so indicate the importance. In the paper, the author proposed two Lambda parameters, which 10x the cycle loss and 5x the identity loss. Note the usage of the GradientTape here, we record the gradient and apply gradient descent for both generators together. Here, real_a is the truth image from domain A, real_b is the truth image from domain B. fake_a2b is the generated image from domain A to domain B. and fake_b2a is the generative image from domain B to domain A.

@tf.function
def train_generator(images_a, images_b):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:
        # Use real B to generate B should be identical
        identity_a2b = generator_a2b(real_b, training=True)
        identity_b2a = generator_b2a(real_a, training=True)
        loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
        loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)

        # Generator A2B tries to trick Discriminator B that the generated image is B
        loss_gan_gen_a2b = calc_gan_loss(discriminator_b(fake_a2b, training=True), True)
        # Generator B2A tries to trick Discriminator A that the generated image is A
        loss_gan_gen_b2a = calc_gan_loss(discriminator_a(fake_b2a, training=True), True)
        loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a)
        loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b)

        # Total generator loss
        loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
            + (loss_cycle_a2b2a + loss_cycle_b2a2b) * 10 \
            + (loss_identity_a2b + loss_identity_b2a) * 5

    trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables
    gradient_gen = tape.gradient(loss_gen_total, trainable_variables)
    optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))

Discriminator

Network Structure

Similar to other GANs, the discriminator consists of some 2d convolution layers to extract features from the generated image. However, to help the generator to generate a high-resolution image, CycleGAN uses a technique called PatchGAN to created more fine-grained decision matrix instead of one decision value. Each value in this 32×32 decision matrix maps to a patch of the generated image, and indicate how real this patch is.

In fact, we don’t crop a patch of the input image during implementation. We just need to use a final convolution layer to do the job for us. Essentially, the convolution layer performs like cropping a patch.

def make_discriminator_model():
    # C64-C128-C256-C512
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=(256, 256, 3)))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    model.add(tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    model.add(tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    model.add(tf.keras.layers.Conv2D(512, (4, 4), strides=(1, 1), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    # This last conv net is the PatchGAN
    # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/39#issuecomment-305575964
    # https://github.com/phillipi/pix2pix/blob/master/scripts/receptive_field_sizes.m
    model.add(tf.keras.layers.Conv2D(1, (4, 4), strides=(1, 1), padding='same'))

    return model

Loss Function

The loss functions for discriminators are much more straightforward. Just like typical GANs, we tell the discriminator to treat truth image as real, and the generated image as fake. So we have two losses for each discriminator loss_real and loss_fake, both have an equal effect on the final loss. In calc_gan_loss, we are comparing two matrices. Usually, the output of a discriminator is just one value between 0 and 1. However, as we mentioned above, we use a technique called PatchGAN, so the discriminator will produce one decision for each patch, which forms a 32×32 decision matrix.

@tf.function
def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:

        # Discriminator A should classify real_a as A
        loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True)
        # Discriminator A should classify generated fake_b2a as not A
        loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False)

        # Discriminator B should classify real_b as B
        loss_gan_dis_b_real = calc_gan_loss(discriminator_b(real_b, training=True), True)
        # Discriminator B should classify generated fake_a2b as not B
        loss_gan_dis_b_fake = calc_gan_loss(discriminator_b(fake_a2b, training=True), False)

        # Total discriminator loss
        loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
        loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5
        loss_dis_total = loss_dis_a + loss_dis_b

    trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables
    gradient_dis = tape.gradient(loss_dis_total, trainable_variables)
    optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))

Training

Now that we have defined both models and loss functions, we can put them together and start training. By default, the eager mode is enabled in TensorFlow 2.0, so we don’t have to make the graph. However, if you are a careful person, you might found that both discriminator and generator training functions are decorated with a tf.function decorator. This is the new way introduced by TensorFlow 2.0 to replace the old tf.Session(). With this decorator, all operations within will be converted into a graph. Hence, the performance could be much better compared with the default eager mode. To learn more about tf.function, please refer to this article.

One thing to mention is that, instead of feeding the generated image to the discriminator directly, we are actually using an image pool here. Each time, the image pool will randomly decide to give the discriminator a newly generated image, or a generated image from past steps. The benefit of doing this is that the discriminator can learn from other cases and sort of having a memory about the hacks the generator uses. Unfortunately, we can’t use this random image pool in graph mode at the moment, so we need to put them back to CPU when selecting a random image from the pool. This indeed introduces some cost.

The model illustrated in this article is trained on my own GTX 1080 home computer, so it’s a bit slow. On a V100 16G GPU and 64G RAM instance, though, you should be able to set the mini batch size to 4, and the trainer can process one epoch of 260 mini batches in 3 minutes for the horse2zebra dataset. So it takes about 10 hours to train a horse2zebra model fully. If you reduce the image resolution and some network parameters correspondingly, the training could be faster. The final generator is about 44mb each.

def train_step(images_a, images_b, epoch, step):
    fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)

    fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
    fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)

    dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)

def train(dataset, epochs):
    for epoch in range(checkpoint.epoch+1, epochs+1):
        for (step, batch) in enumerate(dataset):
            train_step(batch[0], batch[1], epoch, step)

To see the full training script, please go visit my repo here.

Results

Let’s see some inference results on a few datasets. Among those, horse2zebra and monet2photo is the original dataset from the paper. And the CelebA dataset is from here.

horse2zebra

Horse -> Zebra

->

->

Zebra -> Horse

->

->

monet2photo

Monet -> Photo

->

->

Photo -> Monet

->

->

CelebA

Male -> Female

->

->

Female -> Male

->

Gender Swap

We successfully mapped a male face to a female face, but to use it in a production environment, we need to pipeline to orchestrate lots of other steps together. Snapchat may have its own optimization or models. But, here’s the procedure that I think will help to improve the final result of our CycleGAN.

  1. Run face detection to find a bounding box and keypoints for the most dominant face in the picture.
  2. Extend the bounding box a little bit bigger to match the training dataset distribution.
  3. Crop the picture with this extended bounding box and run CycleGAN over it.
  4. Patch the generated image back to the original picture
  5. Overlay some hair, eyeliner, and beards on top of the new face picture based on the keypoints we had from the last step

I get inspiration mostly from this great article that explains how this pipeline works in details.

Questions

Lastly, I want to throw out some questions I have. I don’t know the answers for them, but I hope those who have experiences of building similar products can share their opinions in the comments below.

  • The CycleGAN model turns out to be 44mb, with quantization it could become 12mb but still too large. What are the effective methods to make them usable on those mobile and embedded devices?
  • The output image resolution isn’t great and lost much of sharpness. How to generate a bigger image such as 1024×1024 without blowing up the model size? Will a super-resolution model help in this case?
  • How do we know if a model is thoroughly trained and converged? The loss isn’t good metrics here, but we also don’t know what’s “best” output. How to measure the similarity between the two styles?

References

  1. Understanding and Implementing CycleGAN in TensorFlow
  2. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
  3. Image-to-Image Translation with Conditional Adversarial Networks
  4. Gender and Race Change on Your Selfie with Neural Nets