We have come across various articles and posts about AI being capable of producing human like speech or generating images of non-existing people that are difficult to distinguish from real-life existence. These AI systems are built upon generative adversarial networks (GANs) - which Facebook AI research director Yann LeCun called “most interesting idea in last 10 years in ML” GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014.

GANs’ potential for both good and evil is huge, because they can learn to mimic any distribution of data. That is, GANs can be taught to create worlds eerily similar to our own in any domain: images, music, speech, prose. They are robot artists in a sense, and their output is impressive - poignant even. But they can also be used to generate fake media content and are the technology underpinning Deepfakes.

To understand GANs, we need to understand generative algorithms and its counterpart discriminative algorithms.

Background: Generative Models v/s Discriminative Models

Discriminative Models

Discriminative models classifies the input data; i.e. given the features of an instance of data, they predict the label or cateogry to which the data belongs. For example, given , all the words in an email (this is data instance), a discriminative model would predict whether the email is spam or not_spam. spam is one of the labels, and the bag of words gathered from the email are the features that constitute the input data. When this problem is expressed mathematically, the label is called as $Y$ and the features are called as $X$. The formualtion $P(Y|X)$ is used to mean “the probability of y given x” which in this case would translate to “the probability that an email is spam given the words it contains”.

So discriminative models maps features to labels. They are concerned solely with that correlation. Discriminative models directly learns the conditional distribution $P(Y|X)$.

Generative Models

On the other hand, generative models learns the joint probability distribution $P(X,Y)$ by explicitly learning $P(X|Y)$ and $P(Y)$, and then uses it to compute $P(Y|X)$ through Bayes rule.

Given a model of the joint distribution, $P(X,Y)$, marginal distribution of $X$ & $Y$ can be calculated as:

P(X)=yP(X,Y=y)  andP(Y)=xP(Y,X=x)P(X) = \sum_{y}P(X, Y = y) \space \space and \\ P(Y) = \int_{x}P(Y, X = x)

considering $X$ as continuous, hence integrating over it and $Y$ as discrete hence summing over it.

Either conditional distribution can be computed as:

P(YX)=P(X,Y)P(X)P(Y|X) = \frac{P(X,Y)}{P(X)}

Thats how, we can compute $P(Y|X)$ in generative models.

One way think about generative models from GANs perspective is that they do the opposite of discriminative models. Instead of predicting labels given the features, generative models attempt to predict features given a certain label i.e. $P(X|Y)$. Understanding it from the email example, generative model tries to answer this: Assuming the given eamil is spam, how likely are these features? While discriminative models care about relation between $Y$ and $X$, generative models care about “how you get $X$?”

Formal Definitions:

  • A discriminative model learns a function that maps the input data X to some desired output class label Y. In probabilistic terms, they directly learn the conditional distribution P(Y|X).
  • A generative model tries to learn joint probability of the input data and labels simultaneously, i.e. P(X,Y). This is converted to P(X|Y) for classification via Bayes rule, but the generative ability could be used for something else as well, such as creating new (X,Y) sample.
The word "generative" in generative models does not mean that the model generates actual new data additionally
to the dataset. It refers to the nature of the theoretical model, in the sense that the generative approach
assumes that any sample of data in generated from some distribution, and it tries to estimate this 
distribution. Once the distribution is estimated the model could be used to actually generate instances following
this distribution.

Both these models are used in supervised learning where one wants to learn a rule that maps input $X$ to output $Y$. Some argues that the discriminative models are better as they directly models the quantity we care about i.e. $Y$ and hence no efforts are spent on modelling the input $X$. However, generative models has its own advantages such as capability of dealing with missing data. Since, generative models concerns about $P(X,Y)$ and $P(X)$ at the same time in order to predict $P(Y|X)$, they have less degree of freedom as compared to discriminative models. So generative models are more robust, less prone to overfitting than discriminative models.

Generative Models are hard

