11# -*- coding: UTF-8 -*-
22import time
33import traceback
4+ import logging
45import os
56import gc
67import torch
1617
1718class BaseTrainer (object ):
1819 def __init__ (self , args , model ):
20+ torch .manual_seed (0 )
1921 self .args = args
2022 self .vocab2id = load_vocab (self .args .vocab_path , self .args .vocab_size )
2123
@@ -34,9 +36,18 @@ def __init__(self, args, model):
3436 self .args .max_oov_count ,
3537 self .args .max_target_len ,
3638 'train' )
37- timemark = time .strftime ('%Y%m%d-%H%M%S' , time .localtime (time .time ()))
38- self .dest_dir = os .path .join (self .args .dest_base_dir , self .args .exp_name + '-' + timemark ) + '/'
39- os .mkdir (self .dest_dir )
39+ if self .args .train_from :
40+ self .dest_dir = os .path .dirname (self .args .train_from ) + '/'
41+ else :
42+ timemark = time .strftime ('%Y%m%d-%H%M%S' , time .localtime (time .time ()))
43+ self .dest_dir = os .path .join (self .args .dest_base_dir , self .args .exp_name + '-' + timemark ) + '/'
44+ os .mkdir (self .dest_dir )
45+
46+ fh = logging .FileHandler (os .path .join (self .dest_dir , args .logfile ))
47+ fh .setLevel (logging .INFO )
48+ fh .setFormatter (logging .Formatter ('[%(asctime)s] %(message)s' ))
49+ self .logger .addHandler (fh )
50+
4051 if not self .args .tensorboard_dir :
4152 tensorboard_dir = self .dest_dir + 'logs/'
4253 else :
@@ -46,7 +57,7 @@ def __init__(self, args, model):
4657 self .macro_evaluator = KeyphraseEvaluator (self .eval_topn , 'macro' )
4758 self .micro_evaluator = KeyphraseEvaluator (self .eval_topn , 'micro' )
4859 self .best_f1 = None
49- self .best_step = None
60+ self .best_step = 0
5061 self .not_update_count = 0
5162
5263 def parse_args (self ):
@@ -55,7 +66,12 @@ def parse_args(self):
5566 def train (self ):
5667 step = 0
5768 is_stop = False
58- self .logger .info ('destination dir:{}' .format (self .dest_dir ))
69+ if self .args .train_from :
70+ step = self .args .step
71+ self .logger .info ('train from destination dir:{}' .format (self .dest_dir ))
72+ self .logger .info ('train from step {}' .format (step ))
73+ else :
74+ self .logger .info ('destination dir:{}' .format (self .dest_dir ))
5975 for epoch in range (1 , self .args .epochs + 1 ):
6076 for batch_idx , batch in enumerate (self .train_loader ):
6177 self .model .train ()
@@ -97,6 +113,8 @@ def evaluate_and_save_model(self, step, epoch):
97113 model_basename = self .dest_dir + '{}_epoch_{}_batch_{}' .format (exp_name , epoch , step )
98114 torch .save (self .model .state_dict (), model_basename + '.model' )
99115 write_json (model_basename + '.json' , vars (self .args ))
116+ score_msg_tmpl = 'best score: step {} macro f1@{} {:.4f}'
117+ self .logger .info (score_msg_tmpl .format (self .best_step , self .eval_topn [- 1 ], self .best_f1 ))
100118 self .logger .info ('epoch {} step {}, model saved' .format (epoch , step ))
101119
102120 def evaluate (self , step ):
0 commit comments