Skip to content

Commit 732d86b

Browse files
cmd parameter from function to class method, add more hints in logs
1 parent 4a23039 commit 732d86b

File tree

1 file changed

+52
-126
lines changed

1 file changed

+52
-126
lines changed

deep_keyphrase/copy_rnn/train.py

Lines changed: 52 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,19 @@
11
# -*- coding: UTF-8 -*-
22
import argparse
3-
import time
4-
import traceback
5-
import os
63
import torch
7-
import torch.nn as nn
8-
import torch.optim as optim
9-
from torch.utils.tensorboard import SummaryWriter
10-
from pysenal import write_json, get_logger
4+
from pysenal import write_json
115
from deep_keyphrase.utils.vocab_loader import load_vocab
126
from deep_keyphrase.copy_rnn.model import CopyRNN
7+
from deep_keyphrase.base_trainer import BaseTrainer
138
from deep_keyphrase.dataloader import (KeyphraseDataLoader, TOKENS, TARGET)
14-
from deep_keyphrase.evaluation import KeyphraseEvaluator
159
from deep_keyphrase.copy_rnn.predict import CopyRnnPredictor
16-
from deep_keyphrase.utils.constants import PAD_WORD
1710

1811

19-
class Trainer(object):
20-
def __init__(self, args):
21-
self.args = args
22-
self.vocab2id = load_vocab(args.vocab_path, args.vocab_size)
23-
24-
self.model = CopyRNN(args, self.vocab2id)
25-
if torch.cuda.is_available():
26-
self.model = self.model.cuda()
27-
if args.train_parallel:
28-
self.model = nn.DataParallel(self.model)
29-
self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD])
30-
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
31-
self.logger = get_logger('train')
32-
self.train_loader = KeyphraseDataLoader(args.train_filename,
33-
self.vocab2id,
34-
args.batch_size,
35-
args.max_src_len,
36-
args.max_oov_count,
37-
args.max_target_len,
38-
'train')
39-
timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
40-
self.dest_dir = os.path.join(args.dest_base_dir, args.exp_name + '-' + timemark) + '/'
41-
os.mkdir(self.dest_dir)
42-
if not args.tensorboard_dir:
43-
tensorboard_dir = self.dest_dir + 'logs/'
44-
else:
45-
tensorboard_dir = args.tensorboard_dir
46-
self.writer = SummaryWriter(tensorboard_dir)
47-
self.eval_topn = (5, 10)
48-
self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro')
49-
self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro')
50-
self.best_f1 = None
51-
self.not_update_count = 0
52-
53-
def train(self):
54-
step = 0
55-
is_stop = False
56-
self.logger.info('destination dir:{}'.format(self.dest_dir))
57-
for epoch in range(1, self.args.epochs + 1):
58-
for batch_idx, batch in enumerate(self.train_loader):
59-
self.model.train()
60-
try:
61-
loss = self.train_batch(batch)
62-
except Exception as e:
63-
err_stack = traceback.format_exc()
64-
self.logger.error(err_stack)
65-
loss = 0.0
66-
step += 1
67-
self.writer.add_scalar('loss', loss, step)
68-
if step and step % self.args.save_model_step == 0:
69-
self.save_model(step, epoch)
70-
if self.not_update_count >= self.args.early_stop_tolerance:
71-
is_stop = True
72-
break
73-
if is_stop:
74-
break
12+
class CopyRnnTrainer(BaseTrainer):
13+
def __init__(self):
14+
args = self.parse_args()
15+
model = CopyRNN(args, load_vocab(args.vocab_path))
16+
super().__init__(args, model)
7517

