11# -*- coding: UTF-8 -*-
2+ import os
23import argparse
4+ import torch
5+ from collections import OrderedDict
6+ from munch import Munch
7+ from pysenal import read_json
38from deep_keyphrase .base_trainer import BaseTrainer
49from deep_keyphrase .utils .vocab_loader import load_vocab
10+ from deep_keyphrase .copy_cnn .model import CopyCnn
511
612
713class CopyCnnTrainer (BaseTrainer ):
@@ -12,16 +18,85 @@ def __init__(self):
1218 super ().__init__ (self .args , model )
1319
1420 def load_model (self ):
15- pass
21+ if not self .args .train_from :
22+ model = CopyCnn (self .args , self .vocab2id )
23+ else :
24+ model_path = self .args .train_from
25+ config_path = os .path .join (os .path .dirname (model_path ),
26+ self .get_basename (model_path ) + '.json' )
27+
28+ old_config = read_json (config_path )
29+ old_config ['train_from' ] = model_path
30+ old_config ['step' ] = int (model_path .rsplit ('_' , 1 )[- 1 ].split ('.' )[0 ])
31+ self .args = Munch (old_config )
32+ self .vocab2id = load_vocab (self .args .vocab_path , self .args .vocab_size )
33+
34+ model = CopyCnn (self .args , self .vocab2id )
35+
36+ if torch .cuda .is_available ():
37+ checkpoint = torch .load (model_path )
38+ else :
39+ checkpoint = torch .load (model_path , map_location = torch .device ('cpu' ))
40+ state_dict = OrderedDict ()
41+ # avoid error when load parallel trained model
42+ for k , v in checkpoint .items ():
43+ if k .startswith ('module.' ):
44+ k = k [7 :]
45+ state_dict [k ] = v
46+ model .load_state_dict (state_dict )
47+
48+ return model
1649
1750 def train_batch (self , batch , step ):
18- pass
51+ self .model .train ()
52+ loss = 0
53+ self .optimizer .zero_grad ()
1954
2055 def evaluate (self , step ):
2156 pass
2257
2358 def parse_args (self , args = None ):
2459 parser = argparse .ArgumentParser ()
25- # parser.add_argument()
60+ parser .add_argument ("-exp_name" , required = True , type = str , help = '' )
61+ parser .add_argument ("-train_filename" , required = True , type = str , help = '' )
62+ parser .add_argument ("-valid_filename" , required = True , type = str , help = '' )
63+ parser .add_argument ("-test_filename" , required = True , type = str , help = '' )
64+ parser .add_argument ("-dest_base_dir" , required = True , type = str , help = '' )
65+ parser .add_argument ("-vocab_path" , required = True , type = str , help = '' )
66+ parser .add_argument ("-vocab_size" , type = int , default = 500000 , help = '' )
67+ parser .add_argument ("-train_from" , default = '' , type = str , help = '' )
68+ parser .add_argument ("-token_field" , default = 'tokens' , type = str , help = '' )
69+ parser .add_argument ("-keyphrase_field" , default = 'keyphrases' , type = str , help = '' )
70+ # parser.add_argument("-auto_regressive", action='store_true', help='')
71+ parser .add_argument ("-epochs" , type = int , default = 10 , help = '' )
72+ parser .add_argument ("-batch_size" , type = int , default = 64 , help = '' )
73+ parser .add_argument ("-learning_rate" , type = float , default = 1e-4 , help = '' )
74+ parser .add_argument ("-eval_batch_size" , type = int , default = 50 , help = '' )
75+ parser .add_argument ("-dropout" , type = float , default = 0.0 , help = '' )
76+ parser .add_argument ("-grad_norm" , type = float , default = 0.0 , help = '' )
77+ parser .add_argument ("-max_grad" , type = float , default = 5.0 , help = '' )
78+ parser .add_argument ("-shuffle" , action = 'store_true' , help = '' )
79+ # parser.add_argument("-teacher_forcing", action='store_true', help='')
80+ parser .add_argument ("-beam_size" , type = float , default = 50 , help = '' )
81+ parser .add_argument ('-tensorboard_dir' , type = str , default = '' , help = '' )
82+ parser .add_argument ('-logfile' , type = str , default = 'train_log.log' , help = '' )
83+ parser .add_argument ('-save_model_step' , type = int , default = 5000 , help = '' )
84+ parser .add_argument ('-early_stop_tolerance' , type = int , default = 100 , help = '' )
85+ parser .add_argument ('-train_parallel' , action = 'store_true' , help = '' )
86+ # parser.add_argument('-schedule_lr', action='store_true', help='')
87+ # parser.add_argument('-schedule_step', type=int, default=100000, help='')
88+ # parser.add_argument('-schedule_gamma', type=float, default=0.5, help='')
89+ # parser.add_argument('-processed', action='store_true', help='')
90+ parser .add_argument ('-prefetch' , action = 'store_true' , help = '' )
91+
92+ parser .add_argument ('-dim_size' , type = int , default = 100 , help = '' )
93+ parser .add_argument ('-kernel_width' , type = int , default = 5 , help = '' )
94+ parser .add_argument ('-encoder_layer_num' , type = int , default = 6 , help = '' )
95+ parser .add_argument ('-decoder_layer_num' , type = int , default = 6 , help = '' )
96+
2697 args = parser .parse_args (args )
2798 return args
99+
100+
101+ if __name__ == '__main__' :
102+ CopyCnnTrainer ().train ()
0 commit comments