Skip to content

Commit e85515c

Browse files
start implement CopyCNN
1 parent 5f6971f commit e85515c

File tree

4 files changed

+193
-12
lines changed

4 files changed

+193
-12
lines changed
Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
1-
# -*- coding: UTF-8 -*-
1+
# -*- coding: UTF-8 -*-
2+
3+
4+
class CopyCnnBeamSearch(object):
5+
def __init__(self):
6+
pass
7+
8+
def beam_search(self):
9+
pass
10+
11+
def greedy_search(self):
12+
pass

deep_keyphrase/copy_cnn/model.py

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,121 @@
11
# -*- coding: UTF-8 -*-
22
import torch
33
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from deep_keyphrase.dataloader import (TOKENS, TOKENS_LENS, TARGET)
46

57

6-
class CopyCnnModel(nn.Module):
7-
def __init__(self):
8+
class Attention(nn.Module):
9+
"""
10+
11+
"""
12+
13+
def __init__(self, dim_size):
814
super().__init__()
15+
self.in_proj = nn.Linear(dim_size, dim_size)
916

10-
def forward(self):
17+
def forward(self, x, target_embedding, encoder_input, encoder_output, encoder_mask):
1118
pass
1219

1320

21+
class CopyCnn(nn.Module):
22+
def __init__(self, args, vocab2id):
23+
super().__init__()
24+
self.args = args
25+
self.vocab2id = vocab2id
26+
self.embedding = nn.Embedding(len(vocab2id), args.dim_size)
27+
self.encoder = CopyCnnEncoder(vocab2id=vocab2id, embedding=self.embedding, args=args)
28+
self.decoder = CopyCnnDecoder(vocab2id=vocab2id, embedding=self.embedding, args=args)
29+
30+
def forward(self, src_dict, encoder_output):
31+
if encoder_output is None:
32+
encoder_output = self.encoder(src_dict)
33+
34+
1435
class CopyCnnEncoder(nn.Module):
15-
def __init__(self):
36+
def __init__(self, vocab2id, embedding, args):
1637
super().__init__()
38+
self.vocab2id = vocab2id
39+
self.embedding = embedding
40+
self.args = args
41+
self.dim_size = args.dim_size
42+
self.kernel_size = (args.kernal_width, self.dim_size)
43+
self.dropout = args.dropout
44+
self.convolution_layers = []
45+
for i in range(args.encoder_layer_num):
46+
layer = nn.Conv2d(in_channels=1, out_channels=2 * self.dim_size,
47+
kernel_size=self.kernel_size, bias=True)
48+
self.convolution_layers.append(layer)
1749

1850
def forward(self, src_dict):
19-
pass
51+
tokens = src_dict[TOKENS]
52+
x = self.embedding(tokens).unsqueeze(1)
53+
# x = tokens.unsqueeze(1)
54+
layer_output = [x]
55+
for layer in self.convolution_layers:
56+
x = F.dropout(x, p=self.dropout, training=self.training)
57+
x = layer(x)
58+
x = F.glu(x, dim=1) + layer_output[-1]
59+
layer_output.append(x)
60+
return x
2061

2162

2263
class CopyCnnDecoder(nn.Module):
23-
def __init__(self):
64+
def __init__(self, vocab2id, embedding, args):
2465
super().__init__()
66+
self.vocab2id = vocab2id
67+
self.embedding = embedding
68+
self.args = args
69+
self.vocab_size = self.args.vocab_size
70+
self.max_oov_count = self.args.max_oov_count
71+
self.total_vocab_size = self.vocab_size + self.max_oov_count
72+
self.dim_size = args.dim_size
73+
self.kernel_size = (args.kernal_width, self.dim_size)
74+
self.dropout = args.dropout
75+
self.convolution_layers = []
76+
self.attn_linear_layers = []
77+
self.decoder_layer_num = args.decoder_layer_num
78+
for i in range(self.decoder_layer_num):
79+
conv_layer = nn.Conv2d(in_channels=1, out_channels=2 * self.dim_size,
80+
kernel_size=self.kernel_size, bias=True)
81+
self.convolution_layers.append(conv_layer)
82+
attn_linear_layer = nn.Linear(self.dim_size, self.dim_size, bias=True)
83+
self.attn_linear_layers.append(attn_linear_layer)
84+
self.generate_proj = nn.Linear(self.dim_size, self.vocab_size)
85+
self.copy_proj = nn.Linear(self.dim_size, self.total_vocab_size)
86+
87+
def forward(self, src_dict, prev_tokens, encoder_output):
88+
"""
89+
90+
:param src_dict:
91+
:param prev_tokens:
92+
:param encoder_output:
93+
:return:
94+
"""
95+
src_tokens = src_dict[TOKENS]
96+
tokens = src_dict[TARGET][:, :-1]
97+
x = self.embedding(tokens).unsqueeze(1)
98+
prev_x = self.embedding(prev_tokens)
99+
src_x = self.embedding(src_tokens)
100+
layer_output = [x]
101+
for conv_layer, linear_layer in zip(self.convolution_layers, self.attn_linear_layers):
102+
x = F.dropout(x, p=self.dropout, training=self.training)
103+
x = conv_layer(x)
104+
x = F.glu(x, dim=1) + layer_output[-1]
105+
# attention
106+
d = linear_layer(x) + prev_x
107+
attn_weights = torch.softmax(torch.bmm(encoder_output, d.unsqueeze(2)), dim=1)
108+
c = attn_weights * (encoder_output + src_x)
109+
# residual connection
110+
final_output = x + c.squeeze(2)
111+
layer_output.append(final_output)
112+
generate_logits = self.generate_proj(layer_output[-1])
113+
114+
def forward_one_pass(self):
115+
pass
116+
117+
def forward_auto_regressive(self):
118+
pass
25119