7618
def train_batch(self, batch):
7719
loss = 0
@@ -81,10 +23,7 @@ def train_batch(self, batch):
8123
targets = targets.cuda()
8224
batch_size = len(batch[TOKENS])
8325
encoder_output = None
84-
if self.args.bidirectional:
85-
decoder_state = torch.zeros(batch_size, self.args.target_hidden_size * 2)
86-
else:
87-
decoder_state = torch.zeros(batch_size, self.args.target_hidden_size)
26+
decoder_state = torch.zeros(batch_size, self.args.target_hidden_size)
8827
hidden_state = None
8928
for target_index in range(self.args.max_target_len):
9029
if target_index == 0:
@@ -94,7 +33,8 @@ def train_batch(self, batch):
9433
if self.args.teacher_forcing:
9534
prev_output_tokens = targets[:, target_index].unsqueeze(1)
9635
else:
97-
prev_output_tokens = torch.topk(decoder_prob, 1, 1)
36+
best_probs, prev_output_tokens = torch.topk(decoder_prob, 1, 1)
37+
9838
output = self.model(src_dict=batch,
9939
prev_output_tokens=prev_output_tokens,
10040
encoder_output_dict=encoder_output,
@@ -115,21 +55,6 @@ def train_batch(self, batch):
11555
self.optimizer.step()
11656
return loss
11757

118-
def save_model(self, step, epoch):
119-
valid_f1 = self.evaluate(step)
120-
if self.best_f1 is None:
121-
self.best_f1 = valid_f1
122-
elif valid_f1 >= self.best_f1:
123-
self.best_f1 = valid_f1
124-
self.not_update_count = 0
125-
else:
126-
self.not_update_count += 1
127-
exp_name = self.args.exp_name
128-
model_basename = self.dest_dir + '{}_epoch_{}_batch_{}'.format(exp_name, epoch, step)
129-
torch.save(self.model.state_dict(), model_basename + '.model')
130-
write_json(model_basename + '.json', vars(self.args))
131-
self.logger.info('epoch {} step {}, model saved'.format(epoch, step))
132-
13358
def evaluate(self, step):
13459
predictor = CopyRnnPredictor(model_info={'model': self.model, 'config': self.args},
13560
vocab_info=self.vocab2id,
@@ -151,7 +76,7 @@ def evaluate(self, step):
15176
pred_test_filename += '.batch_{}.pred.jsonl'.format(step)
15277

15378
predictor.eval_predict(self.args.test_filename, pred_test_filename,
154-
self.args.batch_size, self.model, True)
79+
self.args.eval_batch_size, self.model, True)
15580
test_macro_ret = self.macro_evaluator.evaluate(pred_test_filename)
15681
for n, counter in test_macro_ret.items():
15782
for k, v in counter.items():
@@ -161,8 +86,46 @@ def evaluate(self, step):
16186
# valid_micro_ret = self.micro_evaluator.evaluate(pred_test_filename)
16287
return valid_macro_ret[self.eval_topn[-1]]['recall']
16388

164-
def get_basename(self, filename):
165-
return os.path.splitext(os.path.basename(filename))[0]
89+
def parse_args(self):
90+
parser = argparse.ArgumentParser()
91+
# train and evaluation parameter
92+
parser.add_argument("-exp_name", required=True, type=str, help='')
93+
parser.add_argument("-train_filename", required=True, type=str, help='')
94+
parser.add_argument("-valid_filename", required=True, type=str, help='')
95+
parser.add_argument("-test_filename", required=True, type=str, help='')
96+
parser.add_argument("-dest_base_dir", required=True, type=str, help='')
97+
parser.add_argument("-vocab_path", required=True, type=str, help='')
98+
parser.add_argument("-vocab_size", type=int, default=500000, help='')
99+
parser.add_argument("-epochs", type=int, default=10, help='')
100+
parser.add_argument("-batch_size", type=int, default=64, help='')
101+
parser.add_argument("-learning_rate", type=float, default=1e-4, help='')
102+
parser.add_argument("-eval_batch_size", type=int, default=20, help='')
103+
parser.add_argument("-dropout", type=float, default=0.0, help='')
104+
parser.add_argument("-grad_norm", type=float, default=0.0, help='')
105+
parser.add_argument("-max_grad", type=float, default=5.0, help='')
106+
parser.add_argument("-shuffle_in_batch", action='store_true', help='')
107+
parser.add_argument("-teacher_forcing", action='store_true', help='')
108+
parser.add_argument("-beam_size", type=float, default=50, help='')
109+
parser.add_argument('-tensorboard_dir', type=str, default='', help='')
110+
parser.add_argument('-logfile', type=str, default='train_log.log', help='')
111+
parser.add_argument('-save_model_step', type=int, default=5000, help='')
112+
parser.add_argument('-early_stop_tolerance', type=int, default=50, help='')
113+
parser.add_argument('-train_parallel', action='store_true', help='')
114+
115+
# model specific parameter
116+
parser.add_argument("-embed_size", type=int, default=200, help='')
117+
parser.add_argument("-max_oov_count", type=int, default=100, help='')
118+
parser.add_argument("-max_src_len", type=int, default=1500, help='')
119+
parser.add_argument("-max_target_len", type=int, default=8, help='')
120+
parser.add_argument("-src_hidden_size", type=int, default=100, help='')
121+
parser.add_argument("-target_hidden_size", type=int, default=100, help='')
122+
parser.add_argument('-src_num_layers', type=int, default=1, help='')
123+
parser.add_argument('-target_num_layers', type=int, default=1, help='')
124+
parser.add_argument("-bidirectional", action='store_true', help='')
125+
parser.add_argument("-copy_net", action='store_true', help='')
126+
127+
args = parser.parse_args()
128+
return args
166129

167130

168131
def accuracy(probs, true_indices, pad_idx):
@@ -172,42 +135,5 @@ def accuracy(probs, true_indices, pad_idx):
172135
return torch.sum(tp_result).numpy() / true_indices.numel()
173136

174137

175-
def main():
176-
parser = argparse.ArgumentParser()
177-
parser.add_argument("-exp_name", required=True, type=str, help='')
178-
parser.add_argument("-train_filename", required=True, type=str, help='')
179-
parser.add_argument("-valid_filename", required=True, type=str, help='')
180-
parser.add_argument("-test_filename", required=True, type=str, help='')
181-
parser.add_argument("-dest_base_dir", required=True, type=str, help='')
182-
parser.add_argument("-vocab_path", required=True, type=str, help='')
183-
parser.add_argument("-vocab_size", type=int, default=500000, help='')
184-
parser.add_argument("-embed_size", type=int, default=200, help='')
185-
parser.add_argument("-max_oov_count", type=int, default=100, help='')
186-
parser.add_argument("-max_src_len", type=int, default=1500, help='')
187-
parser.add_argument("-max_target_len", type=int, default=8, help='')
188-
parser.add_argument("-src_hidden_size", type=int, default=100, help='')
189-
parser.add_argument("-target_hidden_size", type=int, default=100, help='')
190-
parser.add_argument('-src_num_layers', type=int, default=1, help='')
191-
parser.add_argument('-target_num_layers', type=int, default=1, help='')
192-
parser.add_argument("-epochs", type=int, default=10, help='')
193-
parser.add_argument("-batch_size", type=int, default=50, help='')
194-
parser.add_argument("-eval_batch_size", type=int, default=10, help='')
195-
parser.add_argument("-dropout", type=float, default=0.5, help='')
196-
parser.add_argument("-grad_norm", type=float, default=0.0, help='')
197-
parser.add_argument("-max_grad", type=float, default=10.0, help='')
198-
parser.add_argument("-shuffle_in_batch", action='store_true', help='')
199-
parser.add_argument("-bidirectional", default=True, action='store_true', help='')
200-
parser.add_argument("-use_vanilla_rnn_search", default=False, action='store_true', help='')
201-
parser.add_argument("-teacher_forcing", action='store_true', help='')
202-
parser.add_argument("-beam_size", type=float, default=50, help='')
203-
parser.add_argument('-tensorboard_dir', type=str, default='', help='')
204-
parser.add_argument('-logfile', type=str, default='train_log.log', help='')
205-
parser.add_argument('-save_model_step', type=int, default=5000, help='')
206-
parser.add_argument('-early_stop_tolerance', type=int, default=50, help='')
207-
parser.add_argument('-train_parallel', action='store_true', help='')
208-
args = parser.parse_args()
209-
Trainer(args).train()
210-
211-
212138
if __name__ == '__main__':
213-
main()
139+
CopyRnnTrainer().train()

0 commit comments

Comments
 (0)