Introduction to Generative Networks

Introduction to Generative Networks

2022, Sep 25    

This idea came about while in the process of writing another article regarding classifying plant disease using synthetic data.

Disclaimer: Complex math backs much of machine learning/deep learning. For this post I’m going to do my best to leave out the formulas and focus on implementation and logic.

What Is a ‘Generative Adversarial Network’?

In a short and sweet definition, a Generative Adversarial Network (GAN) is a deep learning architecture that generates some synthetic data as the output. This is the “generative” part of the name.

The word “adversarial” means that this network’s architecture is designed to have opposing processes built into it, to form a sort of rivalry.

Then the “network” piece just means the algorithm uses a network-style architecture to process data. A GAN is a style of neural network.

It is common for the data generated by a GAN to be images (pixel values), but GANs can also be used to synthesize speech. In this post we focus on visual applications.

Applications of a GAN

You may have seen some recent articles written about “AI” creating art that is on-par with what a human artist can create. Due to the realism (or surrealism) of the generated images, this is one of the more publicly-known applications of a GAN. You can play with image generation for free on craiyon.

While generating art is fun (and also tends to irritate the art community a bit…), GANs have also been used to create deepfakes. Unless you’ve been living under a rock lately, you probably know that deepfakes are generated faces that are commonly superimposed onto video.

df

As you might imagine, these deepfakes are also being used maliciously, leading to research into deepfake detection methods.

However, GANs (and other generative architectures) are also very useful within the machine learning process itself…

Cats vs Dogs: Your Hypothetical Invention

A common issue when preparing data for machine learning/deep learning is the lack of quality data. This becomes even more critical for neural networks, as a goal of deep learning is to automatically identify/extract important features.

Let’s assume that you have a bunch of cats and dogs running rampant in your house. One day you decide to move the cats’ food/water/litter boxes to a different room, and put in a cat door so they can come and go freely. However, you also have a problem dog that just loves eating cat food (more than his/her own food, because why not).

As an aspiring ML developer, you decide to automate the cat door using a camera to detect if a cat or dog is at the door. If it is a cat, the door will open letting the cat through. If it detects a dog, it won’t open (thus saving you money on having to buy more cat food).

millies

You have the hardware in place and are ready to train your first neural network. You start by taking a picture of each dog/cat and use those images as your training data, with the predicted labels being binary (0 or 1, cat or dog). After your first training iteration you notice the accuracy of the model is only 8%. oof

You notice that your invention is letting more dogs through the door than cats. This isn’t right, so you decide to take more pictures of each cat/dog and re-train. Now the model performance is up to 15% accuracy.

Now you’ve taken hundreds of photos of each cat/dog and your accuracy is up to 60% - still only slightly better than a guess. You’re exhausted and tired of taking picture after picture. Plus, chasing down each animal for picture day retakes isn’t an easy task.

Augmenting Your Dataset Using GANs

Frustrated with your invention not working as intended, you remember stumbling across this cool person’s blog post about something called Generated Adversarial Networks. You recall that they are a neural network architecture that is commonly used to generate new, synthetic images.

Intro: the “a-ha!” moment.

aha

Sure, the generated images might not look 100% realistic, but they might be close enough to fool your classification network into thinking they are new images of your furry friends.

You leave your current cat/dog classification architecture as-is, and insert a step in data preparation to first train a GAN using the images of your cats/dogs. Once the GAN has been trained sufficiently, you now send images of each pet to this model, and the GAN generates a new image that resembles that of the original.

You have now expanded the image dataset for each pet from hundreds to thousands. Using this new mixed dataset of real and synthetic images, you re-train your classification network and notice an average accuracy of 88%. Significant improvement! Even your dog high fives you. nice

dogfive

Note: Don’t confuse GANs with other data augmentation methods, such as simply flipping/skewing/cropping etc.. GANs are generating brand new data using patterns it found from the training data, not simply modifying existing images.

So, How Do GANs Work?

If you stuck around after that lengthy example of your new billion-dollar invention, you might be curious on how exactly these GAN things work.

GANs are a network architecture that is actually made up of two separate, purpose-built neural networks. These are often referred to as the “Generator” and “Discriminator” networks. Each of these are a functional neural network architecture by themselves, but are built/compiled in a way to compete with each other.

Generator

The Generator network serves the purpose of generating fake/synthetic images. It works by taking a random noise image as input, such as a random pixelated 10x10 image. This image alone serves no purpose other than a random starting point for the Generator.

The network architecture is designed to continually upsample this random noise image until the dimensions of the image reach that of our dataset images. For example, if your images of cats and dogs are 500x500 pixels each, you will upsample the random noise image from 10x10 until it reaches a size of 500x500. This also means more data is being generated to create the new pixel values.

Keep in mind that the first pass will still be an image of random pixelated noise, not resembling your cat/dog.

Discriminator

The goal of the Discriminator network is to determine if the output image from the Generator network is a real or fake image.

The Discriminator works in a bit of a reverse manner from the Generator. Rather than upsampling a random noise image, it takes the already upsampled noise image as input and slowly breaks it down. The final output of the Discriminator network is a single binary output of 0 or 1, real or fake.

Generator + Discriminator (Adversarial Network)

The overall “adversarial network” is nothing more than combining the Generator and Discriminator into the same processes. Together these networks become adversarial in nature, following a loosely-similar process to the below:

  1. Sample 20 real cat/dog images
  2. Generator: generate 20 random noise (fake) images, send to discriminator
  3. Discriminator: predict real/fake on 20 random noise images
  4. Discriminator: train Discriminator by combining real/fake images and labels (0,1)
  5. Generator: use output of step 3 to train Generator to produce less fake images
  6. Repeat steps 2 - 5.

Now you have a network that (hopefully) improves each time the discriminator predicts a real/fake image. The discriminator will get better at identifying fake images from real, and the generator will get better at creating more realistic images to fool the generator.

Such a beautiful relationship.

(notice my art skills showing the noisy image looking more like a cat, to show progression…)

arch2

Looking Forward

This is a longer post than I initially intended. Hopefully you didn’t find this a complete waste of your time. Even if you aren’t an ML practitioner, now you have a better understanding of what is driving all these generated images/videos, and where they can be beneficial (or malicious).

This was also meant to be a precursor to my next post that focuses on plant disease classification. In that post I make use of a custom trained GAN to generate synthetic images. There you will find much more detail of how GANs work, including the code (TensorFlow) that makes it all work, if that’s your kind of thing.

If you have any questions, comments, or find any errors or bland jokes/sarcasm in this post, please reach out!