-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_GAN.py
49 lines (39 loc) · 1.75 KB
/
main_GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
from utils.natural_process import *
from config.config1 import get_args
from train_GAN import train
from utils.utils import init_settings, setup_seed
from utils.create_model import get_model_by_name
import time
if __name__ == '__main__':
args = get_args()
setup_seed(args.seed)
args.model_name = 'GAN'
args.experiment_dir, args.checkpoints_dir, args.tensorboard_dir = init_settings(args)
gen, dis = get_model_by_name(args.model_name)
ssim, l1, psnr = [], [], []
start = time.time()
# n次试验取平均
for _ in range(args.n):
e1, e2, e3 = [], [], []
# 训练10张自然图像
for dir in os.listdir(args.data_dir):
args.data_dir_i = os.path.join(args.data_dir, dir)
args.res_dir_i = os.path.join(args.res_dir, dir)
if not os.path.exists(args.res_dir_i):
os.makedirs(args.res_dir_i)
print(args.data_dir_i)
res = train(args, gen, dis)
e1.append(res[0])
e2.append(res[1])
e3.append(res[2])
print('===============================================================================')
ssim.append(np.mean(e1))
l1.append(np.mean(e2))
psnr.append(np.mean(e3))
end = time.time()
mean_run_time = int((end - start) / args.n)
with open(os.path.join(args.experiment_dir, 'log.txt'), mode='w') as log_object:
log_object.write(args.model_name + '\tevaluate\t ssim:{}\t pixel:{}\t psnr:{}'.format(np.mean(ssim), np.mean(l1), np.mean(psnr)))
log_object.write('\nmean-run-time:' + time.strftime("%H:%M:%S", time.gmtime(mean_run_time)))
print('evaluate\t\t ssim:{}\t\t pixel:{}\t\t psnr:{}'.format(np.mean(ssim), np.mean(l1), np.mean(psnr)))