Skip to content

Commit 597a72f

Browse files
implement basic CopyTransformer trainer logic
1 parent 732d86b commit 597a72f

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: UTF-8 -*-
2+
import argparse
3+
import torch
4+
from deep_keyphrase.base_trainer import BaseTrainer
5+
from deep_keyphrase.copy_transformer.model import CopyTransformer
6+
from deep_keyphrase.dataloader import (TARGET, TOKENS)
7+
from deep_keyphrase.utils.vocab_loader import load_vocab
8+
9+
10+
class CopyTransformerTrainer(BaseTrainer):
11+
def __init__(self):
12+
args = self.parse_args()
13+
vocab2id = load_vocab(args.vocab_path, vocab_size=args.vocab_size)
14+
model = CopyTransformer(args, vocab2id)
15+
super().__init__(args, model)
16+
17+
def train_batch(self, batch):
18+
torch.autograd.set_detect_anomaly(True)
19+
loss = 0
20+
self.optimizer.zero_grad()
21+
targets = batch[TARGET]
22+
if torch.cuda.is_available():
23+
targets = targets.cuda()
24+
batch_size = len(batch[TOKENS])
25+
encoder_output = encoder_mask = None
26+
prev_copy_state = None
27+
prev_decoder_state = torch.zeros(batch_size, self.args.input_dim)
28+
for target_index in range(self.args.max_target_len):
29+
prev_output_tokens = targets[:, :target_index + 1].clone()
30+
true_indices = targets[:, target_index + 1].clone()
31+
output = self.model(src_dict=batch,
32+
prev_output_tokens=prev_output_tokens,
33+
encoder_output=encoder_output,
34+
encoder_mask=encoder_mask,
35+
prev_decoder_state=prev_decoder_state,
36+
position=target_index,
37+
prev_copy_state=prev_copy_state)
38+
probs, prev_decoder_state, prev_copy_state, encoder_output, encoder_mask = output
39+
loss += self.loss_func(probs, true_indices)
40+
41+
loss.backward()
42+
self.optimizer.step()
43+
# torch.cuda.empty_cache()
44+
return loss
45+
46+
def evaluate(self, step):
47+
pass
48+
49+
def parse_args(self):
50+
parser = argparse.ArgumentParser()
51+
# train and evaluation parameter
52+
parser.add_argument("-exp_name", required=True, type=str, help='')
53+
parser.add_argument("-train_filename", required=True, type=str, help='')
54+
parser.add_argument("-valid_filename", required=True, type=str, help='')
55+
parser.add_argument("-test_filename", required=True, type=str, help='')
56+
parser.add_argument("-dest_base_dir", required=True, type=str, help='')
57+
parser.add_argument("-vocab_path", required=True, type=str, help='')
58+
parser.add_argument("-vocab_size", type=int, default=500000, help='')
59+
parser.add_argument("-epochs", type=int, default=10, help='')
60+
parser.add_argument("-batch_size", type=int, default=12, help='')
61+
parser.add_argument("-learning_rate", type=float, default=1e-4, help='')
62+
parser.add_argument("-eval_batch_size", type=int, default=1, help='')
63+
parser.add_argument("-dropout", type=float, default=0.0, help='')
64+
parser.add_argument("-grad_norm", type=float, default=0.0, help='')
65+
parser.add_argument("-max_grad", type=float, default=5.0, help='')
66+
parser.add_argument("-shuffle_in_batch", action='store_true', help='')
67+
parser.add_argument("-teacher_forcing", action='store_true', help='')
68+
parser.add_argument("-beam_size", type=float, default=50, help='')
69+
parser.add_argument('-tensorboard_dir', type=str, default='', help='')
70+
parser.add_argument('-logfile', type=str, default='train_log.log', help='')
71+
parser.add_argument('-save_model_step', type=int, default=5000, help='')
72+
parser.add_argument('-early_stop_tolerance', type=int, default=50, help='')
73+
parser.add_argument('-train_parallel', action='store_true', help='')
74+
75+
# model specific parameter
76+
parser.add_argument("-input_dim", type=int, default=256, help='')
77+
parser.add_argument("-src_head_size", type=int, default=4, help='')
78+
parser.add_argument("-target_head_size", type=int, default=4, help='')
79+
parser.add_argument("-feed_forward_dim", type=int, default=1024, help='')
80+
parser.add_argument("-src_dropout", type=int, default=0.1, help='')
81+
parser.add_argument("-target_dropout", type=int, default=0.1, help='')
82+
parser.add_argument("-src_layers", type=int, default=6, help='')
83+
parser.add_argument("-target_layers", type=int, default=6, help='')
84+
parser.add_argument("-max_src_len", type=int, default=1000, help='')
85+
parser.add_argument("-max_target_len", type=int, default=8, help='')
86+
parser.add_argument("-max_oov_count", type=int, default=100, help='')
87+
parser.add_argument("-copy_net", action='store_true', help='')
88+
89+
args = parser.parse_args()
90+
return args
91+
92+
93+
if __name__ == '__main__':
94+
CopyTransformerTrainer().train()

0 commit comments

Comments
 (0)