Skip to content

Commit e174a5d

Browse files
implement BaseTrainer
1 parent 5e8f982 commit e174a5d

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

deep_keyphrase/base_trainer.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: UTF-8 -*-
2+
import time
3+
import traceback
4+
import os
5+
import gc
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
from torch.utils.tensorboard import SummaryWriter
10+
from pysenal import write_json, get_logger
11+
from deep_keyphrase.utils.vocab_loader import load_vocab
12+
from deep_keyphrase.dataloader import KeyphraseDataLoader
13+
from deep_keyphrase.evaluation import KeyphraseEvaluator
14+
from deep_keyphrase.utils.constants import PAD_WORD
15+
16+
17+
class BaseTrainer(object):
18+
def __init__(self, args, model):
19+
self.args = args
20+
self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)
21+
22+
self.model = model
23+
if torch.cuda.is_available():
24+
self.model = self.model.cuda()
25+
if args.train_parallel:
26+
self.model = nn.DataParallel(self.model)
27+
self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD])
28+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
29+
self.logger = get_logger('train')
30+
self.train_loader = KeyphraseDataLoader(self.args.train_filename,
31+
self.vocab2id,
32+
self.args.batch_size,
33+
self.args.max_src_len,
34+
self.args.max_oov_count,
35+
self.args.max_target_len,
36+
'train')
37+
timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
38+
self.dest_dir = os.path.join(self.args.dest_base_dir, self.args.exp_name + '-' + timemark) + '/'
39+
os.mkdir(self.dest_dir)
40+
if not self.args.tensorboard_dir:
41+
tensorboard_dir = self.dest_dir + 'logs/'
42+
else:
43+
tensorboard_dir = self.args.tensorboard_dir
44+
self.writer = SummaryWriter(tensorboard_dir)
45+
self.eval_topn = (5, 10)
46+
self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro')
47+
self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro')
48+
self.best_f1 = None
49+
self.best_step = None
50+
self.not_update_count = 0
51+
52+
def parse_args(self):
53+
raise NotImplementedError('build_parser is not implemented')
54+
55+
def train(self):
56+
step = 0
57+
is_stop = False
58+
self.logger.info('destination dir:{}'.format(self.dest_dir))
59+
for epoch in range(1, self.args.epochs + 1):
60+
for batch_idx, batch in enumerate(self.train_loader):
61+
self.model.train()
62+
try:
63+
loss = self.train_batch(batch)
64+
except Exception as e:
65+
err_stack = traceback.format_exc()
66+
self.logger.error(err_stack)
67+
loss = 0.0
68+
step += 1
69+
self.writer.add_scalar('loss', loss, step)
70+
del loss
71+
gc.collect()
72+
if step and step % self.args.save_model_step == 0:
73+
torch.cuda.empty_cache()
74+
self.evaluate_and_save_model(step, epoch)
75+
if self.not_update_count >= self.args.early_stop_tolerance:
76+
is_stop = True
77+
break
78+
if is_stop:
79+
self.logger.info('best step {}'.format(self.best_step))
80+
break
81+
82+
def train_batch(self, batch):
83+
raise NotImplementedError('train method is not implemented')
84+
85+
def evaluate_and_save_model(self, step, epoch):
86+
valid_f1 = self.evaluate(step)
87+
if self.best_f1 is None:
88+
self.best_f1 = valid_f1
89+
self.best_step = step
90+
elif valid_f1 >= self.best_f1:
91+
self.best_f1 = valid_f1
92+
self.not_update_count = 0
93+
self.best_step = step
94+
else:
95+
self.not_update_count += 1
96+
exp_name = self.args.exp_name
97+
model_basename = self.dest_dir + '{}_epoch_{}_batch_{}'.format(exp_name, epoch, step)
98+
torch.save(self.model.state_dict(), model_basename + '.model')
99+
write_json(model_basename + '.json', vars(self.args))
100+
self.logger.info('epoch {} step {}, model saved'.format(epoch, step))
101+
102+
def evaluate(self, step):
103+
raise NotImplementedError('evaluate method is not implemented')
104+
105+
def get_basename(self, filename):
106+
return os.path.splitext(os.path.basename(filename))[0]

0 commit comments

Comments
 (0)