Human Pose Estimation with Stacked Hourglass Network and TensorFlow

For full source code, please go to I really appreciate your ⭐STAR⭐ that supports my efforts.

Human is good at making different poses. Human is good at understanding these poses too. This makes body language become such an essential part of our daily communication, work, and entertainment. Unfortunately, poses have so much variance, so it’s not an easy task for a computer to recognize a pose from a picture…until we have deep learning!

With a deep neural network, the computer can learn a generalized pattern of human poses, and predict joints location accordingly. The Stacked Hourglass Network is just such kind of network, and I’m going to show you how to use it to make a simple human pose estimation. Although first introduced in 2016, it’s still one of the most important networks in pose estimation area, and widely used in lots of applications. No matter if you want to build a software to track basketball player’s action, or make a body language classifier based on a person’s pose, this would be a handy hands-on tutorial for you.

Network Architecture


To simply put, Stacked Hourglass Network (HG) is a stack of hourglass modules. It got this name because the shape of each hourglass module closely resemble an hourglass, as we can see from the picture below:

The idea behind stacking multiple HG (Hourglass) modules instead of forming a giant encoder and decoder network is that each HG module will produce a full heat-map for joint prediction. Thus, the latter HG module can learn from the joint predictions of the previous HG module.

Why would a heat-map help human pose estimation? This is a pretty common technique nowadays. Unlike facial keypoints, human pose data has lots of variances, which makes it hard to converge if we just simply regress the joint coordinates. Smart researchers come up with an idea to use heat-map to represent a joint location in an image. This preserves the location information, and then we just need to find the peak of the heat-map and use that as the joint location (plus some minor adjustment since heat-map is coarse). For a 256×256 input image, our heat-map will be 64×64.

In addition, we would also calculate the loss for each intermediate prediction, which helps us to supervise not only the final output but also all HG modules effectively. This is a brilliant design back then because pose estimation relies on the relationship among each area of the human body. For example, without seeing the location of the body, it’s hard to tell if an arm is left arm or right arm. By using a full prediction as the next modules’ input, we are forcing the network to pay attention to other joints while predicting a new join location.

Hourglass Module

So how does this HG (Hourglass) module itself look like? Let’s take a look at another diagram from the original paper:

In the diagram, each box is a residual block plus some additional operations like pooling. If you are not familiar with residual block and bottleneck structure, I’d recommend you to read some ResNet article first. In general, an HG module is an encoder and decoder architecture, where we downsample the features first, and then upsample the features to recover the info and form a heat-map. Each encoder layer would have a connection to its decoder counterpart, and we could stack as many as layers we want. In the implementation, we usually make some recursions and let this HG module to repeat itself.

I understand that it still seems too “convoluted” here, so it might be easier just to read the code. Here’s a piece of code copied from my Stacked Hourglass implementation on Github deep-vision repo:

def HourglassModule(inputs, order, filters, num_residual):
    One Hourglass Module. Usually we stacked multiple of them together.

    order: The remaining order for HG modules to call itself recursively.
    num_residual: Number of residual layers for this HG module.
    # Upper branch
    up1 = BottleneckBlock(inputs, filters, downsample=False)

    for i in range(num_residual):
        up1 = BottleneckBlock(up1, filters, downsample=False)

    # Lower branch
    low1 = MaxPool2D(pool_size=2, strides=2)(inputs)
    for i in range(num_residual):
        low1 = BottleneckBlock(low1, filters, downsample=False)

    low2 = low1
    if order > 1:
        low2 = HourglassModule(low1, order - 1, filters, num_residual)
        for i in range(num_residual):
            low2 = BottleneckBlock(low2, filters, downsample=False)

    low3 = low2
    for i in range(num_residual):
        low3 = BottleneckBlock(low3, filters, downsample=False)

    up2 = UpSampling2D(size=2)(low3)

    return up2 + up1

This module looks like an onion, and let’s start from the outmost layer first. up1 went through two bottleneck blocks and added together with up2. This represents two bigs boxes on the left and top, and also the right-most plus sign. The whole flow is up in the air, so we call it up channel. On line 17, there’s also a low channel. This low1 goes through some pooling and bottleneck block, then goes into another smaller Hourglass module! On the diagram, it’s the second layer of the big onion. And this is also why we are using recursion here. We keep repeating this HG module until layer 4, where you just have a single bottleneck instead of an HG module. And this final layer in the three tiny boxes in the middle of the diagram.

If you are familiar with some image classification networks, it’s clear that the author borrows the idea of skip connection very heavily. This repeating pattern connects the corresponding layers in the encoder and decoder together, instead of just having one flow of features. This not only helps the gradient to pass through but also lets the network consider features from different scales when decoding.

Intermediate Supervision

Now that we have an Hourglass module, and we know that the whole network consists of multiple modules like this, but how do we stack them together precisely? Here comes the final piece of the network: intermediate supervision.

As you can see from the diagram above, when we produce something from the HG module, we split the output into two paths. The top path includes some more convolutions to further process the features and then go to the next HG module. The interesting thing happens at the bottom path. Here we use the output of that convolution layer as an intermediate heat-map result (blue box) and then calculate loss between this intermediate heat-map and the ground-truth heat-map. In other words, if we have 4 HG modules, we will need to calculate four losses in total: 3 for the intermediate result, and 1 for the final result.

Prepare the Data

MPII Dataset

Once we finished the code for the Stacked Hourglass network, it’s time for us to think about what kind of data we’d like to use to train this network. If you have your own dataset, that’s great. But here I’d like to mention an open dataset for those beginners who want to have something to train on first. And it’s called MPII Dataset (Max Planck Institute for Informatics). You could find the download link here.

Although this dataset is mostly used for single person pose estimation, it does provide joints annotations for multiple people in the same image. For each person, it gives the coordinates for 16 joints, such as the left ankle or right shoulder.

