Skip to content

Commit

Permalink
refactor: the codes
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 12, 2020
1 parent bde98cb commit 620b931
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
6 changes: 4 additions & 2 deletions awesome_gans/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ def parse_args():
parser.add_argument('--n_channels', type=int, default=3, help='number of channel of image')
parser.add_argument('--root_path', type=str, default='./', help='root path')
parser.add_argument(
'--dataset', type=str, default='cifar10',
'--dataset',
type=str,
default='cifar10',
choices=['mnist', 'cifar10', 'cifar100', 'div2k'],
help='type of dataset'
help='type of dataset',
)
parser.add_argument('--mnist_path', type=str, default='mnist')
parser.add_argument('--fashion_mnist_path', type=str, default='fashion-mnist')
Expand Down
2 changes: 1 addition & 1 deletion awesome_gans/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_dataset(self, use_label: bool = False):
ds = tfds.load(name=self.dataset, split='train', as_supervised=use_label, shuffle_files=True)
ds = ds.map(lambda x: self.preprocess_image(x['image']), tf.data.experimental.AUTOTUNE)
ds = ds.cache()
ds = ds.shuffle(50000, reshuffle_each_iteration=True)
ds = ds.shuffle(50000)
ds = ds.batch(self.bs, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
1 change: 0 additions & 1 deletion awesome_gans/wgan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def get_config():
parser.add_argument('--epochs', default=50, type=int, help='epochs to train')
parser.add_argument('--global_steps', default=5e4, type=int, help='iterations to train')
parser.add_argument('--n_feats', default=64, type=int, help='number of convolution filters')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate for generic somethings')
parser.add_argument('--d_lr', default=1e-4, type=float, help='learning rate of discriminator')
parser.add_argument('--g_lr', default=1e-4, type=float, help='learning rate of generator')
parser.add_argument(
Expand Down
7 changes: 3 additions & 4 deletions awesome_gans/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,14 @@ def train(self, dataset: tf.data.Dataset):
for epoch in range(start_epoch, self.epochs):
loader = tqdm(dataset, desc=f'[*] Epoch {epoch} / {self.epochs}')
for n_iter, batch in enumerate(loader):
d_loss = 0.0
for _ in range(self.n_critics):
d_loss += self.train_discriminator(batch)
d_loss = self.train_discriminator(batch)

g_loss = self.train_generator()

loader.set_postfix(
d_loss=f'{d_loss / self.n_critics:.6f}',
g_loss=f'{g_loss:.6f}',
d_loss=f'{d_loss:.5f}',
g_loss=f'{g_loss:.5f}',
)

# saving the generated samples
Expand Down

0 comments on commit 620b931

Please sign in to comment.