11# -*- coding: UTF-8 -*-
22import argparse
3- import time
4- import traceback
5- import os
63import torch
7- import torch .nn as nn
8- import torch .optim as optim
9- from torch .utils .tensorboard import SummaryWriter
10- from pysenal import write_json , get_logger
4+ from pysenal import write_json
115from deep_keyphrase .utils .vocab_loader import load_vocab
126from deep_keyphrase .copy_rnn .model import CopyRNN
7+ from deep_keyphrase .base_trainer import BaseTrainer
138from deep_keyphrase .dataloader import (KeyphraseDataLoader , TOKENS , TARGET )
14- from deep_keyphrase .evaluation import KeyphraseEvaluator
159from deep_keyphrase .copy_rnn .predict import CopyRnnPredictor
16- from deep_keyphrase .utils .constants import PAD_WORD
1710
1811
19- class Trainer (object ):
20- def __init__ (self , args ):
21- self .args = args
22- self .vocab2id = load_vocab (args .vocab_path , args .vocab_size )
23-
24- self .model = CopyRNN (args , self .vocab2id )
25- if torch .cuda .is_available ():
26- self .model = self .model .cuda ()
27- if args .train_parallel :
28- self .model = nn .DataParallel (self .model )
29- self .loss_func = nn .NLLLoss (ignore_index = self .vocab2id [PAD_WORD ])
30- self .optimizer = optim .Adam (self .model .parameters (), lr = 1e-4 )
31- self .logger = get_logger ('train' )
32- self .train_loader = KeyphraseDataLoader (args .train_filename ,
33- self .vocab2id ,
34- args .batch_size ,
35- args .max_src_len ,
36- args .max_oov_count ,
37- args .max_target_len ,
38- 'train' )
39- timemark = time .strftime ('%Y%m%d-%H%M%S' , time .localtime (time .time ()))
40- self .dest_dir = os .path .join (args .dest_base_dir , args .exp_name + '-' + timemark ) + '/'
41- os .mkdir (self .dest_dir )
42- if not args .tensorboard_dir :
43- tensorboard_dir = self .dest_dir + 'logs/'
44- else :
45- tensorboard_dir = args .tensorboard_dir
46- self .writer = SummaryWriter (tensorboard_dir )
47- self .eval_topn = (5 , 10 )
48- self .macro_evaluator = KeyphraseEvaluator (self .eval_topn , 'macro' )
49- self .micro_evaluator = KeyphraseEvaluator (self .eval_topn , 'micro' )
50- self .best_f1 = None
51- self .not_update_count = 0
52-
53- def train (self ):
54- step = 0
55- is_stop = False
56- self .logger .info ('destination dir:{}' .format (self .dest_dir ))
57- for epoch in range (1 , self .args .epochs + 1 ):
58- for batch_idx , batch in enumerate (self .train_loader ):
59- self .model .train ()
60- try :
61- loss = self .train_batch (batch )
62- except Exception as e :
63- err_stack = traceback .format_exc ()
64- self .logger .error (err_stack )
65- loss = 0.0
66- step += 1
67- self .writer .add_scalar ('loss' , loss , step )
68- if step and step % self .args .save_model_step == 0 :
69- self .save_model (step , epoch )
70- if self .not_update_count >= self .args .early_stop_tolerance :
71- is_stop = True
72- break
73- if is_stop :
74- break
12+ class CopyRnnTrainer (BaseTrainer ):
13+ def __init__ (self ):
14+ args = self .parse_args ()
15+ model = CopyRNN (args , load_vocab (args .vocab_path ))
16+ super ().__init__ (args , model )
7517
7618 def train_batch (self , batch ):
7719 loss = 0
@@ -81,10 +23,7 @@ def train_batch(self, batch):
8123 targets = targets .cuda ()
8224 batch_size = len (batch [TOKENS ])
8325 encoder_output = None
84- if self .args .bidirectional :
85- decoder_state = torch .zeros (batch_size , self .args .target_hidden_size * 2 )
86- else :
87- decoder_state = torch .zeros (batch_size , self .args .target_hidden_size )
26+ decoder_state = torch .zeros (batch_size , self .args .target_hidden_size )
8827 hidden_state = None
8928 for target_index in range (self .args .max_target_len ):
9029 if target_index == 0 :
@@ -94,7 +33,8 @@ def train_batch(self, batch):
9433 if self .args .teacher_forcing :
9534 prev_output_tokens = targets [:, target_index ].unsqueeze (1 )
9635 else :
97- prev_output_tokens = torch .topk (decoder_prob , 1 , 1 )
36+ best_probs , prev_output_tokens = torch .topk (decoder_prob , 1 , 1 )
37+
9838 output = self .model (src_dict = batch ,
9939 prev_output_tokens = prev_output_tokens ,
10040 encoder_output_dict = encoder_output ,
@@ -115,21 +55,6 @@ def train_batch(self, batch):
11555 self .optimizer .step ()
11656 return loss
11757
118- def save_model (self , step , epoch ):
119- valid_f1 = self .evaluate (step )
120- if self .best_f1 is None :
121- self .best_f1 = valid_f1
122- elif valid_f1 >= self .best_f1 :
123- self .best_f1 = valid_f1
124- self .not_update_count = 0
125- else :
126- self .not_update_count += 1
127- exp_name = self .args .exp_name
128- model_basename = self .dest_dir + '{}_epoch_{}_batch_{}' .format (exp_name , epoch , step )
129- torch .save (self .model .state_dict (), model_basename + '.model' )
130- write_json (model_basename + '.json' , vars (self .args ))
131- self .logger .info ('epoch {} step {}, model saved' .format (epoch , step ))
132-
13358 def evaluate (self , step ):
13459 predictor = CopyRnnPredictor (model_info = {'model' : self .model , 'config' : self .args },
13560 vocab_info = self .vocab2id ,
@@ -151,7 +76,7 @@ def evaluate(self, step):
15176 pred_test_filename += '.batch_{}.pred.jsonl' .format (step )
15277
15378 predictor .eval_predict (self .args .test_filename , pred_test_filename ,
154- self .args .batch_size , self .model , True )
79+ self .args .eval_batch_size , self .model , True )
15580 test_macro_ret = self .macro_evaluator .evaluate (pred_test_filename )
15681 for n , counter in test_macro_ret .items ():
15782 for k , v in counter .items ():
@@ -161,8 +86,46 @@ def evaluate(self, step):
16186 # valid_micro_ret = self.micro_evaluator.evaluate(pred_test_filename)
16287 return valid_macro_ret [self .eval_topn [- 1 ]]['recall' ]
16388
164- def get_basename (self , filename ):
165- return os .path .splitext (os .path .basename (filename ))[0 ]
89+ def parse_args (self ):
90+ parser = argparse .ArgumentParser ()
91+ # train and evaluation parameter
92+ parser .add_argument ("-exp_name" , required = True , type = str , help = '' )
93+ parser .add_argument ("-train_filename" , required = True , type = str , help = '' )
94+ parser .add_argument ("-valid_filename" , required = True , type = str , help = '' )
95+ parser .add_argument ("-test_filename" , required = True , type = str , help = '' )
96+ parser .add_argument ("-dest_base_dir" , required = True , type = str , help = '' )
97+ parser .add_argument ("-vocab_path" , required = True , type = str , help = '' )
98+ parser .add_argument ("-vocab_size" , type = int , default = 500000 , help = '' )
99+ parser .add_argument ("-epochs" , type = int , default = 10 , help = '' )
100+ parser .add_argument ("-batch_size" , type = int , default = 64 , help = '' )
101+ parser .add_argument ("-learning_rate" , type = float , default = 1e-4 , help = '' )
102+ parser .add_argument ("-eval_batch_size" , type = int , default = 20 , help = '' )
103+ parser .add_argument ("-dropout" , type = float , default = 0.0 , help = '' )
104+ parser .add_argument ("-grad_norm" , type = float , default = 0.0 , help = '' )
105+ parser .add_argument ("-max_grad" , type = float , default = 5.0 , help = '' )
106+ parser .add_argument ("-shuffle_in_batch" , action = 'store_true' , help = '' )
107+ parser .add_argument ("-teacher_forcing" , action = 'store_true' , help = '' )
108+ parser .add_argument ("-beam_size" , type = float , default = 50 , help = '' )
109+ parser .add_argument ('-tensorboard_dir' , type = str , default = '' , help = '' )
110+ parser .add_argument ('-logfile' , type = str , default = 'train_log.log' , help = '' )
111+ parser .add_argument ('-save_model_step' , type = int , default = 5000 , help = '' )
112+ parser .add_argument ('-early_stop_tolerance' , type = int , default = 50 , help = '' )
113+ parser .add_argument ('-train_parallel' , action = 'store_true' , help = '' )
114+
115+ # model specific parameter
116+ parser .add_argument ("-embed_size" , type = int , default = 200 , help = '' )
117+ parser .add_argument ("-max_oov_count" , type = int , default = 100 , help = '' )
118+ parser .add_argument ("-max_src_len" , type = int , default = 1500 , help = '' )
119+ parser .add_argument ("-max_target_len" , type = int , default = 8 , help = '' )
120+ parser .add_argument ("-src_hidden_size" , type = int , default = 100 , help = '' )
121+ parser .add_argument ("-target_hidden_size" , type = int , default = 100 , help = '' )
122+ parser .add_argument ('-src_num_layers' , type = int , default = 1 , help = '' )
123+ parser .add_argument ('-target_num_layers' , type = int , default = 1 , help = '' )
124+ parser .add_argument ("-bidirectional" , action = 'store_true' , help = '' )
125+ parser .add_argument ("-copy_net" , action = 'store_true' , help = '' )
126+
127+ args = parser .parse_args ()
128+ return args
166129
167130
168131def accuracy (probs , true_indices , pad_idx ):
@@ -172,42 +135,5 @@ def accuracy(probs, true_indices, pad_idx):
172135 return torch .sum (tp_result ).numpy () / true_indices .numel ()
173136
174137
175- def main ():
176- parser = argparse .ArgumentParser ()
177- parser .add_argument ("-exp_name" , required = True , type = str , help = '' )
178- parser .add_argument ("-train_filename" , required = True , type = str , help = '' )
179- parser .add_argument ("-valid_filename" , required = True , type = str , help = '' )
180- parser .add_argument ("-test_filename" , required = True , type = str , help = '' )
181- parser .add_argument ("-dest_base_dir" , required = True , type = str , help = '' )
182- parser .add_argument ("-vocab_path" , required = True , type = str , help = '' )
183- parser .add_argument ("-vocab_size" , type = int , default = 500000 , help = '' )
184- parser .add_argument ("-embed_size" , type = int , default = 200 , help = '' )
185- parser .add_argument ("-max_oov_count" , type = int , default = 100 , help = '' )
186- parser .add_argument ("-max_src_len" , type = int , default = 1500 , help = '' )
187- parser .add_argument ("-max_target_len" , type = int , default = 8 , help = '' )
188- parser .add_argument ("-src_hidden_size" , type = int , default = 100 , help = '' )
189- parser .add_argument ("-target_hidden_size" , type = int , default = 100 , help = '' )
190- parser .add_argument ('-src_num_layers' , type = int , default = 1 , help = '' )
191- parser .add_argument ('-target_num_layers' , type = int , default = 1 , help = '' )
192- parser .add_argument ("-epochs" , type = int , default = 10 , help = '' )
193- parser .add_argument ("-batch_size" , type = int , default = 50 , help = '' )
194- parser .add_argument ("-eval_batch_size" , type = int , default = 10 , help = '' )
195- parser .add_argument ("-dropout" , type = float , default = 0.5 , help = '' )
196- parser .add_argument ("-grad_norm" , type = float , default = 0.0 , help = '' )
197- parser .add_argument ("-max_grad" , type = float , default = 10.0 , help = '' )
198- parser .add_argument ("-shuffle_in_batch" , action = 'store_true' , help = '' )
199- parser .add_argument ("-bidirectional" , default = True , action = 'store_true' , help = '' )
200- parser .add_argument ("-use_vanilla_rnn_search" , default = False , action = 'store_true' , help = '' )
201- parser .add_argument ("-teacher_forcing" , action = 'store_true' , help = '' )
202- parser .add_argument ("-beam_size" , type = float , default = 50 , help = '' )
203- parser .add_argument ('-tensorboard_dir' , type = str , default = '' , help = '' )
204- parser .add_argument ('-logfile' , type = str , default = 'train_log.log' , help = '' )
205- parser .add_argument ('-save_model_step' , type = int , default = 5000 , help = '' )
206- parser .add_argument ('-early_stop_tolerance' , type = int , default = 50 , help = '' )
207- parser .add_argument ('-train_parallel' , action = 'store_true' , help = '' )
208- args = parser .parse_args ()
209- Trainer (args ).train ()
210-
211-
212138if __name__ == '__main__' :
213- main ()
139+ CopyRnnTrainer (). train ()
0 commit comments