However, the original dataset annotation is in Matlab format, which is really hard to use nowadays. An alternative is to use a preprocessed JSON format annotations provided by Microsoft here. The Google Drive link is here. After you download this JSON annotation, you would see a list with elements like this:

    "joints_vis": [
    "joints": [
    "image": "005808361.jpg",
    "scale": 4.718488,
    "center": [

“joint_vis” indicates the visibility of a joint. In recent datasets, we usually need to differentiate occluded joints and visible joints. But in MPII, we only care about if the joint is in the view of the image: 1 -> in the view, 0-> out of the view. “joints” is a list of joint coordinates, and they follow the order of 0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee, 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck, 9 - head top, 10 - r wrist, 11 - r elbow, 12 - r shoulder, 13 - l shoulder, 14 - l elbow, 15 - l wrist. 


The less clear part is “the scale” and “center”. Sometimes we could have more than one person in the image, so we need to crop out the one we are interested. Unlike the MSCOCO dataset, MPII didn’t give us the bounding box of the person. Instead, it gave us a center coordinate and a rough scale of the person. Both value is not accurate, but still represent the general location of a person in an image. Note that you’ll need to multiply “scale” by 200px to get the true height of a person. But how about the width? Unfortunately, the dataset didn’t really specify it. And the body may align somewhat horizontally, which makes the width way larger than height. One of the examples I saw before is the curling player crawling on the ground, and if you only use the height to crop, you could end up leaving his arms out. After some experiments, here’s my proposal to crop the image:

# avoid invisible keypoints whose value are <= 0
masked_keypoint_x = tf.boolean_mask(keypoint_x, keypoint_x > 0)
masked_keypoint_y = tf.boolean_mask(keypoint_y, keypoint_y > 0)

# find \left-most, top, bottom, and right-most keypoints
keypoint_xmin = tf.reduce_min(masked_keypoint_x)
keypoint_xmax = tf.reduce_max(masked_keypoint_x)
keypoint_ymin = tf.reduce_min(masked_keypoint_y)
keypoint_ymax = tf.reduce_max(masked_keypoint_y)

# add a padding according to human body height
xmin = keypoint_xmin - tf.cast(body_height * margin, dtype=tf.int32)
xmax = keypoint_xmax + tf.cast(body_height * margin, dtype=tf.int32)
ymin = keypoint_ymin - tf.cast(body_height * margin, dtype=tf.int32)
ymax = keypoint_ymax + tf.cast(body_height * margin, dtype=tf.int32)

# make sure the crop is valid
effective_xmin = xmin if xmin > 0 else 0
effective_ymin = ymin if ymin > 0 else 0
effective_xmax = xmax if xmax < img_width else img_width
effective_ymax = ymax if ymax < img_height else img_height
effective_height = effective_ymax - effective_ymin
effective_width = effective_xmax - effective_xmin

In short, we filter out invisible joints first and calculate coordinates of the left-most, top-most, bottom-most, and right-most joint from 16 joints. These four coordinates give us a region where we could at least include all available joint annotations. Then I padded this region based on a proportion of this person’s height, which is also calculated from the “scale” field. Lastly, we need to make sure that this crop would not go out of the border.


Another important thing to know about the ground-truth data is gaussian. When we curate the ground-truth heat-map, we don’t just assign 1 for the joint coordinates and assign 0 for all other pixels. This would make ground-truth too sparse to learn. If the model prediction is just a few pixels off, we should sort of encourage this behavior.

How do we model this encouragement in our loss function? If you took a probability class before, you might remember Gaussian distribution:

The center has the highest value and gradually decreasing values for the area around the center. This is exactly what we need. We would draw such a Gaussian pattern in our all-zero ground-truth canvas like the first figure below. And when you combine all 16 joints in one heat-map, it looks like the second figure below.

As you can see from the code, we calculate the size of the patch first, when sigma is 1, the size would be 7, and the center would be (3,3). Then we generate a meshgrid to represent the coordinates of each cell in this patch. And finally, substitute them into the gaussian formula.

scale = 1
size = 6 * sigma + 1
x, y = tf.meshgrid(tf.range(0, 6*sigma+1, 1), tf.range(0, 6*sigma+1, 1), indexing='xy')

# the center of the gaussian patch should be 1
center_x = size // 2
center_y = size // 2

# generate this 7x7 gaussian patch
gaussian_patch = tf.cast(tf.math.exp(-(tf.square(x - center_x) + tf.math.square(y - center_y)) / (tf.math.square(sigma) * 2)) * scale, dtype=tf.float32)

Note that the final code to generate a Gaussian is more complicated than this because it needs to handle some border cases. For full code, please take a look at my repo here:

Loss Function

Until now, we discussed the network architecture and also the data to use. With those, we can make a forward pass for some training data to get a feeling about output. But modern deep learning is about back-propagation and gradient descent, which requires us to calculate the loss between ground-truth and prediction. So let’s get on it.

Fortunately, the loss function for Stacked Hourglass is pretty simple. You just take Mean Square Error between two vectors, which could be done in one line of code (vanilla version). However, in reality, I found it’s still kind of hard for the model to converge, and it learned to cheat by predicting all zeros to reach a local optimal. My solution here (improved version) is to assign a bigger weight for foreground pixels (those gaussian we drew) and make it hard for the network to just ignore these non-zero values. I choose 82 here because there’re 82 times background pixels than foreground pixels for a 7×7 patch in 64×64 heat-map.

# vanilla version
loss += tf.math.reduce_mean(tf.math.square(labels - output))
# improved version
weights = tf.cast(labels > 0, dtype=tf.float32) * 81 + 1
loss += tf.math.reduce_mean(tf.math.square(labels - output) * weights)


So far, we’ve discussed the network, the data, and the optimization goal (loss). This should be sufficient for you to start your own training. Once we finished the training and get a model, it’s still not done yet. One shortcoming of using heat-map compared with regressing directly is the granularity. For example, with a 256×256 input, we are getting a 64×64 heat-map to represent key-point location. The four-times down-scale doesn’t seem quite bad. However, we usually first resize a bigger image, such as 720×480, into this 256×256 input. In this scenario, a 64×64 heat-map would be too coarse. To alleviate this problem, researchers came up an interesting idea. Instead of just using the pixel with max value, we also take into consideration of the neighbor pixel with largest value. Since the neighbor pixel also has a high , it infers that the actual key-point location might be a bit towards the direction of neighbor pixel. Sounds familiar, right? It’s pretty much like our gradient descent, which also points to the optimal solution.

Above is the prediction example of our network. The top one is all the joint locations. The bottom one is the skeleton, drew by linking those joints together. Although the result looks pretty decent, I have to admit that this is an easy example. In reality, there’re lots of twisted poses or occluded joints, which brings significant challenges to our network. For example, when there is only one foot in the image, the network could confuse itself by assign both left and right foot to the same location. How do we address this? Since this is a more a topic about improvement, I’ll leave it to you to think about it first and write another article to discuss in the future.


Congratulation, you reach the end of this tutorial. If you comprehend everything we discussed above, you should have a solid understanding of the theory and major challenges now. To start coding it in TensorFlow, I would suggest you to clone/fork my repo:, follow the instruction to prepare dataset, and give it a run. If you run into any problems, please leave a Github issue so that I could take a look. And again, if you like my article or my repo, please ⭐star⭐ my repo and that will be the biggest support for me.

Ten Reasons Why TensorFlow Sucks

TensorFlow is the most important machine learning framework on this earch so far, no doubt. But I believe, for most developers, TensorFlow is also one of the most hated frameworks so far in the world. I’ve been struggling with it for a while, and wanna share you with some of my thoughts:

  1. In TensorFlow, essentially you are trying to write a static typed program (Tensor) in a dynamic typed language (Python). This gives you endless runtime error of type mismatch.

  2. For TF users, Eager mode brings more troubles than solutions because of the weird hybrid. You could have the code running in eager mode perfectly and suddenly everything breaks when you add tf.function. Perhaps only Francois Chollet knows where is the boundary between eager mode and graph mode.

  3. For TF engineers, this confusing hyrid exponentially increased the number of bugs. They may as well just create a new TensorFlow and ditch the 1.x altogether.

  4. The documentations are self-contradictory everywhere. Some critial explanations are missing, some are still outdated implementations, And some are just purely irresponsible fools.

  5. The curse of Keras, low-level APIs and Estimator. If you don’t have a powerful community like JavaScript to make all alternatives perfect, you should just be opinionated and focus on one.

  6. Do you know Tor browser? Yes, the one used to visit dark webs. I guess TensorFlow learned from it by deeply hide the real error from the stack trace. Horray! TF bugs are just untraceable like Tor.

  7. Distributed training over multiple GPUs? Isn’t it an accomplished thing since 2017? Nope, TensorFlow is aimming for 2050 to provide you with a convienient distributed training API. Before that, let it all be experimental and incompatible with other code for a while.

  8. and TF Records sounds attractive? In fact, they are poisoned graph mode candies which could bring you infinite troubles. You have to spend hours to think about “creative” TensorFlow graph solution for some transformation that’s very simple in native Python (or in other word, PyTorch).

  9. The breaking changes in APIs and design invalidates most of the third-party tensorflow tutorials and articles before 2019.

  10. And last but not least, you hate TensorFlow but you still have to use it because you work on some production models and the existing workflow relies on TensorFlow much. Or maybe you are seduced by TF Lite, tfjs? Fuck Life.

Alright. Let’s get back to the work now.

Dive Really Deep into YOLO v3: A Beginner’s Guide

For full source code, please go to I really appreciate your STAR that supports my efforts.


When a self-driving car runs on a road, how does it know where are other vehicles in the camera image? When an AI radiologist reading an X-ray, how does it know where the lesion (abnormal tissue) is? Today, I will walk through this fascinating algorithm, which can identify the category of the given image, and also locate the region of interest. There’s plenty of algorithms introduced in recent years to address object detection in a deep learning approach, such as R-CNN, Faster-RCNN, and Single Shot Detector. Among those, I’m most interested in a model called YOLO – You Only Look Once. This model attracts me so much, not only because of its funny name, but also some practical design that truly makes sense for me. In 2018, this latest V3 of this model had been released, and it achieved many new State of the Art performance. Because I’ve programmed some GANs and image classification networks before, and also Joseph Redmon described it in the paper in a really easy-going way, I thought this detector would just be another stack of CNN and FC layers that just works well magically.

But I was wrong.

Perhaps it’s because I’m just dumber than usual engineers, I found it really difficult for me to translate this model from the paper to actual code. And even when I managed to do that in a couple of weeks (I gave up once put it away for a few weeks), I found it even more difficult for me to make it work. There’re so quite a few blogs, GitHub repos about YOLO V3, but most of them just gave a very high-level overview of the architecture, and somehow they just succeed. Even worse, the paper itself is too chill that it fails to provide many crucial details of implementation, and I have to read the author’s original C implementation (when is the last time did I write C? Maybe at college?) to confirm some of my guesses. When there’s a bug, I usually have no idea why it would occur. Then I end up manually debugging it step by step and calculating those formulas with my little calculator.

Fortunately, I didn’t give up this time and finally made it work. But in the meantime, I also felt really strongly that there should be a more thorough guide out there on the internet to help dumb people like me to understand every detail of this system. After all, if one single detail is wrong, the whole system would go south quickly. And I’m sure that if I don’t write these down, I would forget all these in few weeks too. So, here I am, presenting you this “Dive Really Deep into YOLO V3: A Beginner’s Guide”. I hope you’ll like it.


Before getting into the network itself, I’ll need to clarify with some prerequisites first. As a reader, you are expected to:

  1. Understand the basics of Convolutional Neural Network and Deep Learning
  2. Understand the idea of object detection task
  3. Have curiosity about how the algorithm works internally

If you need help on first two items, there’re plenty of excellent resources like Udacity Computer Vision Nanodegree, Cousera Deep Learning Specialization and Stanford CS231N
If you just want to build something to detect some object with your custom dataset quickly, check out this Tensorflow Object Detection API


YOLO V3 is an improvement over previous YOLO detection networks. Compared to prior versions, it features multi-scale detection, stronger feature extractor network, and some changes in the loss function. As a result, this network can now detect many more targets from big to small. And, of course, just like other single-shot detectors, YOLO V3 also runs quite fast and makes real-time inference possible on GPU devices. Well, as a beginner to object detection, you might not have a clear image of what do they mean here. But you will gradually understand them later in my post. For now, just remember that YOLO V3 is one of the best models in terms of real-time object detection as of 2019.

Network Architecture

First of all, let’s talk about how this network look like at a high-level diagram (Although, the network architecture is the least time-consuming part of implementation). The whole system can be divided into two major components: Feature Extractor and Detector; both are multi-scale. When a new image comes in, it goes through the feature extractor first so that we can obtain feature embeddings at three (or more) different scales. Then, these features are feed into three (or more) branches of the detector to get bounding boxes and class information.


The feature extractor YOLO V3 uses is called Darknet-53. You might be familiar with the previous Darknet version from YOLO V1, where there’re only 19 layers. But that was like a few years ago, and the image classification network has progressed a lot from merely deep stacks of layers. ResNet brought the idea of skip connections to help the activations to propagate through deeper layers without gradient diminishing. Darknet-53 borrows this idea and successfully extends the network from 19 to 53 layers, as we can see from the following diagram.

This is very easy to understand. Consider layers in each rectangle as a residual block. The whole network is a chain of multiple blocks with some strides 2 Conv layers in between to reduce dimension. Inside the block, there’s just a bottleneck structure (1×1 followed by 3×3) plus a skip connection. If the goal is to do multi-class classification as ImageNet does, an average pooling and a 1000 ways fully connected layers plus softmax activation will be added.

However, in the case of object detection, we won’t include this classification head. Instead, we are going to append a “detection” head to this feature extractor. And since YOLO V3 is designed to be a multi-scaled detector, we also need features from multiple scales. Therefore, features from last three residual blocks are all used in the later detection. In the diagram below, I’m assuming the input is 416×416, so three scale vectors would be 52×52, 26×26, and 13×13. Please note that if the input size is different, the output size will differ too.

Multi-scale Detector

Once we have three features vectors, we can now feed them into the detector. But how should we structure this detector? Unfortunately, the author didn’t bother to explain this part this his paper. But we could still take a look at the source code he published on Github. Through this config file, multiple 1×1 and 3×3 Conv layers are used before a final 1×1 Conv layer to form the final output. For medium and small scale, it also concatenates features from the previous scale. By doing so, small scale detection can also benefit from the result of large scale detection.

Assuming the input image is (416, 416, 3), the final output of the detectors will be in shape of [(52, 52, 3, (4 + 1 + num_classes)), (26, 26, 3, (4 + 1 + num_classes)), (13, 13, 3, (4 + 1 + num_classes))]. The three items in the list represent detections for three scales. But what do the cells in this 52x52x3x(4+1+num_classes) matrix mean? Good questions. This brings us to the most important notion in pre-2019 object detection algorithm: anchor box (prior box).

Anchor Box

The goal of object detection is to get a bounding box and its class. Bounding box usually represents in a normalized xmin, ymin, xmax, ymax format. For example, 0.5 xmin and 0.5 ymin mean the top left corner of the box is in the middle of the image. Intuitively, if we want to get a numeric value like 0.5, we are facing a regression problem. We may as well just have the network predict for values and use Mean Square Error to compare with the ground truth. However, due to the large variance of scale and aspect ratio of boxes, researchers found that it’s really hard for the network to converge if we just use this “brute force” way to get a bounding box. Hence, in Faster-RCNN paper, the idea of an anchor box is proposed.

Anchor box is a prior box that could have different pre-defined aspect ratios. These aspect ratios are determined before training by running K-means on the entire dataset. But where does the box anchor to? We need to introduce a new notion called the grid. In the “ancient” year of 2013, algorithms detect objects by using a window to slide through the entire image and running image classification on each window. However, this is so inefficient that researchers proposed to use Conv net to calculate the whole image all in once (technically, only when your run convolution kernels in parallel.) Since the convolution outputs a square matrix of feature values (like 13×13, 26×26, and 52×52 in YOLO), we define this matrix as a “grid” and assign anchor boxes to each cell of the grid. In other words, anchor boxes anchor to the grid cells, and they share the same centroid. And once we defined those anchors, we can determine how much does the ground truth box overlap with the anchor box and pick the one with the best IOU and couple them together. I guess you can also claim that the ground truth box anchors to this anchor box. In our later training, instead of predicting coordinates from the wild west, we can now predict offsets to these bounding boxes. This works because our ground truth box should look like the anchor box we pick, and only subtle adjustment is needed, whhich gives us a great head start in training.

In YOLO v3, we have three anchor boxes per grid cell. And we have three scales of grids. Therefore, we will have 52x52x3, 26x26x3 and 13x13x3 anchor boxes for each scale. For each anchor box, we need to predict 3 things:

  1. The location offset against the anchor box: tx, ty, tw, th. This has 4 values.
  2. The objectness score to indicate if this box contains an object. This has 1 value.
  3. The class probabilities to tell us which class this box belongs to. This has num_classes values.

In total, we are predicting 4 + 1 + num_classes values for one anchor box, and that’s why our network outputs a matrix in shape of 52x52x3x(4+1+num_classes) as I mentioned before. tx, ty, tw, th isn’t the real coordinates of the bounding box. It’s just the relative offsets compared with a particular anchor box. I’ll explain these three predictions more in the Loss Function section after.

Anchor box not only makes the detector implementation much harder and much error-prone, but also introduced an extra step before training if you want the best result. So, personally, I hate it very much and feel like this anchor box idea is more a hack than a real solution. In 2018 and 2019, researchers start to question the need for anchor box. Papers like CornerNet, Object as Points, and FCOS all discussed the possibility of training an object detector from scratch without the help of an anchor box.

Loss Function

With the final detection output, we can calculate the loss against the ground truth labels now. The loss function consists of four parts (or five, if you split noobj and obj): centroid (xy) loss, width and height (wh) loss, objectness (obj and noobj) loss and classification loss. When putting together, the formula is like this:

Loss = Lambda_Coord * Sum(Mean_Square_Error((tx, ty), (tx', ty') * obj_mask)
    + Lambda_Coord * Sum(Mean_Square_Error((tw, th), (tw', th') * obj_mask)
    + Sum(Binary_Cross_Entropy(obj, obj') * obj_mask) + Lambda_Noobj * Sum(Binary_Cross_Entropy(obj, obj') * (1 -obj_mask) * ignore_mask)
    + Sum(Binary_Cross_Entropy(class, class'))

It looks intimidating but let me break them down and explain one by one.

xy_loss =  Lambda_Coord * Sum(Mean_Square_Error((tx, ty), (tx', ty')) * obj_mask)

The first part is the loss for bounding box centroid. tx and ty is the relative centroid location from the ground truth. tx' and ty' is the centroid prediction from the detector directly. The smaller this loss is, the closer the centroids of prediction and ground truth are. Since this is a regression problem, we use mean square error here. Besides, if there’s no object from the ground truth for certain cells, we don’t need to include the loss of that cell into the final loss. Therefore we also multiple by obj_mask here. obj_mask is either 1 or 0, which indicates if there’s an object or not. In fact, we could just use obj as obj_mask, obj is the objectness score that I will cover later. One thing to note is that we need to do some calculation on ground truth to get this tx and ty. So, let’s see how to get this value first. As the author says in the paper:

bx = sigmoid(tx) + Cx
by = sigmoid(ty) + Cy

Here bx and by are the absolute values that we usually use as centroid location. For example, bx = 0.5, by = 0.5 means that the centroid of this box is the center of the entire image. However, since we are going to compute centroid off the anchor, our network is actually predicting centroid relative the top-left corner of the grid cell. Why grid cell? Because each anchor box is bounded to a grid cell, they share the same centroid. So the difference to grid cell can represent the difference to anchor box. In the formula above, sigmoid(tx) and sigmoid(ty) are the centroid location relative to the grid cell. For instance, sigmoid(tx) = 0.5 and sigmoid(ty) = 0.5 means the centroid is the center of the current grid cell (but not the entire image). Cx and Cy represents the absolute location of the top-left corner of the current grid cell. So if the grid cell is the one in the SECOND row and SECOND column of a grid 13×13, then Cx = 1 and Cy = 1. And if we add this grid cell location with relative centroid location, we will have the absolute centroid location bx = 0.5 + 1 and by = 0.5 + 1. Certainly, the author won’t bother to tell you that you also need to normalize this by dividing by the grid size, so the true bx would be 1.5/13 = 0.115. Ok, now that we understand the above formula, we just need to invert it so that we can get tx from bx in order to translate our original ground truth into the target label. Lastly, Lambda_Coord is the weight that Joe introduced in YOLO v1 paper. This is to put more emphasis on localization instead of classification. The value he suggested is 5.

wh_loss = Lambda_Coord * Sum(Mean_Square_Error((tw, th), (tw', th')) * obj_mask)

The next one is the width and height loss. Again, the author says:

bw = exp(tw) * pw
bh = exp(th) * ph

Here bw and bh are still the absolute width and height to the whole image. pw and ph are the width and height of the prior box (aka. anchor box, why there’re so many names). We take e^(tw) here because tw could be a negative number, but width won’t be negative in real world. So this exp() will make it positive. And we multiply by prior box width pw and ph because the prediction exp(tw) is based off the anchor box. So this multiplication gives us real width. Same thing for height. Similarly, we can inverse the formula above to translate bw and bh to tx and th when we calculate the loss.

obj_loss = Sum(Binary_Cross_Entropy(obj, obj') * obj_mask)
noobj_loss = Lambda_Noobj * Sum(Binary_Cross_Entropy(obj, obj') * (1 - obj_mask) * ignore_mask)

The third and fourth items are objectness and non-objectness score loss. Objectness indicates how likely is there an object in the current cell. Unlike YOLO v2, we will use binary cross-entropy instead of mean square error here. In the ground truth, objectness is always 1 for the cell that contains an object, and 0 for the cell that doesn’t contain any object. By measuring this obj_loss, we can gradually teach the network to detect a region of interest. In the meantime, we don’t want the network to cheat by proposing objects everywhere. Hence, we need noobj_loss to penalize those false positive proposals. We get false positives by masking prediciton with 1 - obj_mask. The ignore_mask is used to make sure we only penalize when the current box doesn’t have much overlap with the ground truth box. If there is, we tend to be softer because it’s actually quite close to the answer. As we can see from the paper, “If the bounding box prior is not the best but does overlap a ground truth object by more than some threshold we ignore the prediction.” Since there are way too many noobj than obj in our ground truth, we also need this Lambda_Noobj = 0.5 to make sure the network won’t be dominated by cells that don’t have objects.

class_loss = Sum(Binary_Cross_Entropy(class, class') * obj_mask)

The last loss is classification loss. If there’re 80 classes in total, the class and class' will be the one-hot encoding vector that has 80 values. In YOLO v3, it’s changed to do multi-label classification instead of multi-class classification. Why? Because some dataset may contains labels that are hierarchical or related, eg woman and person. So each output cell could have more than 1 class to be true. Correspondingly, we also apply binary cross-entropy for each class one by one and sum them up because they are not mutually exclusive. And like we did to other losses, we also multiply by this obj_mask so that we only count those cells that have a ground truth object.

To fully understand how this loss works, I suggest you manually walk through them with a real network prediction and ground truth. Calculating the loss by your calculator (or tf.math) can really help you to catch all the nitty-gritty details. And I did that by myself, which helped me find lots of bugs. After all, the devil is in the detail.


If I stop writing here, my post will just be like another “YOLO v3 Review” somewhere on the web. Once you digest the general idea of YOLO v3 from the previous section, we are now ready to go explore the remaining 90% of our YOLO v3 journey: Implementation.


At the end of September, Google finally released TensorFlow 2.0.0. This is a fascinating milestone for TF. Nevertheless, new design doesn’t necessarily mean less pain for developers. I’ve been playing around TF 2 since very early of 2019 because I always wanted to write TensorFlow code in the way I did for PyTorch. If it’s not because of TensorFlow’s powerful production suite like TF Serving, TF lite, and TF Board, etc., I guess many developers will not choose TF for new projects. Hence, if you don’t have a strong demand for production deployment, I would suggest you implement YOLO v3 in PyTorch or even MXNet. However, if you made your mind to stick with TensorFlow, please continue reading.

TensorFlow 2 officially made eager mode a first-tier citizen. To put it simply, instead of using TensorFlow specific APIs to calculate in a graph, you can now leverage native Python code to run the graph in a dynamic mode. No more graph compilation and much easier debugging and control flow. In the case where performance is more important, a handy @tf.function decorator is also provided to help compile the code into a static graph. But, the reality is, eager mode and tf.function are still buggy or not well documented sometimes, which makes your life even harder in a complicated system like YOLO v3. Also, Keras model isn’t quite flexible, while the custom training loop is still quite experimental. Therefore, the best strategy for you to write YOLO v3 in TF 2 is to start with a minimum working template first, and gradually add more logic to this shell. By doing so, we can fail early and fix the bug before it hides too deeply in a giant nested graph.


Aside from the framework to choose, the most important thing for successful training is the dataset. In the paper, the author used MSCOCO dataset to validate his idea. Indeed, this is a great dataset, and we should aim for a good accuracy on this benchmark dataset for our model. However, a big dataset like this could also hide some bugs in your code. For example, if the loss is not dropping, how do you know if it just needs more time to converge, or your loss function is wrong? Even with 8x V100 GPU, the training is still not fast enough for you to quickly iterate and fix things. Therefore, I recommend you to build a development set which contains tens of images to make sure your code looks “working” first. Another option is to use VOC 2007 dataset, which only has 2500 training images. To use MSCOCO or VOC2007 dataset and create TF Records, you could refer to my helper scripts here: MSCOCO, VOC2007


Preprocessing stands for the operations to translate raw data into a proper input format of the network. For the image classification task, we usually just need to resize the image, and one-hot encode the label. But things are a bit more complicated for YOLO v3. Remember I said the output of the network is like 52x52x3x(4+1+num_classes) and has three different scales? Since we need to calculate the delta between ground truth and prediction, we also need to format our ground truth into such a matrix first.

For each ground truth bounding box, we need to pick the best scale and anchor for it. For example, a tiny kite in the sky should be in the small scale (52×52). And if the kite is more like a square in the image, we should also pick the most square-shaped anchor in that scale. In YOLO v3, the author provides 9 anchors for 3 scales. All we need to do is to choose the one that matches our ground truth box the most. When I implement this, I thought I need the coordinates of the anchor box as well to calculate IOU. In fact, you don’t need to. Since we just want to know which anchor fits our ground truth box best, we can just assume all anchors and the ground truth box share the same centroid. And with this assumption, the degree of matching would be the overlapping area, which can be calculated by min width * min height.

During the transformation, one could also add some data augmentation to increase the variety of training set virtually. For example, typical augmentation includes random flipping, random cropping, and random translating. However, these augmentations won’t block you from training a working detector, so I won’t cover much about this advanced topic.


After all these discussions, you finally have a chance to run python and start your model training. And this is also when you meet most of your bugs. You could refer to my training script here when you are blocked. Meanwhile, I want to provide some tips that are helpful for my own training.

NaN Loss

  1. Check your learning rate and make sure it’s not too high to explode your gradient.
  2. Check for 0 in binary cross-entropy because ln(0) is not a number. You can clip the value from (epsilon, 1 – epsilon).
  3. Find an example and walk through your loss step by step. Find out which part of your loss goes to NaN. For example, if width/height loss went to NaN, it could be because the way you calculate from tw to bw is wrong.

Loss remains high

  1. Try to increase your learning rate to see if it can drop faster. Mine starts at 0.01. But I’ve seen 1e-4 and 1e-5 works too.
  2. Visualize your preprocessed ground truth to see if it makes sense. One problem I had before is that my output grid is in [y][x] instead of [x][y], but my ground truth is reversed.
  3. Again, manually walk through your loss with a real example. I had a mistake of calculating cross-entropy between objectness and class probabilities.
  4. My loss also remains around 40 after 50 epochs of MSCOCO. However, the result isn’t that bad.
  5. Double-check the coordinates format throughout your code. YOLO requires xywh (centroid x, centroid y, width and height), but most of dataset comes as x1y1x2y2 (xmin, ymin, xmax, ymax).
  6. Double-check your network architecture. Don’t get misled by the diagram from a post called “A Closer Look at YOLOv3 – CyberAILab”.
  7. tf.keras.losses.binary_crossentropy isn’t the sum of binary cross-entropy you need.

Loss is low, but the prediction is off

  1. Adjusting lambda_coord or lambda_noobj to the loss based on your observation.
  2. If you are traininig on your own dataset, and the dataset is relative small (<30k images), you should intialize weights from a COCO pretrained model first.
  3. Double-check your non max suppression code and adjust some threshold (I’ll talk about NMS later).
  4. Make sure your obj_mask in the loss function isn’t mistakenly taking out necessary elements.
  5. Again and again, your loss function. When calculating loss, it uses relative xywh in a cell (also called tx, ty, tw, th). When calculating ignore mask and IOU, it uses absolute xywh in the whole image, though. Don’t mix them up.

Loss os low, but there’s no prediction

  1. If you are using a custom dataset, please check the distribution of your ground truth boxes first. The amount and quality of the boxes could really affect what the network learn (or cheat) to do.
  2. Predict on your training set to see if your model can overfit on the training set at least.

Multi-GPU training

Since the object detection network has so many parameters to train, it’s always better to have more computing power. For example, On MSCOCO 2017 Dataset, 8x Telsa V100 GPUs can train a whole epoch in 10 minutes, while 1x V100 would need more than 1 hour.

However, TensorFlow 2.0 doesn’t have great support over multi-GPU training so far. To do that in TF, you’ll need to pick a training strategy like MirroredStrategy, as I did here. Then wrap your dataset loader into a distributed version too. One caveat for distributed training is that the loss coming out of each batch should be divided by the global batch size because we are going to reduce_sum over all GPU results. For example, if the local batch size is 8, and there’re 8 GPUs, your batch loss should divide a global batch size of 64. Once you summed up losses from all replica, the final result will be the average loss of a single example.


The final component in this detection system is a post-processor. Usually, postprocessing is just about trivial things like replacing machine-readable class id with human-readable class text. In object detection, though, we have one more crucial step to do to get final human-readable results. This is called non maximum suppression.

Let’s recall our objectness loss. When is false proposal has great overlap with ground truth, we won’t penalize it with noobj_loss. This encourages the network to predict close results so that we can train it more easily. Also, although not used in YOLO, when the sliding window approach is used, multiple windows could predict the same object. In order to eliminate these duplicate results, smart researchers designed an algorithm called non maximum supression (NMS).
(Credit to Analytics Vidhya)

The idea of NMS is quite simple. Find out the detection box with the best confidence first, add it to the final result, and then eliminates all other boxes which have IOU over a certain threshold with this best box. Next, you choose another box with the best confidence in the remaining boxes and do the same thing over and over until nothing is left. In the code, since TensorFlow needs explicit shape most of the time, we will usually define a max number of detection and stop early if that number is reached. In YOLO v3, our classification is not mutually exclusive anymore, and one detection could have more than one true class. However, some existing NMS code doesn’t take that into consideration, so be careful when you use them.


YOLO v3 is a masterpiece in the rising era of Artificial Intelligence, and also an excellent summary of Convolution Neural Network techniques and tricks in the 2010s. Although there’re many turn-key solutions like Detectron out there to simplify the process of making a detector, a hands-on experience in coding such sophisticated detector is really a great learning opportunity for machine learning engineers because merely reading the paper is far from enough. Like Ray Dalio said about his philosophy, “Pain plus reflection equals progress.” I hope my article could be a lighthouse in your painful journey of implementing YOLO v3, and perhaps you can also share the delightful progress with us later.


Gender Swap and CycleGAN in TensorFlow 2.0

To view the full neural network model and training scripts, please visit my Github repo here:


Recently, the gender swap lens from Snapchat becomes very popular on the internet. There’re many buzzwords about Generative Adversarial Networks since 2016 but this is the first time that ordinary people get to experience the power of GANs. What’s more extraordinary about this lens is its great real-time performance which make it just like looking into a magic mirror. Although we can’t know the exact algorithm behind this virus lens, it’s most likely a CycleGAN which is introduced in 2017 by Jun-Yan, Taesung, Phillip and Alexei in this paper. And in this article, I’m going to show you how to implement a gender swap effect with TensorFlow 2.0 just like Snapchat does.


First of all, I want to quickly go over the basics of Generative Adversarial Network (GAN) to help those readers who are not familiar with it. In some scenario, we want to generate an image which belongs to a particular domain. For example, we’d like to draw a random interior design photo. So to ask the computer to generate such an image, we need a mathematical representation of the interior design domain space. Assume there’s a function F, and a random input number x. We want y = F(x) to always be very close to our target domain Y. However, this target domain is in a very high dimensional space so that no human-being can figure out explicit rules to define it. A GAN is such kind of a network, by playing a minimax game between two AI agent, it can eventually find out an approximate representation F of our target domain Y.

So how does GAN accomplish it? The trick here is to break down the problem into two parts: 1. We need a generator to keep making new images out of some random number 2. We need a discriminator to give feedback for the generator about how good the generated image is. The generator here is just like a young artist who has no idea how to paint but want to fake some masterpiece, and the discriminator is a judge who can tell what’s wrong in the new paint. The judge doesn’t need to know how to paint by himself. However, as long as he’s good at telling the difference between a good one and a bad one, our young painter can benefit from his feedback for sure. So we use Deep Learning to build a good judge and use it to train a good painter in the meantime.

To train a good judge, we need to feed both the authentic image and the generated image to our discriminator together. Since we know which is authentic and which is fake beforehand, the discriminator can update its weights by comparing its decision and the truth. For the generator, it takes a look at the decision from the discriminator. If the discriminator seems more agreeable with the fake image, it indicates that the generator is heading in the right direction, and vice versa. The tricky part is that we can’t just train a judge without a painter, or a painter without a judge. They learn from each other and try to beat each other by playing this minimax game. If we train the discriminator too much without training the generator, the discriminator will become too dominant, and the generator won’t ever have a chance to improve because every move is a losing move.

Eventually, both generator and discriminator will be really good at their job. Also, by then, we will take the generator out to perform the task independently. However, what’s on earth does this GAN has anything to do with gender swap? In fact, this is a more straightforward problem than the one I just mentioned above. Now, instead of generating a random image of a target domain, we can skip the random number step and use the given image as input. Let’s say we want to convert a male face to a female face. We are looking for a function F, by taking a male face x, can output a value y that’s very close to the real female version y_0 of that face.

The GAN approach sounds clear. But we haven’t discussed one curial caveat in the above approach. To train the discriminator, we need both true image and false image. We get the false image y from the generator, but where do we get the true image for that specific male face? We can’t just use a random female face here because we want to preserve some common trait when we swap the gender and a face from a different person would ruin that. But it’s also really hard to get the paired training data as well. You could go out to find those real twin brother and sisters and take pictures of them. Or you can ask a professional dresser to ‘turn’ a man into a woman. Both are very expensive. So is there a way for the model to learn the most important facial difference between man and woman from an unpaired dataset?


Fortunately, scientists discovered ways to utilize unpair training data. One of the most famous models is called CycleGAN. The main idea behind CycleGAN is that, instead of using paired data to train the discriminator, we can form a cycle, where two generators work together to convert the image back and forth. More specifically, generator A2B first generate an image from domain A, and then generator B2A use that as input to generate another image from domain B. We then set a goal to make sure the second image (reconstructed image) looks as close as the first input. For example, if a generator A2B first converts a horse to a zebra, then generator B2A converts that zebra back to a horse, the newly generated horse should look identical to the very original horse. In this way, the generator will learn to not generate some trivial changes, but only those critical differences between the two domains. Otherwise, it probably won’t be able to convert it back. With this goal setup, we can now use unpaired images as training data.


In reality, we need two cycles here. Since we are training generator A2B and generator B2A together, we have to make sure both generators is improving over time; otherwise, it will still have a problem to reconstruct a good image. Moreover, as we discussed above, improving a generator means we need to improve the discriminator in the meantime. In the cycle A2B2A (A -> B -> A), we use discriminator A to decide if the reconstructed image is in domain A. Thus, discriminator A will be trained. Likewise, we also need a cycle B2A2B so that discriminator B can be trained as well. If both discriminator A and discriminator B are well trained, it means our generator A2B and B2A can improve too!

There’s another great article here for CycleGAN for further reading. Now that you get the main idea of this network, let’s dive deep into some details.


It’s recommended here that Adam is the best optimizer for GAN training. Although I don’t know the reason behind, the linear learning rate decay from the original paper looks quite effective during the training. It remains at 0.0002 for the first 100 epochs, then linearly decay to 0 in the next 100 epochs. Here total_batches is the number of mini-batches for each epoch because our learning rate scheduler only considers each mini batch as a step.

gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1)
optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)


Network Structure

CycleGAN uses a regular generator structure. It first encodes the input image into a feature matrix by applying 2D convolutions. This is used to extract valuable feature information from local or global.

Then, six or nine layers of ResNet blocks are used to transform the features from the encoder into the features in the target domain. As we know, the skip connection in ResNet block helps the network to memorize the gradients from previous layers, which makes sure the deeper layers can still learn something. If you are not familiar with ResNet, please refer to this paper.

Finally, a few layers of deconvolution is used as a decoder. The decoder converts the features from the target domain into an actual image from the target domain by upsampling.

Unlike the idea from VGG and Inception network, it’s recommended for a GAN to use a larger convolution kernel size like 7X7 so that it can pick up broader information instead of just focusing on details. It makes sense because when we reconstruct an image, it’s not only the details matters but also the overall pattern. Also, reflection padding is used here to improve the quality around the image border.

def make_generator_model(n_blocks):
    # 6 residual blocks
    # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
    # 9 residual blocks
    # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
    model = tf.keras.Sequential()

    # Encoding
    model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
    model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))

    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))

    model.add(tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', use_bias=False))

    # Transformation
    for i in range(n_blocks):

    # Decoding
    model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))

    model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False))

    model.add(tf.keras.layers.Conv2D(3, (7, 7), strides=(1, 1), padding='valid', activation='tanh'))

    return model

Loss Functions

There’re three types of loss we care about here:

  • To calculate the GAN loss, we measure the L2 distance (MSE) between the generated image and the truth image.
  • To calculate the cyclic loss, we measure the L1 distance (MAE) between the reconstructed image from the cycle and the truth image
  • To calculate the identity loss, we measure the L1 distance (MAE) between the identity image and the truth image

GAN loss is the typical loss we use the GANs, and I won’t discuss much here. The interesting parts are cyclic loss and identity loss. The cyclic loss measures how good the reconstructed image is, which helps both generators to catch the essential style difference between the two domains. The identity loss is optional, but it helps to avoid the generator to make unnecessary changes. The way it works is that, by applying generator A2B to a real B image, it shouldn’t make any changes as it’s already the desired outcome. According to the author, this mitigates some weird issues like background color change.

def calc_gan_loss(prediction, is_real):
    # Typical GAN loss to set objectives for generator and discriminator
    if is_real:
        return mse_loss(prediction, tf.ones_like(prediction))
        return mse_loss(prediction, tf.zeros_like(prediction))

def calc_cycle_loss(reconstructed_images, real_images):
    # Cycle loss to make sure reconstructed image looks real
    return mae_loss(reconstructed_images, real_images)

def calc_identity_loss(identity_images, real_images):
    # Identity loss to make sure generator won't do unnecessary change
    # Ideally, feeding a real image to generator should generate itself
    return mae_loss(identity_images, real_images)

To combine all losses, we also need to assign some weights for each loss so indicate the importance. In the paper, the author proposed two Lambda parameters, which 10x the cycle loss and 5x the identity loss. Note the usage of the GradientTape here, we record the gradient and apply gradient descent for both generators together. Here, real_a is the truth image from domain A, real_b is the truth image from domain B. fake_a2b is the generated image from domain A to domain B. and fake_b2a is the generative image from domain B to domain A.

def train_generator(images_a, images_b):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:
        # Use real B to generate B should be identical
        identity_a2b = generator_a2b(real_b, training=True)
        identity_b2a = generator_b2a(real_a, training=True)
        loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
        loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)

        # Generator A2B tries to trick Discriminator B that the generated image is B
        loss_gan_gen_a2b = calc_gan_loss(discriminator_b(fake_a2b, training=True), True)
        # Generator B2A tries to trick Discriminator A that the generated image is A
        loss_gan_gen_b2a = calc_gan_loss(discriminator_a(fake_b2a, training=True), True)
        loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a)
        loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b)

        # Total generator loss
        loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
            + (loss_cycle_a2b2a + loss_cycle_b2a2b) * 10 \
            + (loss_identity_a2b + loss_identity_b2a) * 5

    trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables
    gradient_gen = tape.gradient(loss_gen_total, trainable_variables)
    optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))