As explained earlier, generative models are better than discriminative models in certain aspects but they are even difficult to train. Generative models tackle a more difficult task than analogous discriminative models. Generative models have to model more.

A generative model for images might capture correlations like “things that look like boats are probably going to appear near things that look like water” and “eyes are unlikely to appear on foreheads”. These are very complicated distributions.

In contrast, a discriminative model might learn the difference between “sailboat” or “not sailboat” by just looking for a few tell-tale patterns. It could ignore many of the correlations that the generative models must get right.

Discriminative models try to draw boundaries in the data space, while generative models try to model how data is placed throughut the space. For example, the following diagrams shows discriminative and generative models of handwritten digits:

The discriminative model tries to tell the difference between handwritten 0’s and 1’s by drawing a line in the data space. If it gets the line right, it can distinguish 0’s from 1’s without ever having to model exactly where the instances are placed in the data space on either side of the line.

In contrast, the generative model tries to produce convincing 1’s and 0’s by generating digits that fall close to their real counterparts in the data space. It has to model the distribution throughout the data space. GANs offer an effective way to train such rich models to resemble a real distribution.

Generative Adversarial Networks

Lets dive into GANs and their working.

The main idea behind GANs is to have two competing neural network models. One takes noise as input and generates samples (hence called generator) wheras other model (called discriminator) receives samples from both the generator and the training data, and has to be able to distinguish between the two inputs. These two networks play a continuous game, where the generator is leaning to produce more and more realistic samples, and the discriminator is learning to get better and better at distinguishing generated data from the real data.

Lets take an example. Lets say we are going to generate hand-written numerals like those found in the MNIST dataset, which is taken from the real world. The goal of the discriminator here, when shown an instance from the true MNIST dataset, is to recognize those that are authentic and belong to MNIST dataset.

Meanwhile, the generator is creating new, synthetic images just like MNIST dataset and passes it to the discriminator. Generator does so in the hope that its generated synthetic images will be deemed authentic by the discriminator even though they are fake. The goal of the generator is to generate hand-written digits not to be caught by discriminator for being fake. The goal of the discriminator is to identify images coming from the generators as fake.

Here are the steps a GAN takes:

  • The generator takes in random numbers and returns an image
  • This generated image is fed into the discriminator alongside a stream of images taken from the actual, ground-truth dataset.
  • The discriminator takes in both real and fake images and return probabilities for whether image is fake or real.
Discriminator takes real images and fake images one after another (not simulataneously) and calculate their probabilites for being fake or real

So there is a double feedback loop (or two ways through which weights are updated):

  • The discriminator is in a feedback loop with the ground truth of the real images i.e. discriminator’s weights are updated based on its performance to detect real images as real
  • The generator is in a feedback loop with the discriminator i.e. generator’s weights are updated based on discriminator’s performance to detect generated images as fake

The analogy that is often used to understand GAN is that the generator is like a forger trying to produce some counterfeit material, and the discriminator is like the police trying to detect the forged items. This setup may also seem somewhat reminiscent of reinforcement learning, where the generator is receiving a reward signal from the discriminator letting it know whether the generated data is accurate or not. The key difference with GANs however is that we can backpropagate gradients from the discriminator network back to the generator network, so the generator knows how to adapt its parameters in order to produce output data that can fool discriminator. Lets learn how loss is propagated from one discriminator network to another generator network.

Discriminator

The discriminator in GANs are simply a classifier. It tries to distinguish real data from the data generated by the generator. It could use any network architecture appropriate to the type of data it’s classifying. In MNIST example above, the discriminator network could be a standard convolutional network that can categorize the images fed to it; a binary classifier labelling images as real or fake.

As explained earlier, discriminator’s training data comes from two sources:

  • Real data instances taken from the real world. The discriminator uses these instances as positive examples during training.
  • Fake data instances created by the generator. The discriminator uses these instances as negative examples during training.

