Skip to content

Commit 4a62a03

Browse files
add incremental train
1 parent 433fff1 commit 4a62a03

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

deep_keyphrase/base_trainer.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: UTF-8 -*-
22
import time
33
import traceback
4+
import logging
45
import os
56
import gc
67
import torch
@@ -16,6 +17,7 @@
1617

1718
class BaseTrainer(object):
1819
def __init__(self, args, model):
20+
torch.manual_seed(0)
1921
self.args = args
2022
self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)
2123

@@ -34,9 +36,18 @@ def __init__(self, args, model):
3436
self.args.max_oov_count,
3537
self.args.max_target_len,
3638
'train')
37-
timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
38-
self.dest_dir = os.path.join(self.args.dest_base_dir, self.args.exp_name + '-' + timemark) + '/'
39-
os.mkdir(self.dest_dir)
39+
if self.args.train_from:
40+
self.dest_dir = os.path.dirname(self.args.train_from) + '/'
41+
else:
42+
timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
43+
self.dest_dir = os.path.join(self.args.dest_base_dir, self.args.exp_name + '-' + timemark) + '/'
44+
os.mkdir(self.dest_dir)
45+
46+
fh = logging.FileHandler(os.path.join(self.dest_dir, args.logfile))
47+
fh.setLevel(logging.INFO)
48+
fh.setFormatter(logging.Formatter('[%(asctime)s] %(message)s'))
49+
self.logger.addHandler(fh)
50+
4051
if not self.args.tensorboard_dir:
4152
tensorboard_dir = self.dest_dir + 'logs/'
4253
else:
@@ -46,7 +57,7 @@ def __init__(self, args, model):
4657
self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro')
4758
self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro')
4859
self.best_f1 = None
49-
self.best_step = None
60+
self.best_step = 0
5061
self.not_update_count = 0
5162

5263
def parse_args(self):
@@ -55,7 +66,12 @@ def parse_args(self):
5566
def train(self):
5667
step = 0
5768
is_stop = False
58-
self.logger.info('destination dir:{}'.format(self.dest_dir))
69+
if self.args.train_from:
70+
step = self.args.step
71+
self.logger.info('train from destination dir:{}'.format(self.dest_dir))
72+
self.logger.info('train from step {}'.format(step))
73+
else:
74+
self.logger.info('destination dir:{}'.format(self.dest_dir))
5975
for epoch in range(1, self.args.epochs + 1):
6076
for batch_idx, batch in enumerate(self.train_loader):
6177
self.model.train()
@@ -97,6 +113,8 @@ def evaluate_and_save_model(self, step, epoch):
97113
model_basename = self.dest_dir + '{}_epoch_{}_batch_{}'.format(exp_name, epoch, step)
98114
torch.save(self.model.state_dict(), model_basename + '.model')
99115
write_json(model_basename + '.json', vars(self.args))
116+
score_msg_tmpl = 'best score: step {} macro f1@{} {:.4f}'
117+
self.logger.info(score_msg_tmpl.format(self.best_step, self.eval_topn[-1], self.best_f1))
100118
self.logger.info('epoch {} step {}, model saved'.format(epoch, step))
101119

102120
def evaluate(self, step):

0 commit comments

Comments
 (0)