Network Structure

Similar to other GANs, the discriminator consists of some 2d convolution layers to extract features from the generated image. However, to help the generator to generate a high-resolution image, CycleGAN uses a technique called PatchGAN to created more fine-grained decision matrix instead of one decision value. Each value in this 32×32 decision matrix maps to a patch of the generated image, and indicate how real this patch is.

In fact, we don’t crop a patch of the input image during implementation. We just need to use a final convolution layer to do the job for us. Essentially, the convolution layer performs like cropping a patch.

def make_discriminator_model():
    # C64-C128-C256-C512
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=(256, 256, 3)))

    model.add(tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False))

    model.add(tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same', use_bias=False))

    model.add(tf.keras.layers.Conv2D(512, (4, 4), strides=(1, 1), padding='same', use_bias=False))

    # This last conv net is the PatchGAN
    model.add(tf.keras.layers.Conv2D(1, (4, 4), strides=(1, 1), padding='same'))

    return model

Loss Function

The loss functions for discriminators are much more straightforward. Just like typical GANs, we tell the discriminator to treat truth image as real, and the generated image as fake. So we have two losses for each discriminator loss_real and loss_fake, both have an equal effect on the final loss. In calc_gan_loss, we are comparing two matrices. Usually, the output of a discriminator is just one value between 0 and 1. However, as we mentioned above, we use a technique called PatchGAN, so the discriminator will produce one decision for each patch, which forms a 32×32 decision matrix.

def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:

        # Discriminator A should classify real_a as A
        loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True)
        # Discriminator A should classify generated fake_b2a as not A
        loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False)

        # Discriminator B should classify real_b as B
        loss_gan_dis_b_real = calc_gan_loss(discriminator_b(real_b, training=True), True)
        # Discriminator B should classify generated fake_a2b as not B
        loss_gan_dis_b_fake = calc_gan_loss(discriminator_b(fake_a2b, training=True), False)

        # Total discriminator loss
        loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
        loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5
        loss_dis_total = loss_dis_a + loss_dis_b

    trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables
    gradient_dis = tape.gradient(loss_dis_total, trainable_variables)
    optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))


