Skip to content

Commit

Permalink
update: model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 12, 2020
1 parent ac6d401 commit 93e3e4c
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions awesome_gans/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ def build_generator(self) -> tf.keras.Model:
return Model(inputs, x, name='generator')

@tf.function
def train_discriminator(self, x_real: tf.Tensor):
def train_discriminator(self, x: tf.Tensor):
z = tf.random.uniform((self.bs, self.z_dims))
with tf.GradientTape() as gt:
x_fake = self.generator(z, training=True)
d_fake = self.discriminator(x_fake, training=True)
d_real = self.discriminator(x_real, training=True)
d_real = self.discriminator(x, training=True)

d_loss = tf.reduce_mean(discriminator_loss(self.d_loss, d_real, d_fake))
d_loss = discriminator_loss(self.d_loss, d_real, d_fake)

gradient = gt.gradient(d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(gradient, self.discriminator.trainable_variables))
gradient = gt.gradient(d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(gradient, self.discriminator.trainable_variables))

for var in self.discriminator.trainable_variables:
var.assign(tf.clip_by_value(var, -self.grad_clip, self.grad_clip))

return d_loss

Expand All @@ -117,13 +120,10 @@ def train_generator(self):
x_fake = self.generator(z, training=True)
d_fake = self.discriminator(x_fake, training=True)

g_loss = tf.reduce_mean(generator_loss(self.g_loss, d_fake))

gradient = gt.gradient(g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(gradient, self.generator.trainable_variables))
g_loss = generator_loss(self.g_loss, d_fake, d_fake)

for var in self.discriminator.trainable_variables:
var.assign(tf.clip_by_value(var, -self.grad_clip, self.grad_clip))
gradient = gt.gradient(g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(gradient, self.generator.trainable_variables))

return g_loss

Expand All @@ -133,7 +133,7 @@ def load(self) -> int:
def train(self, dataset: tf.data.Dataset):
start_epoch: int = self.load()

z_samples = tf.random.normal((self.n_samples, self.z_dims))
z_samples = tf.random.uniform((self.n_samples, self.z_dims))

for epoch in range(start_epoch, self.epochs):
loader = tqdm(dataset, desc=f'[*] Epoch {epoch} / {self.epochs}')
Expand Down

0 comments on commit 93e3e4c

Please sign in to comment.