26-
def forward(self):
120+
def get_attn_read(self, encoder_output, src_tokens_with_oov, decoder_output, encoder_output_mask):
27121
pass

deep_keyphrase/copy_cnn/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ def __init__(self, model_info):
99
def predict(self, input_list, batch_size, delimiter=''):
1010
pass
1111

12-
def eval_predict(self):
12+
def eval_predict(self, src_filename, dest_filename, args,
13+
model=None, remove_existed=False):
1314
pass

deep_keyphrase/copy_cnn/train.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# -*- coding: UTF-8 -*-
2+
import os
23
import argparse
4+
import torch
5+
from collections import OrderedDict
6+
from munch import Munch
7+
from pysenal import read_json
38
from deep_keyphrase.base_trainer import BaseTrainer
49
from deep_keyphrase.utils.vocab_loader import load_vocab
10+
from deep_keyphrase.copy_cnn.model import CopyCnn
511

612

713
class CopyCnnTrainer(BaseTrainer):
@@ -12,16 +18,85 @@ def __init__(self):
1218
super().__init__(self.args, model)
1319

1420
def load_model(self):
15-
pass
21+
if not self.args.train_from:
22+
model = CopyCnn(self.args, self.vocab2id)
23+
else:
24+
model_path = self.args.train_from
25+
config_path = os.path.join(os.path.dirname(model_path),
26+
self.get_basename(model_path) + '.json')
27+
28+
old_config = read_json(config_path)
29+
old_config['train_from'] = model_path
30+
old_config['step'] = int(model_path.rsplit('_', 1)[-1].split('.')[0])
31+
self.args = Munch(old_config)
32+
self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)
33+
34+
model = CopyCnn(self.args, self.vocab2id)
35+
36+
if torch.cuda.is_available():
37+
checkpoint = torch.load(model_path)
38+
else:
39+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
40+
state_dict = OrderedDict()
41+
# avoid error when load parallel trained model
42+
for k, v in checkpoint.items():
43+
if k.startswith('module.'):
44+
k = k[7:]
45+
state_dict[k] = v
46+
model.load_state_dict(state_dict)
47+
48+
return model
1649

1750
def train_batch(self, batch, step):
18-
pass
51+
self.model.train()
52+
loss = 0
53+
self.optimizer.zero_grad()
1954

2055
def evaluate(self, step):
2156
pass
2257

2358
def parse_args(self, args=None):
2459
parser = argparse.ArgumentParser()
25-
# parser.add_argument()
60+
parser.add_argument("-exp_name", required=True, type=str, help='')
61+
parser.add_argument("-train_filename", required=True, type=str, help='')
62+
parser.add_argument("-valid_filename", required=True, type=str, help='')
63+
parser.add_argument("-test_filename", required=True, type=str, help='')
64+
parser.add_argument("-dest_base_dir", required=True, type=str, help='')
65+
parser.add_argument("-vocab_path", required=True, type=str, help='')
66+
parser.add_argument("-vocab_size", type=int, default=500000, help='')
67+
parser.add_argument("-train_from", default='', type=str, help='')
68+
parser.add_argument("-token_field", default='tokens', type=str, help='')
69+
parser.add_argument("-keyphrase_field", default='keyphrases', type=str, help='')
70+
# parser.add_argument("-auto_regressive", action='store_true', help='')
71+
parser.add_argument("-epochs", type=int, default=10, help='')
72+
parser.add_argument("-batch_size", type=int, default=64, help='')
73+
parser.add_argument("-learning_rate", type=float, default=1e-4, help='')
74+
parser.add_argument("-eval_batch_size", type=int, default=50, help='')
75+
parser.add_argument("-dropout", type=float, default=0.0, help='')
76+
parser.add_argument("-grad_norm", type=float, default=0.0, help='')
77+
parser.add_argument("-max_grad", type=float, default=5.0, help='')
78+
parser.add_argument("-shuffle", action='store_true', help='')
79+
# parser.add_argument("-teacher_forcing", action='store_true', help='')
80+
parser.add_argument("-beam_size", type=float, default=50, help='')
81+
parser.add_argument('-tensorboard_dir', type=str, default='', help='')
82+
parser.add_argument('-logfile', type=str, default='train_log.log', help='')
83+
parser.add_argument('-save_model_step', type=int, default=5000, help='')
84+
parser.add_argument('-early_stop_tolerance', type=int, default=100, help='')
85+
parser.add_argument('-train_parallel', action='store_true', help='')
86+
# parser.add_argument('-schedule_lr', action='store_true', help='')
87+
# parser.add_argument('-schedule_step', type=int, default=100000, help='')
88+
# parser.add_argument('-schedule_gamma', type=float, default=0.5, help='')
89+
# parser.add_argument('-processed', action='store_true', help='')
90+
parser.add_argument('-prefetch', action='store_true', help='')
91+
92+
parser.add_argument('-dim_size', type=int, default=100, help='')
93+
parser.add_argument('-kernel_width', type=int, default=5, help='')
94+
parser.add_argument('-encoder_layer_num', type=int, default=6, help='')
95+
parser.add_argument('-decoder_layer_num', type=int, default=6, help='')
96+
2697
args = parser.parse_args(args)
2798
return args
99+
100+
101+
if __name__ == '__main__':
102+
CopyCnnTrainer().train()

0 commit comments

Comments
 (0)