Now that we have defined both models and loss functions, we can put them together and start training. By default, the eager mode is enabled in TensorFlow 2.0, so we don’t have to make the graph. However, if you are a careful person, you might found that both discriminator and generator training functions are decorated with a tf.function decorator. This is the new way introduced by TensorFlow 2.0 to replace the old tf.Session(). With this decorator, all operations within will be converted into a graph. Hence, the performance could be much better compared with the default eager mode. To learn more about tf.function, please refer to this article.

One thing to mention is that, instead of feeding the generated image to the discriminator directly, we are actually using an image pool here. Each time, the image pool will randomly decide to give the discriminator a newly generated image, or a generated image from past steps. The benefit of doing this is that the discriminator can learn from other cases and sort of having a memory about the hacks the generator uses. Unfortunately, we can’t use this random image pool in graph mode at the moment, so we need to put them back to CPU when selecting a random image from the pool. This indeed introduces some cost.

The model illustrated in this article is trained on my own GTX 1080 home computer, so it’s a bit slow. On a V100 16G GPU and 64G RAM instance, though, you should be able to set the mini batch size to 4, and the trainer can process one epoch of 260 mini batches in 3 minutes for the horse2zebra dataset. So it takes about 10 hours to train a horse2zebra model fully. If you reduce the image resolution and some network parameters correspondingly, the training could be faster. The final generator is about 44mb each.

