To view the full neural network model and training scripts, please visit my Github repo here:
Recently, the gender swap lens from Snapchat becomes very popular on the internet. There’re many buzzwords about Generative Adversarial Networks since 2016 but this is the first time that ordinary people get to experience the power of GANs. What’s more extraordinary about this lens is its great real-time performance which make it just like looking into a magic mirror. Although we can’t know the exact algorithm behind this virus lens, it’s most likely a CycleGAN which is introduced in 2017 by Jun-Yan, Taesung, Phillip and Alexei in this paper. And in this article, I’m going to show you how to implement a gender swap effect with TensorFlow 2.0 just like Snapchat does.
First of all, I want to quickly go over the basics of Generative Adversarial Network (GAN) to help those readers who are not familiar with it. In some scenario, we want to generate an image which belongs to a particular domain. For example, we’d like to draw a random interior design photo. So to ask the computer to generate such an image, we need a mathematical representation of the interior design domain space. Assume there’s a function
F, and a random input number x. We want
y = F(x) to always be very close to our target domain
Y. However, this target domain is in a very high dimensional space so that no human-being can figure out explicit rules to define it. A GAN is such kind of a network, by playing a minimax game between two AI agent, it can eventually find out an approximate representation F of our target domain
So how does GAN accomplish it? The trick here is to break down the problem into two parts: 1. We need a generator to keep making new images out of some random number 2. We need a discriminator to give feedback for the generator about how good the generated image is. The generator here is just like a young artist who has no idea how to paint but want to fake some masterpiece, and the discriminator is a judge who can tell what’s wrong in the new paint. The judge doesn’t need to know how to paint by himself. However, as long as he’s good at telling the difference between a good one and a bad one, our young painter can benefit from his feedback for sure. So we use Deep Learning to build a good judge and use it to train a good painter in the meantime.
To train a good judge, we need to feed both the authentic image and the generated image to our discriminator together. Since we know which is authentic and which is fake beforehand, the discriminator can update its weights by comparing its decision and the truth. For the generator, it takes a look at the decision from the discriminator. If the discriminator seems more agreeable with the fake image, it indicates that the generator is heading in the right direction, and vice versa. The tricky part is that we can’t just train a judge without a painter, or a painter without a judge. They learn from each other and try to beat each other by playing this minimax game. If we train the discriminator too much without training the generator, the discriminator will become too dominant, and the generator won’t ever have a chance to improve because every move is a losing move.
Eventually, both generator and discriminator will be really good at their job. Also, by then, we will take the generator out to perform the task independently. However, what’s on earth does this GAN has anything to do with gender swap? In fact, this is a more straightforward problem than the one I just mentioned above. Now, instead of generating a random image of a target domain, we can skip the random number step and use the given image as input. Let’s say we want to convert a male face to a female face. We are looking for a function
F, by taking a male face
x, can output a value
y that’s very close to the real female version
y_0 of that face.
The GAN approach sounds clear. But we haven’t discussed one curial caveat in the above approach. To train the discriminator, we need both true image and false image. We get the false image
y from the generator, but where do we get the true image for that specific male face? We can’t just use a random female face here because we want to preserve some common trait when we swap the gender and a face from a different person would ruin that. But it’s also really hard to get the paired training data as well. You could go out to find those real twin brother and sisters and take pictures of them. Or you can ask a professional dresser to ‘turn’ a man into a woman. Both are very expensive. So is there a way for the model to learn the most important facial difference between man and woman from an unpaired dataset?
Fortunately, scientists discovered ways to utilize unpair training data. One of the most famous models is called CycleGAN. The main idea behind CycleGAN is that, instead of using paired data to train the discriminator, we can form a cycle, where two generators work together to convert the image back and forth. More specifically, generator A2B first generate an image from domain A, and then generator B2A use that as input to generate another image from domain B. We then set a goal to make sure the second image (reconstructed image) looks as close as the first input. For example, if a generator A2B first converts a horse to a zebra, then generator B2A converts that zebra back to a horse, the newly generated horse should look identical to the very original horse. In this way, the generator will learn to not generate some trivial changes, but only those critical differences between the two domains. Otherwise, it probably won’t be able to convert it back. With this goal setup, we can now use unpaired images as training data.
In reality, we need two cycles here. Since we are training generator A2B and generator B2A together, we have to make sure both generators is improving over time; otherwise, it will still have a problem to reconstruct a good image. Moreover, as we discussed above, improving a generator means we need to improve the discriminator in the meantime. In the cycle A2B2A (A -> B -> A), we use discriminator A to decide if the reconstructed image is in domain A. Thus, discriminator A will be trained. Likewise, we also need a cycle B2A2B so that discriminator B can be trained as well. If both discriminator A and discriminator B are well trained, it means our generator A2B and B2A can improve too!
There’s another great article here for CycleGAN for further reading. Now that you get the main idea of this network, let’s dive deep into some details.
It’s recommended here that Adam is the best optimizer for GAN training. Although I don’t know the reason behind, the linear learning rate decay from the original paper looks quite effective during the training. It remains at 0.0002 for the first 100 epochs, then linearly decay to 0 in the next 100 epochs. Here
total_batches is the number of mini-batches for each epoch because our learning rate scheduler only considers each mini batch as a step.
gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches) dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches) optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1) optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)
CycleGAN uses a regular generator structure. It first encodes the input image into a feature matrix by applying 2D convolutions. This is used to extract valuable feature information from local or global.
Then, six or nine layers of ResNet blocks are used to transform the features from the encoder into the features in the target domain. As we know, the skip connection in ResNet block helps the network to memorize the gradients from previous layers, which makes sure the deeper layers can still learn something. If you are not familiar with ResNet, please refer to this paper.
Finally, a few layers of deconvolution is used as a decoder. The decoder converts the features from the target domain into an actual image from the target domain by upsampling.
Unlike the idea from VGG and Inception network, it’s recommended for a GAN to use a larger convolution kernel size like 7X7 so that it can pick up broader information instead of just focusing on details. It makes sense because when we reconstruct an image, it’s not only the details matters but also the overall pattern. Also, reflection padding is used here to improve the quality around the image border.
def make_generator_model(n_blocks): # 6 residual blocks # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3 # 9 residual blocks # c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3 model = tf.keras.Sequential() # Encoding model.add(ReflectionPad2d(3, input_shape=(256, 256, 3))) model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) # Transformation for i in range(n_blocks): model.add(ResNetBlock(256)) # Decoding model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(ReflectionPad2d(3)) model.add(tf.keras.layers.Conv2D(3, (7, 7), strides=(1, 1), padding='valid', activation='tanh')) return model
There’re three types of loss we care about here:
- To calculate the GAN loss, we measure the L2 distance (MSE) between the generated image and the truth image.
- To calculate the cyclic loss, we measure the L1 distance (MAE) between the reconstructed image from the cycle and the truth image
- To calculate the identity loss, we measure the L1 distance (MAE) between the identity image and the truth image
GAN loss is the typical loss we use the GANs, and I won’t discuss much here. The interesting parts are cyclic loss and identity loss. The cyclic loss measures how good the reconstructed image is, which helps both generators to catch the essential style difference between the two domains. The identity loss is optional, but it helps to avoid the generator to make unnecessary changes. The way it works is that, by applying generator A2B to a real B image, it shouldn’t make any changes as it’s already the desired outcome. According to the author, this mitigates some weird issues like background color change.
def calc_gan_loss(prediction, is_real): # Typical GAN loss to set objectives for generator and discriminator if is_real: return mse_loss(prediction, tf.ones_like(prediction)) else: return mse_loss(prediction, tf.zeros_like(prediction)) def calc_cycle_loss(reconstructed_images, real_images): # Cycle loss to make sure reconstructed image looks real return mae_loss(reconstructed_images, real_images) def calc_identity_loss(identity_images, real_images): # Identity loss to make sure generator won't do unnecessary change # Ideally, feeding a real image to generator should generate itself return mae_loss(identity_images, real_images)
To combine all losses, we also need to assign some weights for each loss so indicate the importance. In the paper, the author proposed two Lambda parameters, which 10x the cycle loss and 5x the identity loss. Note the usage of the GradientTape here, we record the gradient and apply gradient descent for both generators together. Here,
real_a is the truth image from domain A,
real_b is the truth image from domain B.
fake_a2b is the generated image from domain A to domain B. and
fake_b2a is the generative image from domain B to domain A.
@tf.function def train_generator(images_a, images_b): real_a = images_a real_b = images_b with tf.GradientTape() as tape: # Use real B to generate B should be identical identity_a2b = generator_a2b(real_b, training=True) identity_b2a = generator_b2a(real_a, training=True) loss_identity_a2b = calc_identity_loss(identity_a2b, real_b) loss_identity_b2a = calc_identity_loss(identity_b2a, real_a) # Generator A2B tries to trick Discriminator B that the generated image is B loss_gan_gen_a2b = calc_gan_loss(discriminator_b(fake_a2b, training=True), True) # Generator B2A tries to trick Discriminator A that the generated image is A loss_gan_gen_b2a = calc_gan_loss(discriminator_a(fake_b2a, training=True), True) loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a) loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b) # Total generator loss loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \ + (loss_cycle_a2b2a + loss_cycle_b2a2b) * 10 \ + (loss_identity_a2b + loss_identity_b2a) * 5 trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables gradient_gen = tape.gradient(loss_gen_total, trainable_variables) optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))
Similar to other GANs, the discriminator consists of some 2d convolution layers to extract features from the generated image. However, to help the generator to generate a high-resolution image, CycleGAN uses a technique called PatchGAN to created more fine-grained decision matrix instead of one decision value. Each value in this 32×32 decision matrix maps to a patch of the generated image, and indicate how real this patch is.
In fact, we don’t crop a patch of the input image during implementation. We just need to use a final convolution layer to do the job for us. Essentially, the convolution layer performs like cropping a patch.
def make_discriminator_model(): # C64-C128-C256-C512 model = tf.keras.Sequential() model.add(tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=(256, 256, 3))) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) model.add(tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) model.add(tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) model.add(tf.keras.layers.Conv2D(512, (4, 4), strides=(1, 1), padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) # This last conv net is the PatchGAN # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/39#issuecomment-305575964 # https://github.com/phillipi/pix2pix/blob/master/scripts/receptive_field_sizes.m model.add(tf.keras.layers.Conv2D(1, (4, 4), strides=(1, 1), padding='same')) return model
The loss functions for discriminators are much more straightforward. Just like typical GANs, we tell the discriminator to treat truth image as real, and the generated image as fake. So we have two losses for each discriminator
loss_fake, both have an equal effect on the final loss. In
calc_gan_loss, we are comparing two matrices. Usually, the output of a discriminator is just one value between 0 and 1. However, as we mentioned above, we use a technique called PatchGAN, so the discriminator will produce one decision for each patch, which forms a 32×32 decision matrix.
@tf.function def train_discriminator(images_a, images_b, fake_a2b, fake_b2a): real_a = images_a real_b = images_b with tf.GradientTape() as tape: # Discriminator A should classify real_a as A loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True) # Discriminator A should classify generated fake_b2a as not A loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False) # Discriminator B should classify real_b as B loss_gan_dis_b_real = calc_gan_loss(discriminator_b(real_b, training=True), True) # Discriminator B should classify generated fake_a2b as not B loss_gan_dis_b_fake = calc_gan_loss(discriminator_b(fake_a2b, training=True), False) # Total discriminator loss loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5 loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5 loss_dis_total = loss_dis_a + loss_dis_b trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables gradient_dis = tape.gradient(loss_dis_total, trainable_variables) optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))
Now that we have defined both models and loss functions, we can put them together and start training. By default, the eager mode is enabled in TensorFlow 2.0, so we don’t have to make the graph. However, if you are a careful person, you might found that both discriminator and generator training functions are decorated with a
tf.function decorator. This is the new way introduced by TensorFlow 2.0 to replace the old
tf.Session(). With this decorator, all operations within will be converted into a graph. Hence, the performance could be much better compared with the default eager mode. To learn more about
tf.function, please refer to this article.
One thing to mention is that, instead of feeding the generated image to the discriminator directly, we are actually using an image pool here. Each time, the image pool will randomly decide to give the discriminator a newly generated image, or a generated image from past steps. The benefit of doing this is that the discriminator can learn from other cases and sort of having a memory about the hacks the generator uses. Unfortunately, we can’t use this random image pool in graph mode at the moment, so we need to put them back to CPU when selecting a random image from the pool. This indeed introduces some cost.
The model illustrated in this article is trained on my own GTX 1080 home computer, so it’s a bit slow. On a V100 16G GPU and 64G RAM instance, though, you should be able to set the mini batch size to 4, and the trainer can process one epoch of 260 mini batches in 3 minutes for the horse2zebra dataset. So it takes about 10 hours to train a horse2zebra model fully. If you reduce the image resolution and some network parameters correspondingly, the training could be faster. The final generator is about 44mb each.
def train_step(images_a, images_b, epoch, step): fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b) fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a) fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b) dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool) def train(dataset, epochs): for epoch in range(checkpoint.epoch+1, epochs+1): for (step, batch) in enumerate(dataset): train_step(batch, batch, epoch, step)
To see the full training script, please go visit my repo here.
Let’s see some inference results on a few datasets. Among those, horse2zebra and monet2photo is the original dataset from the paper. And the CelebA dataset is from here.
Horse -> Zebra
Zebra -> Horse
Monet -> Photo
Photo -> Monet
Male -> Female
Female -> Male
We successfully mapped a male face to a female face, but to use it in a production environment, we need to pipeline to orchestrate lots of other steps together. Snapchat may have its own optimization or models. But, here’s the procedure that I think will help to improve the final result of our CycleGAN.
- Run face detection to find a bounding box and keypoints for the most dominant face in the picture.
- Extend the bounding box a little bit bigger to match the training dataset distribution.
- Crop the picture with this extended bounding box and run CycleGAN over it.
- Patch the generated image back to the original picture
- Overlay some hair, eyeliner, and beards on top of the new face picture based on the keypoints we had from the last step
I get inspiration mostly from this great article that explains how this pipeline works in details.
Lastly, I want to throw out some questions I have. I don’t know the answers for them, but I hope those who have experiences of building similar products can share their opinions in the comments below.
- The CycleGAN model turns out to be 44mb, with quantization it could become 12mb but still too large. What are the effective methods to make them usable on those mobile and embedded devices?
- The output image resolution isn’t great and lost much of sharpness. How to generate a bigger image such as 1024×1024 without blowing up the model size? Will a super-resolution model help in this case?
- How do we know if a model is thoroughly trained and converged? The loss isn’t good metrics here, but we also don’t know what’s “best” output. How to measure the similarity between the two styles?