Skip to content

Commit

Permalink
update: codes
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 12, 2020
1 parent 620b931 commit 4e27ef0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
2 changes: 1 addition & 1 deletion awesome_gans/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_dataset(self, use_label: bool = False):
ds = tfds.load(name=self.dataset, split='train', as_supervised=use_label, shuffle_files=True)
ds = ds.map(lambda x: self.preprocess_image(x['image']), tf.data.experimental.AUTOTUNE)
ds = ds.cache()
ds = ds.shuffle(50000)
ds = ds.shuffle(self.bs * 16)
ds = ds.batch(self.bs, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
12 changes: 12 additions & 0 deletions awesome_gans/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import tensorflow as tf


@tf.function
def discriminator_wgan_loss(real: tf.Tensor, fake: tf.Tensor):
return tf.reduce_mean(fake) - tf.reduce_mean(real)


@tf.function
def generator_wgan_loss(fake: tf.Tensor):
return -tf.reduce_mean(fake)


@tf.function
def discriminator_loss(loss_func: str, real: tf.Tensor, fake: tf.Tensor, use_ra: bool = False):
real_loss: float = 0.0
fake_loss: float = 0.0
Expand Down Expand Up @@ -31,6 +42,7 @@ def discriminator_loss(loss_func: str, real: tf.Tensor, fake: tf.Tensor, use_ra:
return loss


@tf.function
def generator_loss(loss_func: str, real: tf.Tensor, fake: tf.Tensor, use_ra: bool = False):
fake_loss: float = 0.0
real_loss: float = 0.0
Expand Down
18 changes: 9 additions & 9 deletions awesome_gans/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensorflow.keras.models import Model
from tqdm import tqdm

from awesome_gans.losses import discriminator_loss, generator_loss
from awesome_gans.losses import discriminator_loss, generator_loss, discriminator_wgan_loss, generator_wgan_loss
from awesome_gans.optimizers import build_optimizer
from awesome_gans.utils import merge_images, save_image

Expand Down Expand Up @@ -103,15 +103,15 @@ def train_discriminator(self, x: tf.Tensor):
d_fake = self.discriminator(x_fake, training=True)
d_real = self.discriminator(x, training=True)

d_loss = discriminator_loss(self.d_loss, d_real, d_fake)
d_loss = discriminator_wgan_loss(d_real, d_fake)

gradient = gt.gradient(d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(gradient, self.discriminator.trainable_variables))
gradients = gt.gradient(d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(gradients, self.discriminator.trainable_variables))

for var in self.discriminator.trainable_variables:
var.assign(tf.clip_by_value(var, -self.grad_clip, self.grad_clip))

return d_loss
return d_loss

@tf.function
def train_generator(self):
Expand All @@ -120,12 +120,12 @@ def train_generator(self):
x_fake = self.generator(z, training=True)
d_fake = self.discriminator(x_fake, training=True)

g_loss = generator_loss(self.g_loss, d_fake, d_fake)
g_loss = generator_wgan_loss(d_fake)

gradient = gt.gradient(g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(gradient, self.generator.trainable_variables))
gradients = gt.gradient(g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(gradients, self.generator.trainable_variables))

return g_loss
return g_loss

def load(self) -> int:
return 0
Expand Down

0 comments on commit 4e27ef0

Please sign in to comment.