def train_step(images_a, images_b, epoch, step):
    fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)

    fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
    fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)

    dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)

def train(dataset, epochs):
    for epoch in range(checkpoint.epoch+1, epochs+1):
        for (step, batch) in enumerate(dataset):
            train_step(batch[0], batch[1], epoch, step)

To see the full training script, please go visit my repo here.


Let’s see some inference results on a few datasets. Among those, horse2zebra and monet2photo is the original dataset from the paper. And the CelebA dataset is from here.


Horse -> Zebra



Zebra -> Horse




Monet -> Photo



Photo -> Monet




Male -> Female



Female -> Male


Gender Swap

We successfully mapped a male face to a female face, but to use it in a production environment, we need to pipeline to orchestrate lots of other steps together. Snapchat may have its own optimization or models. But, here’s the procedure that I think will help to improve the final result of our CycleGAN.

  1. Run face detection to find a bounding box and keypoints for the most dominant face in the picture.
  2. Extend the bounding box a little bit bigger to match the training dataset distribution.
  3. Crop the picture with this extended bounding box and run CycleGAN over it.
  4. Patch the generated image back to the original picture
  5. Overlay some hair, eyeliner, and beards on top of the new face picture based on the keypoints we had from the last step

I get inspiration mostly from this great article that explains how this pipeline works in details.


