Skip to content

🚀 Variants of GANs most easily implemented as TensorFlow2. GAN, DCGAN, LSGAN, WGAN, WGAN-GP, DRAGAN, ETC...

License

Notifications You must be signed in to change notification settings

marload/GANs-TensorFlow2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

62 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TF Depend License Badge

Generative Adversarial Nets in TensorFlow2

GANs-TensorFlow2 is a repository that implements a variety of popular Generative Adversarial Network algorithms using TensorFlow2. The key to this repository is an easy-to-understand code. Therefore, if you are a student or a researcher studying Deep Reinforcement Learning, I think it would be the best choice to study with this repository. One algorithm relies only on one python script file. So you don't have to go in and out of different files to study specific algorithms. This repository is constantly being updated and will continue to add a new Generative Adversarial Network algorithm.

Algorithms


GAN

Paper Generative Adversarial Networks
Author Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio
Publish NIPS 2014

Animation Results

Loss Function

def get_loss_fn():
    def d_loss_fn(real_logits, fake_logits):
        return -tf.reduce_mean(tf.math.log(real_logits + 1e-10) + tf.math.log(1. - fake_logits + 1e-10))

    def g_loss_fn(fake_logits):
        return -tf.reduce_mean(tf.math.log(fake_logits + 1e-10))

    return d_loss_fn, g_loss_fn

Getting Start

$ python GAN/GAN.py

DCGAN

Paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
Author Alec Radford, Luke Metz, Soumith Chintala
Publish ICLR 2016

Animation Results

Loss Function

def get_loss_fn():
    criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    def d_loss_fn(real_logits, fake_logits):
        real_loss = criterion(tf.ones_like(real_logits), real_logits)
        fake_loss = criterion(tf.zeros_like(fake_logits), fake_logits)
        return real_loss + fake_loss

    def g_loss_fn(fake_logits):
        return criterion(tf.ones_like(fake_logits), fake_logits)

    return d_loss_fn, g_loss_fn

Getting Start

$ python DCGAN/DCGAN.py

LSGAN

Paper Least Squares Generative Adversarial Networks
Author Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, Stephen Paul Smolley
Publish ICCV 2017

Animation Results

Loss Function

def get_loss_fn():
    criterion = tf.keras.losses.MeanSquaredError()

    def d_loss_fn(real_logits, fake_logits):
        real_loss = criterion(tf.ones_like(real_logits), real_logits)
        fake_loss = criterion(tf.zeros_like(fake_logits), fake_logits)
        return real_loss + fake_loss

    def g_loss_fn(fake_logits):
        return criterion(tf.ones_like(fake_logits), fake_logits)

    return d_loss_fn, g_loss_fn

Getting Start

$ python LSGAN/LSGAN.py

WGAN

Paper Wasserstein GAN
Author Martin Arjovsky, Soumith Chintala, Léon Bottou
Publish arXiv 2017

Animation Results

Loss Function

def get_loss_fn():
    def d_loss_fn(real_logits, fake_logits):
        return tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)

    def g_loss_fn(fake_logits):
        return -tf.reduce_mean(fake_logits)

    return d_loss_fn, g_loss_fn

Getting Start

$ python WGAN/WGAN.py

WGAN-GP

Paper Improved Training of Wasserstein GANs
Author Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville
Publish NIPS 2017

Animation Results

Loss Function

def get_loss_fn():
    def d_loss_fn(real_logits, fake_logits):
        return tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)

    def g_loss_fn(fake_logits):
        return -tf.reduce_mean(fake_logits)

    return d_loss_fn, g_loss_fn

def gradient_penalty(generator, real_images, fake_images):
    real_images = tf.cast(real_images, tf.float32)
    fake_images = tf.cast(fake_images, tf.float32)
    alpha = tf.random.uniform([BATCH_SIZE, 1, 1, 1], 0., 1.)
    diff = fake_images - real_images
    inter = real_images + (alpha * diff)
    with tf.GradientTape() as tape:
        tape.watch(inter)
        predictions = generator(inter)
    gradients = tape.gradient(predictions, [inter])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    return tf.reduce_mean((slopes - 1.) ** 2)

Getting Start

$ python WGAN-GP/WGAN-GP.py

DRAGAN

Paper On Convergence and Stability of GANs
Author Naveen Kodali, Jacob Abernethy, James Hays, Zsolt Kira
Publish ICLR 2018

Animation Results

Loss Function

def get_loss_fn():
    def d_loss_fn(real_logits, fake_logits):
        return tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)

    def g_loss_fn(fake_logits):
        return -tf.reduce_mean(fake_logits)

    return d_loss_fn, g_loss_fn

def gradient_penalty(generator, real_images):
    real_images = tf.cast(real_images, tf.float32)
    def _interpolate(a):
        beta = tf.random.uniform(tf.shape(a), 0., 1.)
        b = a + 0.5 * tf.math.reduce_std(a) * beta
        shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1)
        alpha = tf.random.uniform(shape, 0., 1.)
        inter = a + alpha * (b - a)
        inter.set_shape(a.shape)
        return inter
    
    x = _interpolate(real_images)
    with tf.GradientTape() as tape:
        tape.watch(x)
        predictions = generator(x, training=True)
    grad = tape.gradient(predictions, x)
    slopes = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)
    return tf.reduce_mean((slopes - 1.) ** 2)

Getting Start

$ python DRAGAN/DRAGAN.py

Reference

About

🚀 Variants of GANs most easily implemented as TensorFlow2. GAN, DCGAN, LSGAN, WGAN, WGAN-GP, DRAGAN, ETC...

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages