-
Notifications
You must be signed in to change notification settings - Fork 81
/
utility.py
110 lines (100 loc) · 4.03 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
'''
Utility functions for training.
Contact: Liming Zhao ([email protected])
'''
import mxnet as mx
import numpy as np
import time
import logging
#utility functions
def mkdir(dirname,clean=False):
import os
if clean and os.path.exists(dirname):
import shutil
shutil.rmtree(dirname)
if not os.path.exists(dirname):
os.makedirs(dirname)
def cal_params(symbol,input_shapes={"data":(1, 3, 32, 32)}):
"""Initialize weight parameters and auxiliary states"""
arg_shapes, _, _ = symbol.infer_shape(**input_shapes)
assert(arg_shapes is not None)
arg_names = symbol.list_arguments()
input_names = input_shapes.keys()
param_names = [key for key in arg_names if key not in input_names]
param_name_shapes = [x for x in zip(arg_names, arg_shapes) if x[0] in param_names]
import numpy as np
params_num=0
for k, s in param_name_shapes:
params_num+=np.prod(s)
return '%.4fM'%(params_num/1000000.0)
class Scheduler(mx.lr_scheduler.MultiFactorScheduler):
def __init__(self, epoch_step, factor, epoch_size):
super(Scheduler, self).__init__(
step=[epoch_size * s for s in epoch_step],
factor=factor
)
@mx.optimizer.Optimizer.register
class Nesterov(mx.optimizer.NAG):
#same with torch implementation
def set_wd_mult(self, args_wd_mult):
self.wd_mult = {}
for n in self.idx2name.values():
if not ( n.endswith('_weight') or n.endswith('_bias')
or n.endswith('_gamma') or n.endswith('_beta')
):
self.wd_mult[n] = 0.0
if self.sym is not None:
attr = self.sym.list_attr(recursive=True)
for k, v in attr.items():
if k.endswith('_wd_mult'):
self.wd_mult[k[:-len('_wd_mult')]] = float(v)
self.wd_mult.update(args_wd_mult)
class InfoCallback(mx.callback.Speedometer):
"""Calculate training speed in frequent
Parameters
----------
batch_size: int
batch_size of data
frequent: int
calculation frequent
"""
def __init__(self, batch_size, frequent=50):
mx.callback.Speedometer.__init__(self, batch_size, frequent)
self.total_top1=0.0
self.total_top5=0.0
self.total_loss=0.0
def __call__(self, param):
"""Callback to Show speed."""
count = param.nbatch
if self.last_count > count:
self.init = False
self.last_count = count
if self.init:
if count % self.frequent == 0:
speed = self.frequent * self.batch_size / (time.time() - self.tic)
if param.eval_metric is not None:
name_value = param.eval_metric.get_name_value()
param.eval_metric.reset()
log_info='Epoch[%d] Batch [%d]\tSpeed: %.0f'%(param.epoch, count, speed)
for name, value in name_value:
value=value if not np.isinf(value) else 10.0 #log(0)=inf
if name=='cross-entropy':
self.total_loss+=1.0*value*self.frequent
log_info=log_info+'\tloss: %.4f(%.4f)'%(value,self.total_loss/count)
elif name=='accuracy':
self.total_top1+=1.0*value*self.frequent
log_info=log_info+'\ttop1: %.4f(%.4f)'%(100.0*value,100.0*self.total_top1/count)
elif 'top_k' in name:
self.total_top5+=1.0*value*self.frequent
log_info=log_info+'\ttop5: %.4f(%.4f)'%(100.0*value,100.0*self.total_top5/count)
logging.info(log_info)
else:
logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
param.epoch, count, speed)
self.tic = time.time()
else:
self.init = True
self.tic = time.time()
self.total_top1=0.0
self.total_top5=0.0
self.total_loss=0.0