Lastly, I want to throw out some questions I have. I don’t know the answers for them, but I hope those who have experiences of building similar products can share their opinions in the comments below.

  • The CycleGAN model turns out to be 44mb, with quantization it could become 12mb but still too large. What are the effective methods to make them usable on those mobile and embedded devices?
  • The output image resolution isn’t great and lost much of sharpness. How to generate a bigger image such as 1024×1024 without blowing up the model size? Will a super-resolution model help in this case?
  • How do we know if a model is thoroughly trained and converged? The loss isn’t good metrics here, but we also don’t know what’s “best” output. How to measure the similarity between the two styles?


  1. Understanding and Implementing CycleGAN in TensorFlow
  2. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
  3. Image-to-Image Translation with Conditional Adversarial Networks
  4. Gender and Race Change on Your Selfie with Neural Nets

Advertising System Architecture: How to Build a Production Ad Platform


Digital advertising has become one of the most important revenue sources for lots of technology companies nowadays. Google, Facebook, Twitter and Snapchat, all these famous names monetize heavily on people’s attention. Also, it may be one of the most important new businesses in the 21st century. Unlike those traditional advertising techniques, digital advertising unlocks far more reliable insights about the campaign performance. Therefore, more and more advertisers start to put more budget on online ads, and more and more technology companies start to offer advertising options on their platform. However, how in the world does these ads engine work behind the scene? What if we want to build our own ads system to serve ads on our platform as well? (But, to be honest, for a small business or individual developer, I would suggest you use existing solutions like Google DFP). Today, I’m going to show you the technical system design of such a system, so maybe you could get some inspiration after reading.

System Overview

Ads system is no different than a live exchange. It exchanges information between two parties: advertisers and end users. In an ideal world, we would just need: 1) a web interface where advertisers or internal salespeople can manage the order. 2) your client application where the ads get displayed. 3) a backend system that maintains these orders and filters out which ad to present. Unfortunately, this design won’t scale well for a serious advertising business due to few constraints.

The first constraint is the frequency. Advertisers set up their target audience and budge to reach as many audiences as possible. However, as a platform, you don’t want to hurt user experiences by showing too many ads too. Even if you are always suggesting useful information to users, it’s still a distraction from your primary service.

The second constraint is the speed. Unlike posts, images, and videos, ads are a lot more dynamic. An advertising campaign may reach its budget any time and may get canceled any second. However, to make ads a consistent part of our platform, we have to make it lightning fast. When the ads response time is longer than other services, we even have to drop the advertising opportunity completely.

The third constraint is the quality. Although satisfying the first two constraints makes the ads engine run smoothly, you have to provide the right ad to where it’s needed the most to maximize your profit. Advertisers are usually willing to pay tens or hundreds of times of money if you can get a click for their website or an install for their app. So, the more interaction you gain from users, the more money you make. Hence we need to maximize the probability of such interactions.

All these constraints make a simple information exchange problem much more complicated now. So this is why we also need to build so many peripheral components. Ranking service to help improve the ad quality, pacing service to help control the ad frequency, a low-latency delivery module to guarantee the response time, etc.

The above image shows the overall architecture of a fully functioning ads system. In bigger companies like Google and Facebook, their system may have more enhancement to meet the specific needs. However, this architecture could get you a pretty good start. As you can see, advertising is a cycle between advertisers and end users. The system we build first collect campaign information from advertisers, then build ad indices based on pacing and forecasting. With this ad index and the user profile from a targeting service, the ad server asks ranking service to score candidate ads and find the most suitable ad for the current ad request. Once the user finished interacting with the ad, either skip it or click it, we collect the metrics and send them back to metrics collectors. A data pipeline behind these collectors will aggregate the event history into more valuable real-time stats and business log. The real-time stats, in turn, helps the ad server and pacing service for more accurate control of delivery. The business log will be used by inventory forecasting and billing so that advertisers can then set up more campaigns to experiment with their marketing strategy.

Next, I will dive deep into each component in this ads system to discuss some more technical details.


Web Interface

Like I mentioned above, an ads system is essentially an information exchange. So the first thing is to have a way to feed all the information into this exchange. Usually, we could set up a web interface to help people manage their advertising campaigns. This web interface is usually a single page Javascript application that can handle complicated form input, large tables, and rich multimedia content. Just register an account at Google Ads, Facebook Ads Manager or Snap Ads Manager, you will have a rough understanding about how this UI should be like.

Before I jump into the specific technical problems, let’s get familiar with the typical digital advertising hierarchy first. Usually, an advertiser creates a Campaign which represents a marketing initiative. In the meantime, there’s also a notion of Ad which is the single unit of an ads that will be delivered to the audience. To provide fine-grained control, large advertising platforms also introduce another layer called Ad Set or Line Item in between Campaign and Ad, to group a bunch of similar Ad. Also, to make it possible for an advertiser to switch out the actual content of an Ad, there is a notion of Creative or Media which used to represent the content of an Ad. This abstraction could help us decouple the logic well. These four entities are the most important things to make an ad run, but there’re also some other auxiliary entities just like any other platform: Account, Payment Source, Predefined Audience, etc. But for now, let’s focus on the main flow.

One of the biggest challenges that an advertiser web UI could have is complex state management. For an application like the above, there could be hundreds or even thousands of different states to track at the same time. For example, to make it possible to buy a simple video ad, the software needs to maintain all the hierarchy, metadata and targeting information in a temporary place before the user commits the data. Each entity would also have tons of different fields that might interlock with other fields. The objective of a Campaign could affect the type of ad that you can buy. The location targeting of an Ad Set could define a minimum daily budget. The type of an Ad could also limit the options of “call to action”.

Another big challenge is a large variety of UI components which is required by the business flow. For instance, let’s take a look at the targeting section when creating a campaign. The Location targeting needs a map component. The audience group requires a tree select component. The duration setting is a datetime range selector. Furthermore, all these components could appear in any place of your application, so you’d better reuse or at least abstract most of them.

The third challenge I would like to mention is the big table experiences. As we know, there’re tens of fields to control how to run an advertising campaign. In fact, there’re also tons of other columns in the table are used to report metrics of a given entity. Different advertisers may rely on different metrics to measure the performance of the ads, so your main table should be versatile enough to show tons of columns given any order or preference.

Thankfully, Facebook has open sourced their UI framework React.js a few years ago. By adopting the Flux philosophy and a good encapsulation of components, I believe it’s by far the most comfortable way to build an advertiser web app. The Flux pattern addresses the headache of intertwined states, and the JSX makes it so easy to write reusable UI components. In addition, you could also add Redux to your tech stack to make state transition more predictable and maintainable.

With the proper tech stack, we can now divide this web UI into the following major areas:

  • Home Tables: Where all the entities like Campaign, Ad are shown. Further editing could be made from here too.
  • Creation Flow: A wizard form that helps advertiser (or internal ad ops) to place an order step by step.
  • Stats and Reporting: Where advertisers can track the ad performance like impression counts, and also export the data for reference
  • Asset Library: A place to manage all the multimedia content that a user has uploaded. Most of the time, there’re dedicated creative studios which help advertisers to make the asset. But, we can also offer some basic media editing web tools to help smaller business that don’t have budget for professional creative service.
  • Billing and Payment: To manage the payment source and view the billing information. This part is usually integrated with Braintree APIs to support more funding source.
  • Account: To manage role hierarchy and access control of multiple users. Usually, a big organization would have multiple ad accounts. Also, different roles would have different access.
  • Audience Manager: This might not be necessary at the beginning. However, it could be convenient for frequent ad buyer to be able to define some reusable audience group so that they don’t need to apply complex targeting logic every time.
  • Onboarding: This may be the most critical part of the system at an early stage. A good onboarding experience could increase sign-up significantly.

With a good web interface, the friction of buying an ad in our system could be reduced to the minimum from now on. However, keep in mind that the graphic user interface isn’t the only entrance of our system. Next, let’s take a look at Ads APIs (Application Programming Interface).


First of all, what’s Ads APIs? Why do we need it? As you can see from the last paragraph, our web interface needs to handle a very sophisticated form and help our clients (advertisers) to manage their orders. To persist all these changes and also provide data for the UI, we will need a service layer to do CRUDs. However, if that’s the only functionality, we won’t call it APIs. It can just be like any other backend.

In fact, advertisers usually won’t put all eggs into one bucket. In addition to using this web interface we build for them, they will also consult with some advertising agencies and spend part of their budget there. Those ad agencies usually have their own software to track marketing campaigns and have direct access to all major digital advertising platform. Therefore, Ads APIs is also meant to be used by agencies. If we build our in-house solution on top of this external Ads API, we will be able to identify problems earlier than our APIs customers.

To cope with third-party agencies, often the best shot is to build a RESTful APIs, as it’s almost the standard way for two unfamiliar parties to communicate. There’re also fancier solutions like gRPC and GraphQL, but RESTful can guarantee the most compatibility here. Besides, it’s also easier to write public documents for your RESTful APIs because things are grouped by resources and methods. Take a look at Twitter’s API references:

Now that we have a general idea of how to structure our APIs. I want to talk about four pillars in implementing these APIs briefly.

Campaign Management
In short, campaign management is the CRUD of all advertising entities that I listed in the last section. However, from a business perspective, it’s more complicated than just persisting some values. First, we need to deal with tons of validation for all sort of business rules. Hence, a proper design pattern like Facade and consistent abstraction using interfaces and inheritance are important. Second, to ensure the workflow to be manageable, it’s also likely to use a state machine here to maintain the different status of campaigns, ad set or ad. Third, lots of ads operations are long-running or asynchronous. For example, whenever a new video gets uploaded, we need to trigger a build process to prepare different resolutions/encoding of the video to be usable on multiple platforms. This type of asynchronous jobs are usually controlled by some PubSub or TaskQueue system and integrate with the state machine I mentioned before. Fourth, your database should support transactions in a scalable way because most of the operations in Campaign Management may affect multiple tables.

Access Control
Since these APIs will be used by external users eventually, we need to be really careful here on AuthN and AuthZ. A typical cookie session authentication system is acceptable. But more often, a token secret is preferred for stateless RESTful APIs. We can either use third-party OAuth 2.0 or build our own. A simple JWT token exchange may not be secure enough as we are talking about real money business here. Authorization is also a big part. In the ads system, there could be tens of different roles like Account Manager, Creative Manager, Account Executives, etc. A creative manager may upload new creative but not allowed to create a new campaign. An account manager is allowed to create a new campaign, but only an account executive is allowed to activate it. Luckily, most entities in the ads system fall into some hierarchy. Accounts belong to an organization; campaigns belong to an account, etc. Hence, we could adopt the chain of responsibility pattern, and track their access of certain entity through the hierarchy chain.

