-
Notifications
You must be signed in to change notification settings - Fork 268
/
train.py
335 lines (276 loc) · 15.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
print("started imports")
import sys
import argparse
import time
import cv2
import wandb
from PIL import Image
import os
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch
import torchvision.transforms as transforms
import torch.optim.lr_scheduler as scheduler
# custom imports
sys.path.append('./apex/')
from apex import amp
from network.AEI_Net import *
from network.MultiscaleDiscriminator import *
from utils.training.Dataset import FaceEmbedVGG2, FaceEmbed
from utils.training.image_processing import make_image_list, get_faceswap
from utils.training.losses import hinge_loss, compute_discriminator_loss, compute_generator_losses
from utils.training.detector import detect_landmarks, paint_eyes
from AdaptiveWingLoss.core import models
from arcface_model.iresnet import iresnet100
print("finished imports")
def train_one_epoch(G: 'generator model',
D: 'discriminator model',
opt_G: "generator opt",
opt_D: "discriminator opt",
scheduler_G: "scheduler G opt",
scheduler_D: "scheduler D opt",
netArc: 'ArcFace model',
model_ft: 'Landmark Detector',
args: 'Args Namespace',
dataloader: torch.utils.data.DataLoader,
device: 'torch device',
epoch:int,
loss_adv_accumulated:int):
for iteration, data in enumerate(dataloader):
start_time = time.time()
Xs_orig, Xs, Xt, same_person = data
Xs_orig = Xs_orig.to(device)
Xs = Xs.to(device)
Xt = Xt.to(device)
same_person = same_person.to(device)
# get the identity embeddings of Xs
with torch.no_grad():
embed = netArc(F.interpolate(Xs_orig, [112, 112], mode='bilinear', align_corners=False))
diff_person = torch.ones_like(same_person)
if args.diff_eq_same:
same_person = diff_person
# generator training
opt_G.zero_grad()
Y, Xt_attr = G(Xt, embed)
Di = D(Y)
ZY = netArc(F.interpolate(Y, [112, 112], mode='bilinear', align_corners=False))
if args.eye_detector_loss:
Xt_eyes, Xt_heatmap_left, Xt_heatmap_right = detect_landmarks(Xt, model_ft)
Y_eyes, Y_heatmap_left, Y_heatmap_right = detect_landmarks(Y, model_ft)
eye_heatmaps = [Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right]
else:
eye_heatmaps = None
lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes = compute_generator_losses(G, Y, Xt, Xt_attr, Di,
embed, ZY, eye_heatmaps,loss_adv_accumulated,
diff_person, same_person, args)
with amp.scale_loss(lossG, opt_G) as scaled_loss:
scaled_loss.backward()
opt_G.step()
if args.scheduler:
scheduler_G.step()
# discriminator training
opt_D.zero_grad()
lossD = compute_discriminator_loss(D, Y, Xs, diff_person)
with amp.scale_loss(lossD, opt_D) as scaled_loss:
scaled_loss.backward()
if (not args.discr_force) or (loss_adv_accumulated < 4.):
opt_D.step()
if args.scheduler:
scheduler_D.step()
batch_time = time.time() - start_time
if iteration % args.show_step == 0:
images = [Xs, Xt, Y]
if args.eye_detector_loss:
Xt_eyes_img = paint_eyes(Xt, Xt_eyes)
Yt_eyes_img = paint_eyes(Y, Y_eyes)
images.extend([Xt_eyes_img, Yt_eyes_img])
image = make_image_list(images)
if args.use_wandb:
wandb.log({"gen_images":wandb.Image(image, caption=f"{epoch:03}" + '_' + f"{iteration:06}")})
else:
cv2.imwrite('./images/generated_image.jpg', image[:,:,::-1])
if iteration % 10 == 0:
print(f'epoch: {epoch} {iteration} / {len(dataloader)}')
print(f'lossD: {lossD.item()} lossG: {lossG.item()} batch_time: {batch_time}s')
print(f'L_adv: {L_adv.item()} L_id: {L_id.item()} L_attr: {L_attr.item()} L_rec: {L_rec.item()}')
if args.eye_detector_loss:
print(f'L_l2_eyes: {L_l2_eyes.item()}')
print(f'loss_adv_accumulated: {loss_adv_accumulated}')
if args.scheduler:
print(f'scheduler_G lr: {scheduler_G.get_last_lr()} scheduler_D lr: {scheduler_D.get_last_lr()}')
if args.use_wandb:
if args.eye_detector_loss:
wandb.log({"loss_eyes": L_l2_eyes.item()}, commit=False)
wandb.log({"loss_id": L_id.item(),
"lossD": lossD.item(),
"lossG": lossG.item(),
"loss_adv": L_adv.item(),
"loss_attr": L_attr.item(),
"loss_rec": L_rec.item()})
if iteration % 5000 == 0:
torch.save(G.state_dict(), f'./saved_models_{args.run_name}/G_latest.pth')
torch.save(D.state_dict(), f'./saved_models_{args.run_name}/D_latest.pth')
torch.save(G.state_dict(), f'./current_models_{args.run_name}/G_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth')
torch.save(D.state_dict(), f'./current_models_{args.run_name}/D_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth')
if (iteration % 250 == 0) and (args.use_wandb):
### Посмотрим как выглядит свап на трех конкретных фотках, чтобы проследить динамику
G.eval()
res1 = get_faceswap('examples/images/training//source1.png', 'examples/images/training//target1.png', G, netArc, device)
res2 = get_faceswap('examples/images/training//source2.png', 'examples/images/training//target2.png', G, netArc, device)
res3 = get_faceswap('examples/images/training//source3.png', 'examples/images/training//target3.png', G, netArc, device)
res4 = get_faceswap('examples/images/training//source4.png', 'examples/images/training//target4.png', G, netArc, device)
res5 = get_faceswap('examples/images/training//source5.png', 'examples/images/training//target5.png', G, netArc, device)
res6 = get_faceswap('examples/images/training//source6.png', 'examples/images/training//target6.png', G, netArc, device)
output1 = np.concatenate((res1, res2, res3), axis=0)
output2 = np.concatenate((res4, res5, res6), axis=0)
output = np.concatenate((output1, output2), axis=1)
wandb.log({"our_images":wandb.Image(output, caption=f"{epoch:03}" + '_' + f"{iteration:06}")})
G.train()
def train(args, device):
# training params
batch_size = args.batch_size
max_epoch = args.max_epoch
# initializing main models
G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512).to(device)
D = MultiscaleDiscriminator(input_nc=3, n_layers=5, norm_layer=torch.nn.InstanceNorm2d).to(device)
G.train()
D.train()
# initializing model for identity extraction
netArc = iresnet100(fp16=False)
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
netArc=netArc.cuda()
netArc.eval()
if args.eye_detector_loss:
model_ft = models.FAN(4, "False", "False", 98)
checkpoint = torch.load('./AdaptiveWingLoss/AWL_detector/WFLW_4HG.pth')
if 'state_dict' not in checkpoint:
model_ft.load_state_dict(checkpoint)
else:
pretrained_weights = checkpoint['state_dict']
model_weights = model_ft.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() \
if k in model_weights}
model_weights.update(pretrained_weights)
model_ft.load_state_dict(model_weights)
model_ft = model_ft.to(device)
model_ft.eval()
else:
model_ft=None
opt_G = optim.Adam(G.parameters(), lr=args.lr_G, betas=(0, 0.999), weight_decay=1e-4)
opt_D = optim.Adam(D.parameters(), lr=args.lr_D, betas=(0, 0.999), weight_decay=1e-4)
G, opt_G = amp.initialize(G, opt_G, opt_level=args.optim_level)
D, opt_D = amp.initialize(D, opt_D, opt_level=args.optim_level)
if args.scheduler:
scheduler_G = scheduler.StepLR(opt_G, step_size=args.scheduler_step, gamma=args.scheduler_gamma)
scheduler_D = scheduler.StepLR(opt_D, step_size=args.scheduler_step, gamma=args.scheduler_gamma)
else:
scheduler_G = None
scheduler_D = None
if args.pretrained:
try:
G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')), strict=False)
D.load_state_dict(torch.load(args.D_path, map_location=torch.device('cpu')), strict=False)
print("Loaded pretrained weights for G and D")
except FileNotFoundError as e:
print("Not found pretrained weights. Continue without any pretrained weights.")
if args.vgg:
dataset = FaceEmbedVGG2(args.dataset_path, same_prob=args.same_person, same_identity=args.same_identity)
else:
dataset = FaceEmbed([args.dataset_path], same_prob=args.same_person)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
# Будем считать аккумулированный adv loss, чтобы обучать дискриминатор только когда он ниже порога, если discr_force=True
loss_adv_accumulated = 20.
for epoch in range(0, max_epoch):
train_one_epoch(G,
D,
opt_G,
opt_D,
scheduler_G,
scheduler_D,
netArc,
model_ft,
args,
dataloader,
device,
epoch,
loss_adv_accumulated)
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
print('cuda is not available. using cpu. check if it\'s ok')
print("Starting traing")
train(args, device=device)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset params
parser.add_argument('--dataset_path', default='/VggFace2-crop/', help='Path to the dataset. If not VGG2 dataset is used, param --vgg should be set False')
parser.add_argument('--G_path', default='./saved_models/G.pth', help='Path to pretrained weights for G. Only used if pretrained=True')
parser.add_argument('--D_path', default='./saved_models/D.pth', help='Path to pretrained weights for D. Only used if pretrained=True')
parser.add_argument('--vgg', default=True, type=bool, help='When using VGG2 dataset (or any other dataset with several photos for one identity)')
# weights for loss
parser.add_argument('--weight_adv', default=1, type=float, help='Adversarial Loss weight')
parser.add_argument('--weight_attr', default=10, type=float, help='Attributes weight')
parser.add_argument('--weight_id', default=20, type=float, help='Identity Loss weight')
parser.add_argument('--weight_rec', default=10, type=float, help='Reconstruction Loss weight')
parser.add_argument('--weight_eyes', default=0., type=float, help='Eyes Loss weight')
# training params you may want to change
parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
parser.add_argument('--same_person', default=0.2, type=float, help='Probability of using same person identity during training')
parser.add_argument('--same_identity', default=True, type=bool, help='Using simswap approach, when source_id = target_id. Only possible with vgg=True')
parser.add_argument('--diff_eq_same', default=False, type=bool, help='Don\'t use info about where is defferent identities')
parser.add_argument('--pretrained', default=True, type=bool, help='If using the pretrained weights for training or not')
parser.add_argument('--discr_force', default=False, type=bool, help='If True Discriminator would not train when adversarial loss is high')
parser.add_argument('--scheduler', default=False, type=bool, help='If True decreasing LR is used for learning of generator and discriminator')
parser.add_argument('--scheduler_step', default=5000, type=int)
parser.add_argument('--scheduler_gamma', default=0.2, type=float, help='It is value, which shows how many times to decrease LR')
parser.add_argument('--eye_detector_loss', default=False, type=bool, help='If True eye loss with using AdaptiveWingLoss detector is applied to generator')
# info about this run
parser.add_argument('--use_wandb', default=False, type=bool, help='Use wandb to track your experiments or not')
parser.add_argument('--run_name', required=True, type=str, help='Name of this run. Used to create folders where to save the weights.')
parser.add_argument('--wandb_project', default='your-project-name', type=str)
parser.add_argument('--wandb_entity', default='your-login', type=str)
# training params you probably don't want to change
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--lr_G', default=4e-4, type=float)
parser.add_argument('--lr_D', default=4e-4, type=float)
parser.add_argument('--max_epoch', default=2000, type=int)
parser.add_argument('--show_step', default=500, type=int)
parser.add_argument('--save_epoch', default=1, type=int)
parser.add_argument('--optim_level', default='O2', type=str)
args = parser.parse_args()
if args.vgg==False and args.same_identity==True:
raise ValueError("Sorry, you can't use some other dataset than VGG2 Faces with param same_identity=True")
if args.use_wandb==True:
wandb.init(project=args.wandb_project, entity=args.wandb_entity, settings=wandb.Settings(start_method='fork'))
config = wandb.config
config.dataset_path = args.dataset_path
config.weight_adv = args.weight_adv
config.weight_attr = args.weight_attr
config.weight_id = args.weight_id
config.weight_rec = args.weight_rec
config.weight_eyes = args.weight_eyes
config.same_person = args.same_person
config.Vgg2Face = args.vgg
config.same_identity = args.same_identity
config.diff_eq_same = args.diff_eq_same
config.discr_force = args.discr_force
config.scheduler = args.scheduler
config.scheduler_step = args.scheduler_step
config.scheduler_gamma = args.scheduler_gamma
config.eye_detector_loss = args.eye_detector_loss
config.pretrained = args.pretrained
config.run_name = args.run_name
config.G_path = args.G_path
config.D_path = args.D_path
config.batch_size = args.batch_size
config.lr_G = args.lr_G
config.lr_D = args.lr_D
elif not os.path.exists('./images'):
os.mkdir('./images')
# Создаем папки, чтобы было куда сохранять последние веса моделей, а также веса с каждой эпохи
if not os.path.exists(f'./saved_models_{args.run_name}'):
os.mkdir(f'./saved_models_{args.run_name}')
os.mkdir(f'./current_models_{args.run_name}')
main(args)