The discriminator connects to two loss functions. During discriminator training, the discriminator ignores the generator loss and just uses the discriminator loss. During discriminator training:

  1. The discriminator classifies both real data and fake data (from the generator).
  2. The discriminator loss penalizes the discriminator for misclassifying a real instance as fake or a fake instance as real.
  3. The discriminator updates it weights through backpropagation from the discriminator loss through discriminator network.

Generator

The generator part of GAN learns to create fake data by incorporating feedback from the discriminator. It learns to make the discriminator classify its output as real.

GAN takes random noise as its input. The generator then transforms this noise into a meaningful output. By introducing noise, we can get the GAN to produce a wide variety of data, sampling from different places in the target distribution. Experiments suggest that the distribution of noise doesn’t matter much, so we can choose something that’s easy to sample from, like a normal distribution. For convenience, the space from which the noise is sampled is usually of smaller dimension than the dimensionality of the output space.

A neural net is trained by altering its weights to reduce the error. In GANs, however, the generator is not directly connected to the loss that we’re trying to affect. The generator feeds into the discriminator net, and the discriminator produces the output through which loss is computed. Backpropagation adjusts weights in right direction by calculating the weight’s impact on the output. But the impact of a generators weight depends on the impact of the discriminator weight it feeds into. impact So this extra chunk of discriminator network is included in the backpropagation that trains the generator. Backpropagation starts at the output of GAN and flows back through the discriminator into the generator.

At the same time, we don’t want the discriminator weights to be altering during the generator training.

So generator is trained in following way:

  1. Sample random noise
  2. Produce generator output from sampled random noise
  3. Get discriminator “Real” or “Fake” classification for generator output
  4. Calculate loss from discriminator classification
  5. Backpropagate through both the discriminator and generator to obtain gradients.
  6. Use gradients to change only the generator weights.

GAN Training

We saw, how individually generator and discriminator train in their specific sections above. GAN as a whole is trained when both its generator and discriminator network get trains. GAN’s training proceeds in a alternating way:

  1. The discriminator trains for one or more epochs.
  2. The generator trains for one or more epochs.
  3. Steps 1 and 2 are repeated to continue train the generator and discriminator.

When discriminator network is trained, generator network is kept constant. Similarly, discriminator network is kept constant when generator network is trained. Its this back and forth of training that allows GANs to tackle intractable generative problem.

As the generator improves with training, the discriminator performance gets worse which hampers generator's training

As the generator improves with training, the discriminator performance gets worse since the discriminator can’t easily tell the difference between real and fake. If the generator succeeds perfectly, then the discriminator has a 50% accuracy. This progression poses a problem for GAN as a whole: the discriminator feedback to generator gets less meaningful overtime. If the GAN continues training past the point when the discriminator is giving completely random feedback, then the generator starts to train on junk feedback, and its quality may collapse.

GAN Mathematics

Two Divergences

Before going to into detail how GAN is trained mathematically, lets first review two metrics for quantifying the similarity between two probability distributions.

Kullback-Leibler Divergence

This divergence measures how one probability distribution $p$ diverges from another expected probability distribution $q$. It is given by:

DKL(pq)=xp(x)logp(x)q(x)dxD_{KL}(p\|q) = \int_{x}p(x)log\frac{p(x)}{q(x)}dx

Intuitively, $D_{KL}$ achieves the minimum zero when $p(x) = q(x)$ everywhere. Also, it is noticeable according to the formula that KL divergence is asymmetric. More information about it being asymmetric is given here

Fun Fact: Kullback-Leibler divergence is not same as Leibler-Kullback divergence!! ;)

In cases where $p(x)$ is close to zero, but $q(x)$ is significantly non-zero, the $q$’s effect is disregarded. It could cause buggy results measuring the similarity between two equally important distributions.

Jensen-Shannon Divergence

It is another measure of similarity between two probability distributions, bounded by [0,1]. JS divergence is symmetric and more smooth. It is built upon KL-divergence and could be written as:

DJS(pq)=12DKL(pp+q2)+12DKL(qp+q2)D_{JS}(p\|q) = \frac{1}{2}D_{KL}(p\|\frac{p+q}{2})+\frac{1}{2}D_{KL}(q\|\frac{p+q}{2})

Let us now understand how GAN is trained mathematically.

We learnt in earlier sections how GAN can be seen as an interplay between two different models i.e. generator and discriminator. As a result, each of the model would have its own loss function. Lets define the loss function for each of the two models but before that first some notations for loss functions.

Notations and Loss Functions

x:  Real  data  samplez:  Latent  Vector  from  generator  modelG(z):  Fake  data  from  generatorD(x):  Discriminators  evaluation  of  real  dataD(G(z)):  Discriminators  evaluation  of  fake  dataError(a,b):  Error  between  a  and  bx :\space\space Real\space\space data \space\space sample\\ z :\space\space Latent \space\space Vector \space\space from \space\space generator \space\space model\\ G(z) :\space\space Fake \space\space data \space\space from \space\space generator\\ D(x) :\space\space Discriminator's \space\space evaluation \space\space of \space\space real \space\space data\\ D(G(z)) :\space\space Discriminator's \space\space evaluation \space\space of \space\space fake \space\space data\\ Error(a,b):\space\space Error \space\space between \space\space 'a' \space\space and \space\space 'b'\\

Discriminator Loss Function
The goal of the discriminator is to correctly label generated images as false($0$) and real data samples as true($1$). Therefore, the loss function of discriminator would be something like this:

LD=Error(D(x),1)+Error(D(G(z)),0)(1)L_{D} = Error(D(x),1) + Error(D(G(z)), 0) \tag{1}

Here, $Error$ is being referring to some function that tells us the distance or difference between two functional parameters.

Generator Loss Function
The goal of the generator is to confuse the discriminator as much as possible such that it mislabels the generated images as being true. The loss function of generator could be defined as:

LG=Error(D(G(z)),1)(2)L_{G} = Error(D(G(z)), 1) \tag{2}

We have to remember that a loss function is need to be minimized. Hence, in the case of generator, loss function would be minimized such that the difference between 1 (the label for real data) and discriminator’s evaluation of generated data is minimum (i.e. optimally it should be 0).

BCE in GANs

A common loss function that is used in binary classification tasks is binary cross entropy. As a quick review, let’s see how cross entropy looks like:

CE(p,q)=xχp(x)logq(x)CE(p,q) = - \sum_{x \in \chi}p(x)logq(x) where $p$ is the predicted probability and $q$ is the ground truth label of data $x$ belonging to $\chi$ dataset. $CE(p,q)$ could be further simplied for binary classification problems (i.e. only two labels: $0$ and $1$) and called as binary cross entropy (BCE):

BCE(y,y^)=x=1χ=2ylog(y^)+(1y)log(1y^)BCE(y, \hat{y}) = - \sum_{x=1}^{\chi = 2}ylog(\hat{y}) + (1-y)log(1-\hat{y})

More detail on Cross Entropy and its variants can be read here.

The $Error$ written above in $(1)$ and $(2)$ is this BCE. Binary Cross Entropy fulfills the objective of measuring how different two distributions are in the context of binary classification of determining whether an input data point is true or false. Applying BCE to loss function in $(1)$:

LD=xχ,zζlog(D(x))+log(1D(G(z)))(3)L_{D} = -\sum_{x \in \chi, z \in \zeta}log(D(x)) + log(1 - D(G(z))) \tag{3}

Same can be done for $(2)$ as well:

LG=zζlog(D(G(z)))(4)L_{G} = - \sum_{z \in \zeta}log(D(G(z))) \tag{4}

Now there are two loss functions which would be minimized to train the generator and discriminator. For the generator loss function to be small, $D(G(z))$ should be close to $1$, since $log(1) = 0$.

Minor Caveats from Original Paper

The original paper by Ian Goodfellow presents a slightly different version of the two loss functions derived above:

maxD{log(D(x))+log(1D(G(z)))}(5)\underset{D}{max}\{log(D(x)) + log(1- D(G(z)))\} \tag{5}

The above loss function is of generator’s and on looking closely it is very alike to equation $(1)$. The only difference between equation $(1)$ and $(4)$ is the sign and whether the respective equation is need to be minimize or maximize. In $(1)$, the loss function is framed to be minimized, whereas the original formulation frames the loss function to be maximized with the sign flipped. Hence, we get equation $(5)$ as the loss function for discriminator to be maximized.

The generator is competing against the discriminator striving to make discriminator think generated data as real. So, it will try to minimize the equation $(5)$ and generator’s loss function would be:

minG{log(D(x))+log(1D(G(z)))}(6)\underset{G}{min}\{log(D(x)) + log(1- D(G(z)))\} \tag{6}

In the oiginal paper, these loss functions are combined to get the following equation:

L=minG maxD{log(D(x))+log(1D(G(z)))}(7)L = \underset{G}{min}\space \underset{D}{max}\{log(D(x)) + log(1- D(G(z)))\} \tag{7}

Model Optimization

Now that loss functions for generator and discriminator has been defined, it’s time to leverage mathematics to solve the optimization problem, i.e. finding the parameters for the generator and discriminator such that the loss functions are optimized. Equation (7) will be used for further calculations.

Training the Discriminator

When training a GAN, one model is train at a time. In other words, when training the discriminator, the generator is fixed. In the min-max loss function, the quantity of interest can be defined as a function of $G$ and $D$. Let’s call this the value function:

V(G,D)=Expdata[log(D(x))]+Ezpz[log(1D(G(z)))]V(G,D) = \mathbb{E}_{x \sim p_{data}}[log(D(x))] + \mathbb{E}_{z \sim p_{z}}[log(1-D(G(z)))]

$p_{data}$ & $p_{z}$ represents distribution of real data and generated data $(z)$. So expectation $\mathbb{E}$ is taken over the distribution $p_{data}$ & $p_{z}$

Let’s create a new variable, $y = G(z)$, and use this substitution to rewrite the value function:

V(G,D)=Expdata[log(D(x))]+Eypg[log(1D(y))]V(G,D)=xχpdata(x)log(D(x))+pg(x)log(1D(x))dx(8)V(G,D) = \mathbb{E}_{x \sim p_{data}}[log(D(x))] + \mathbb{E}_{y \sim p_{g}}[log(1 - D(y))] \\ V(G,D) = \int_{x \in \chi}p_{data}(x)log(D(x)) + p_{g}(x)log(1-D(x))dx \tag{8}

The goal of the discriminator is to maximize this value function. Through a partial derivative of $V(G,D)$ with respect to $D(x)$, we see that the optimal discriminator, denoted as $D^{*}(x)$, occurs when

pdata(x)D(x)pg(x)1D(x)=0\frac{p_{data}(x)}{D(x)} - \frac{p_{g}(x)}{1-D(x)} = 0

Rearranging it, we get

D(x)=pdata(x)pdata(x)+pg(x)(9)D^{*}(x) = \frac{p_{data}(x)} {p_{data}(x) + p_{g}(x)} \tag{9}

And this is the condition for the optimal discriminator!. Note that the formula makes intuitive sense: if some sample $x$ is highly genuine, $p_{data}(x)$ is expected to be close to one and $p_{g}(x)$ to be converge to zero, in which case the optimal discriminator would assign 1 to that sample. On the other hand, for a generated sample $x=G(z)$, optimal discriminator would assign a label of zero, since $p_{data}(G(z))$ should be close to zero.

Training the Generator

To train the generator, the discriminator is assumed to be fixed and lets proceed with the analysis of the value function. Plugging the result obtained in $(9)$ in value function:

V(G,D)=Expdata[log(D(x))]+Ezpg[log(1log(D(x)))]V(G,D)=Expdata[logpdata(x)pdata(x)+pg(x)]+Ezpg[logpg(x)pdata(x)+pg(x)](10)V(G, D^{*}) = \mathbb{E}_{x \sim p_{data}}[log(D^{*}(x))] + \mathbb{E}_{z \sim p_{g}}[log(1-log(D^{*}(x)))] \\ V(G, D^{*}) =\mathbb{E}_{x \sim p_{data}}[log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}] + \mathbb{E}_{z \sim p_{g}}[log\frac{p_{g}(x)}{p_{data}(x)+p_{g}(x)}] \tag{10}

Doing a little trick in the above equation:

V(G,D)=Expdata[logpdata(x)pdata(x)+pg(x)]+Expg[logpg(x)pdata(x)+pg(x)]V(G,D)=log4+Expdata[logpdata(x)logpdata(x)+pg(x)2]+Expg[logpg(x)logpdata(x)+pg(x)2]V(G,D^{*}) = \mathbb{E}_{x \sim p_{data}}[log \frac{p_{data}(x)}{p_{data}(x) + p_{g}(x)}] + \mathbb{E}_{x \sim p_{g}}[log \frac{p_{g}(x)}{p_{data}(x)+p_{g}(x)}] \\ V(G,D^{*}) = -log4 + \mathbb{E}_{x \sim p_{data}}[logp_{data}(x) - log\frac{p_{data}(x) + p_{g}(x)}{2}] + \mathbb{E}_{x \sim p_{g}}[logp_{g}(x) - log\frac{p_{data}(x)+p_{g}(x)}{2}]

In above equation, $-log4$ is pulled out and applied inevitable changes to the terms in expectations for that specifically by dividing the denominator by two. Due to this change, now Kullback-Leibler divergence can be applied:

V(G,D)=log4+DKL(Pdatapdata+pg2)+DKL(pgpg+pdata2)(11)V(G, D^{*}) = -log4 + D_{KL}(P_{data} \| \frac{p_{data}+p_{g}}{2}) + D_{KL}(p_{g} \| \frac{p_{g}+p_{data}}{2}) \tag{11}

Jensen-Shannon divergence is defined as:

J(P,Q)=12(DKL(PR)+DKL(QR))J(P,Q) = \frac{1}{2}(D_{KL}(P \| R) + D_{KL}(Q \| R))

where $R = \frac{1}{2}(P+Q)$. This means the expression in $(11)$ can be expressed as a JS divergence:

V(G,D)=log4+2×DJS(pdatapg)(12)V(G,D^{*}) = -log4 + 2 \times D_{JS}(p_{data}\|p_{g}) \tag{12}

Essentially the loss function of GAN quantifies the similaity between the generative data distribution $p_{g}$ and the real sample data distribution $p_{data}$ by JS divergence when the discriminator is optimal. This aligns with the intuition that generator is being able to learn the underlying distribution of the data from sampled training examples. The optimal generator $G$ is thus the one that which is able to mimic $p_{data}$ to model a compelling distribution $p_{g}$.

The best GG^{*} that replicates the real data distribution leads to the minimum V(G,D)=log4V(G,D^{*}) = -log4 which is aligned with equations above.

This concludes this blog post. Further to read are the problems in GANs training viz. how hard it is to obtain Nash equilibrium, gradient saturation and mode collapse; all these problems are due to the loss function of GAN. Next topic would be, how these problems are circumvented and what is WGAN. Later, one can read all the variants of GANs available or their applications in different domains.

Vanilla GAN Code Tutorial

The Vanilla GAN (original GAN paper from Ian Goodfellow) code can be found here in jupyter notebook format and could be used to playaround: Vanilla GAN Code

Refereces

  1. A Beginner’s Guide to Generatvive Adversarial Networks (GANs)
  2. Stackoverflow
  3. Generative Adversarial Networks - Google Developers
  4. GAN Mathematics

More to Read

  1. Stackverflow: Generative v/s Discriminative
  2. From GAN to WGAN