Billing and Payment
In the beginning, you may rely on the internal operation and direct sales of your ads. So the line of credit and coupons could be enough to handle payments. After all, operation guys can do it all in their ERP or even spreadsheet. However, to accept public payment source like a credit card, most likely you need to ask some third party service to help you. Braintree and Stripe are both leading payments solution for the enterprise. Furthermore, another concern from the external payment source is the risk of abuse and spam. Proper rate limit, anomaly alarm, and regular audit shall be enforced to avoid such risk.

Metrics and Reporting
Last but not least, Ads API also handles the reporting for both ad agencies and our web interface. The challenge of reporting could vary a lot based on the data warehouse and the way you collect metrics. So be careful when you design the metrics collection pipeline, especially for aggregated results. For example, some data might not be available when the granularity is WEEKLY in stats query. However, one thing remains the same for all reporting service is that the QPS is usually higher than other endpoints. You could support batch stats query to reduce the additional requests, but still, people check the metrics much more often than actually putting a new order.

There’re still many more modules in a real production Ads APIs system, such as budget control, content review pipeline and so on. We can also incorporate machine learning models to do auto-tagging to expedite the review process. However, you should already have a good understanding of where to start now.

Ad Index Publisher

Advertisers can create a campaign and throw in their own image or video ads in our system now. This is great. From now on, we are stepping into the user’s side of the system. To determine which ads to be delivered to which user, the first thing we need to know is that the active ads at this moment. The easiest way to achieve this is by querying the database and filter by the status field. However, the query could usually take too long and couldn’t meet our speed requirement. The database tables are usually structured in a way for an easy write but not an easy read. For example, you could have four tables which record Campaign, Ad Set, Ad and Creative. This is easy when we want to update some values for a specific entity. However, when we serve the ads, we have to query four tables to actually know that if a Creative and all its parent are in active status.

To solve this issue, we need an index publisher to pre-calculate lots of useful indices and save time for the later serving. It publishes the indices to some storage service, and the ad server loads them into memory periodically. One of the challenge to generate the live index is the various business rules we need to apply, like spend limit and ad scheduling. Also, those tables could relate to each other and require a very complex validation. To manage the dependency here, we could introduce Spark or Dataflow. This would often lead to a multi-stage data pipeline like this:

In general, we need to generate three types of index:

Live Index
This is an index which tells us all the live ads in the system. Also, it contains all the necessary information that the ad server needs to form an ad response, such as the resource location and ad metadata. Besides the primary index from id to metadata, we could also build some secondary indices based on targeting rules. The ad server uses these secondary indices to filter out irrelevant ads and only preserve the best candidates for auctions. I’ll discuss auction and filtering in the ad server section.

Pacing Index
Another index we need to prepare is about pacing status and factors. We intentionally separate this with the live index because pacing usually requires much more calculation, so we want it to be independent. It also helps to make our system more resilient because we can still serve live ads when there’s a pacing issue.

Feature Index
This index contains ad features which will be used by ranker later. We can also replace this index with a low-latency datastore like Cassandra or in-memory database like Redis.


Before we get into the actual delivery of an ad, let’s consider this scenario first. An advertiser may want to advertise an ad over a month to get 100K impressions. However, we don’t want to exhaust all the impressions in the first few hours of the campaign. Instead, we’d like to deliver it throughout the lifetime of the campaign. Also, from the end user’s perspective, we don’t want to overwhelm them with the same ad at one time while showing them nothing at another time, which is also called ad fatigue. The mechanism to control this delivery process is called pacing. Pacing is like a north star for our ad system; the pacing index we generated will guide the direction of ad delivery later.

The simplest way to do this is by splitting the budget into an hour or minute trunk. If one ad exhausts its budget within the current minute, it will be filtered out from index publishing. However, this way doesn’t have much flexibility for more fine-grain control, and it still bursts in a smaller period of time.

One of the most traditional ways to control the pace is called PID controller.

For the detailed explanation of PID controller, you can refer to Wikipedia. In short, by analyzing the difference between the desired state and current state, this controller can tell us how much input should we give to the delivery system. If the pacing is lagging, this controller will tell us to give a bigger input, which translates to a higher pacing factor. And a higher pacing factor would end up in higher bid, thus beat everyone else to win the ad opportunity.

It’s easy to know the current state (current total impressions, current total click, etc.) by connecting with the metrics system. Yet, what about the desired state? As we know, the X-axis of this PID controller is the time. What about Y-axis? To start simple, we can project a linear line from 0 to the total impression (or click, depends on the configuration) that the ad wants to reach. Instead of having the total number as Y-axis in our PID controller, we use the rate (the number of delivered impressions per minute) to reflect the desired state. If our current delivery rate is lower than the desired rate, then our pacing factor needs to increase.

One trick to notice is that sometimes PID controller could get too slow to start. Also, the fluctuation could be pretty significant sometimes. Therefore we could introduce some more multipliers into this formula. With this pacing factor, we can now implement a simple pacing service to make sure smooth delivery of ads. There’re also other pacing techniques to address some particular problems. For example, we can also make a local pacing factor for each ad server machine to balance the differences in a distributed system. Moreover, a reach and frequency factor could also be introduced to make sure the frequency requirement is satisfied.


At this section, I’m going to discuss the heart of the ad engine: Auction. An auction is a process of buying and selling goods or services by offering them up for bid, taking bid and then selling the item to the highest bidder. In acution-based ads systems, ads from different advertisers participate in an , and the ad with the highest bid win the auction and will be shown to the user.

When we combine estimated value with the incoming inventory (opportunity) information from an ad request, we can now determine which ad to show for this opportunity. The trick to add pacing into an auction system is to multiply the bid value using the pacing factor. The more urgent an ad is, the pacing factor is bigger and therefore lead to a higher bid.

To build an auction house, the first thing you need to decide is which auction strategy to use. Let’s assume we are going to use the most common strategy called Second-Price Auctions. In real time bidding, the second-price auction gives the winner a chance to pay a little less than their original submitted offer. Instead of having to pay the full price, the winning bidder pays the price offered by the second-highest bidder (or plus $0.01), which is called the clearing price. This could lead to an equilibrium that all bidders are incentivized to bid the true value. Although advertisers give us a maximum bid value, we also need to consider other factors like the probability of the event occurring and our pacing status. By using a formula like below, we can then calculate a real bid price to use in the current auction.

Total Bid = k * EAV + EOV
Where k is the pacing factor, EAV is the estimated advertiser value, and EOV is the estimated organic value

Advertiser value is the actual return that an advertiser could gain from displaying this ad. Different advertising goal might have different ways to calculate EAV. If we want to optimize for clicking, then:

EAV_click = b_click * p_click
Where b_click is the max bid price for getting a click, p_click is the probablity of getting a click

Organic value is the benefit to the platform or user experience. To calculate the advertiser value, the simplest way is to multiply the probability of an event with the max bid price of this type of event. For example, if the probability of getting a click for an ad is 0.1 if we deliver this ad, and the advertiser is willing to pay $1 for a click, the AV will be 0.1*1=0.1 here. To calculate the organic value, the formula is different in different platform. For example, skipping an ad could mean a bad user experience, so the organic value could assign a negative weight to skip.

EOV = p_click * w_click + p_finish * w_finish + ... - p_skip * w_skip
Where p_event is the probability of such event, and w_event is the weight that this event contribute to the final value

In a more complex auction system, some more formula could be designed to reflect the real bid price based on different business priority. Bear in mind that the math your use determines the flavor of your ad engine. For guaranteed delivery, we could make up a high bid to have it win over all other ads in the auction.

Note that this bid price here is only for the auction. Since we are adopting second-price auction here, the advertiser only needs to pay the price offered by the second-highest bidder. Also, if we would like to penalize bad user experience ad, we can also do:

price = Total Bid(second highest bidder) - EOV(winner)

By doing so, if the winner ad has great value to the users, the price is going to drop, and its ROI will increase, vice versa.

So far, we talked about the strategy to use when comparing different bidders, and also how to calculate the real bid price. The final piece of the puzzle is the actual auction engine that connects all these. When a request comes into the auction engine, we first sort all the candidate ads and find the highest bidder for the current opportunity. Sometimes, there’re some business reasons to group candidates into several priority groups. For instance, a first party promotion could be more important than anything else. Naturally, the auction engine now becomes like a waterfall. The opportunity request falls through each priority group and does the auction only within the group. Only when there’s no suitable candidate found will the request goes to the next group. One caveat for this auction is that delivery ads have a cost as well (network bandwidth, computation resource, etc.). Hence we could also set a floor bid price to filter out those candidates with a neglectable bid price.

However, there’s a limitation here that we are only auctioning for one item (like one ad slot) at one time. In some case, we need to be able to auction multiple items together because they are related. In this case, we could implement the generalized second price auction. The generalized second-price auction (GSP) is a non-truthful auction mechanism for multiple items. Each bidder places a bid. The highest bidder gets the first slot, the second-highest, the second slot and so on, but the highest bidder pays the price bid by the second-highest bidder, the second-highest pays the price bid by the third-highest, and so on.


In general, the problem that an ads ranking system tries to solve is like this: Given an opportunity to pick an ad for a user, deliver the best ad by taking into account the user’s past behavior, user interest and ads historical performance that maximizes the ROI for the advertisers. So how does it work? Remember the formula to calculate the real bid price for each ad? We not only need the ad max bid price but also need the probability of the event and the weight of organic values. In the most naive version of the ad server, this ranker could return a hardcoded score. The auction engine will loop through all candidate ads and calculate the EOV and EAV super quickly. However, if we extend this ranker into a separate service, we can incorporate more techniques like machine learning to improve performance and profit.

To build such a machine learning pipeline, we need to do the data cleaning and aggregation first. This step is also called feature engineering, which extracts useful information from the raw user and ad metrics. Common features include but not limited to:

  • Context features: time, location, device, session, etc
  • Demographic features: gender, age, nationality, language, etc
  • Engagement features: view count, click count, etc
  • Ad features: campaign id, brand id, account id, etc
  • Fatigue features: number of times the user has seen this ad, brand, etc
  • Content features: category, tags, object detection, etc

With the help of Dataflow/Spark/Flint, we can batch transform all these features and put them into a feature store. The next step would be training a model using these features. Unlike computer vision or NLP, machine learning techniques in ads are usually more straightforward and more mature, such as Sparse Logistic Regression and Gradient Boosting Decision Tree. We can build a generic trainer using TensorFlow, XGBoost or scikit-learn. To run experiments, we feed different config protobuf to the training.

When the trainer finishes training and passes all the validation and evaluation, it publishes a model file to some remote repository and notifies the ranker to pick up this new model. Often a new model is only deployed to a small subset of the production fleet to verify the performance.

Because some input of features come from real-time data, we also need to build a feature index on a timely basis. As I discussed earlier, this index is prepared by the ad index publisher. The ranker would then look up things like ad stats, user engagement metrics from this index during runtime and perform the inference.

In a production machine learning pipeline, a central scheduling system orchestrates all the tasks. The system maintains the jobs information, tracks the status of each task and helps to manage the DAG of the workflow. It’s not hard to write this service from the ground up, but using container operations in Airflow or Kubeflow is also a good option.


Forecasting, or Inventory Forecasting, is a way to predict the number of impressions for a particular targeting requirement at a given time. Although not necessary for a minimum viable version of the ad system, inventory forecasting can be useful in lots of aspects. It can give our sales team or web interface an idea about how much future inventory can be booked. It can also improve the pacing system by providing the actual and predicted traffic, thus serving more impressions during periods of plenty, and tightening down when traffic is scant. Furthermore, it can also be used in new feature testing and debugging.

A naive implementation would be counting the potential future inventory (e.g., impressions) by looking back to the historical data, then categorize them into different buckets. Then, we take the new ad (or ad set, depends on where the targeting rules are applied) and find the buckets for it and accumulate the inventory. However, the downside is that we need to maintain these categories separately than our production serving logic which is not entirely scalable.

A more robust solution is to take the logic from the production serving path and simulate it. By simulation, I mean take the new ad, pretend that it’s a live ad, and feed in the historical impressions and try to “serve” it. Since it’s using the same logic as our production ad server, it more accurately reflects the future behavior.

However, if the traffic is enormous, a full simulation could take too long to run. If that’s the case, we need some techniques to improve performance. Generally, there’re a few methods:

  • Downsample the historical requests: For example, Only take 1/10th of all the impressions per user. The reason we need to sample by the user is that frequency capping could affect the delivery when there are more impressions.
  • Turn off some server functionalities: Some features we have on ad server is not necessary for us in the simulation. If we can turn off some unimportant ones like using a fixed weight instead of query ranking service, the run time for each simulated auction will be reduced.
  • Assign the inventory for existing ad separately: When we forecast the inventory for a new ad, we can reuse the inventory assignment for other existing ads from a recent forecasting job. Furthermore, we could set up a separate job to prepare the assignment for existing ads.
  • Parallel the simulation: Now that we have already assigned slots for existing ads, we only need to deal with one new ad at one time. Therefore, there won’t be any interference from other ads, and we can simulate the delivery in multiple machines in parallel.

This simulation design could help us understand how much inventory do we have for a new ad. But let’s also take a look at other benefits of simulations. First, the event generated from the simulated run could also be used as event estimation. We can now predict how much click or app install we could get from this ad. Second, we can further develop a reach and frequency booking system. Sometimes, an advertiser would like to book an inventory to guarantee to reach X user for Y times. If an ad has booked 100K impressions, we will take these 100K impressions out from our future simulations to make sure we are not overbooked.


Although machine learning can help us find potential good matches between ads and users, advertisers often have their own opinion about where the ad should be delivered to. Sometimes, the targeting audience would be really broad for brand awareness type of advertisers like Coca-Cola. Sometimes, the scope will be a small niche market like certain zipcode or certain hobby. Thus, we need a structural definition of those targeting rules. One confusing part is the AND / OR logic, and INCLUDE / EXCLUDE between each of the targeting rules.

    "demographics": [
            "occupation": ["teacher"]
    "geolocations": [
            "country": ["us"]

Although this structure is easy for human reading and interaction, it’s not quite so for our ad server to determine if there’s a targeting match for given targeting spec. A simple way to address this is to flatten this nested json blob into this so that we could loop through the list to find if there’s match during candidate filtering.

        "operation": "equals",
        "value": "us",
        "key": "country",
        "group": "geolocations"
        "operation": "equals",
        "value": "teacher",
        "key": "occupation",
        "group": "demographics"

Another way to address this is the Boolean Expression Matching algorithm. The need for fast boolean expression matching first arose in the design of high throughput low latency PubSub systems. Later, the researchers realized that e-commerce and digital advertising all face the same issue. Some standard algorithms are K-Index[6], Interval Matching[7], and BE-Tree[8].

With the targeting spec above, we can now understand who is the target audience of a given ad. However, we still need to associate it with the end user. To make this connection, a user profile is required when we match an ad with an ad opportunity. This user profile is compiled from all sort of historical user data, such as purchase history, and browsing history. (Yes, this is also why Facebook is so infamous on privacy, they profit on your data, but they won’t tell you). A targeting pipeline could infer user interests and demographics out of history. If you don’t have enough data to describe the end user to make such a connection, you could also integrate with some third-party data company to get more insights.

Ad Server

So far, I’ve introduced almost all the components needed to move an ad from the advertiser to an end user. Putting all these together, we can now build the final ad server. The ad server exposes some endpoints for our client (web or mobile), runs the auctions, calls ranking service, fetches user profiles, and consume the indices. To explains this in more details, let’s take a look at what happens when a new ad request comes in.

Request Anatomy

  1. When the client starts, it first talks to our ad server and let us know that there’s a new active user online now. The reason for doing this is because it gives us time to load user profile into some memory database or BigTable for a faster lookup later. In the meantime, the initialization response could also guide the client for further actions.

  2. Next, when the client realizes that we need to show an ad soon, it will send an ad request to our ad server and ask where to load the ad content. If the ad is just text, we can return the metadata together with the ad response. If the ad is a static file, a CDN location shall be returned. The load balancer will route the request to any healthy node/pod because they all contain the latest ad index.

  3. Given the context information carried in the request, and also the live ad index, we can filter out those ads that do not have targeting match first. Then, some other filters like frequency cap or budget cap can also be applied before sending the ad candidates into auctions.

  4. When the auctions finish, the ad server records the winner as served (or empty if no winner) in the metrics table. In the meantime, the ad server also send winner information back to the client side so that it can start to prepare the ad.

  5. Once the ad gets displayed, the client will send a tracking event back to our metrics collector.

Unlike other modules I discussed above, the ad server is on the very frontline of our ad business. A little problem within the ad server could mean an immediate failure of ad delivery, which is a loss of revenue essentially. Therefore, reliability is the most important attribute here. In the case of a non-core module failure, we need a fallback to some backup plan. For example, if the ranking system failed to respond, our ad server needs to have a default model to estimate the ad value. In the case of the total system failure, a quick rollback and re-deployment are critical. It’s usually not an easy job for ad server because of the considerable memory initialization. While we are optimizing the initialization code path, we can also adopt a canary deployment strategy or have some warm backup servers.

To serve the ad request in lightning speed, the trick is to load the index into the memory. However, there’s usually a limit on the memory size. The reasons for this limit are various. It could come from the cloud provider or the physical machines we have. It could be also because of the cost-efficiency requirement from the business side. If the index grows bigger than the biggest memory size we have for each machine, we need to start to split the index into groups. One option is to split index by region. An ad request from Europe only reaches the European cluster, which only hold the index for all European ads. This approach looks more natural at the beginning, but it also imposes a hard region limitation over the entire advertising workflow. Another alternative is to have a dispatcher to query multiple machines at the same time. Each machine in the same group has a different shard of the index.


Even when we successfully deliver the ad to our end users, it’s not the end of our journey. The last big piece of this puzzle is the metrics and stats service. Usually, there’re two types of workload on these stats service. On the one hand, business analysts will need to pull a large amount of historical data to find patterns, compiling campaign reports and generating invoices and bills. On the other hand, advertisers are highly interested in the real-time performance of their currently running ads. Its

These two different goals also require different infrastructure to ensure the highest efficiency. For business analytics, we usually store all the raw data into an OLAP database, or cloud analytics software like BigQuery. And then tooling and data engineering teams can build more pipelines starting from there. For real-time stats, we often store the data into some Time Series Database. We could either use existing solutions like OpenTSDB, InfluxDB or Graphite or build our TSDB query engine on top of some scalable databases like BigTable. Although different solutions have a different focus, the main concern here will be the granularity and retention period of the data. A small granularity and low latency is the key for such real-time stats.

Nevertheless, storing metrics in different places doesn’t mean that we need to build two different data pipelineS. A typical design would be a generic metrics collectors as a frontend to receive all metrics and logging from the client or other services. Then, a message queue system like Kafka or PubSub would stream all these events to some streaming data processing applications, such as Spark or Dataflow. The Spark application would transform and aggregate those raw events into our desired format, and then fan out to different data warehouse for persistence. In our situation, we are routing the final data to both TSDB and OLAP DB.

Bear in mind, in such kind of data pipeline, one bad design or discrepancy would lead to lots of serious issues or limitations to its downstream consumers. Although one cannot imagine all the future requirements during the initial design, we should still make sure the flexibility so that it can be easily extended in the future. For instance, some stats aggregation tasks would require data spanning over several days or weeks, or from different stages of the data transform output.


Congratulations! You are now ready to start to build your own advertising platform! Remember, the design I layout here is just the start of the journey. In a real-world scenario, it’s hard to ignore all the legacy and build a fresh new platform from scratch. Often, we need to make some tradeoff and add some redundancy due to existing tech stack, politics or business requirements. Moreover, building everything above will be a huge project that usually requires dozens of hundreds of engineers. It’s not only about the software itself, but also requires a strong infrastructure and tooling to ensure efficient deployment, scaling and resource planning.

That being said, you can still work on a part of the system first if you can’t hire hundreds or dozens of engineers soon. Some start-ups specialize in campaign management, which only needs an excellent web interface and APIs. Others could focus on ads ranking and use machine learning to maximize ROI.

If you like my overview of the ads system, feel free to share it with your friends and leave your comments. The ad tech is continually evolving, and I hope we can come up with more creative ideas through the discussion.

Disclaimer: This article is only about my personal opinion and has no association with any real companies or products.


  1. How Do First-Price and Second-Price Auctions Work in Online Advertising?
  2. Wikipedia: PID Controller
  3. Wikipedia: Generalized second-price auction
  4. Meet Michelangelo: Uber’s Machine Learning Platform
  5. Time Series Database (TSDB) Explained
  6. Indexing Boolean Expressions
  7. Efficiently Evaluating Complex Boolean Expressions
  8. BE-tree: an index structure to efficiently match boolean expressions over high-dimensional discrete space
  9. 大型广告系统架构概述