Skip to content

Commit 43efdc2

Browse files
implement input feeding, add dropout after embedding layer, optimize call
1 parent 476b7b9 commit 43efdc2

File tree

1 file changed

+60
-43
lines changed

1 file changed

+60
-43
lines changed

deep_keyphrase/copy_rnn/model.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: UTF-8 -*-
2-
import torch.nn as nn
32
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
45
from deep_keyphrase.utils.constants import *
56
from deep_keyphrase.dataloader import TOKENS, TOKENS_OOV, TOKENS_LENS, OOV_COUNT
67

@@ -60,10 +61,6 @@ def __init__(self, args, vocab2id):
6061
super().__init__()
6162
src_hidden_size = args.src_hidden_size
6263
target_hidden_size = args.target_hidden_size
63-
if args.bidirectional:
64-
target_hidden_size *= 2
65-
max_oov_count = args.max_oov_count
66-
max_len = args.max_src_len
6764
embed_size = args.embed_size
6865
embedding = nn.Embedding(len(vocab2id), embed_size, vocab2id[PAD_WORD])
6966
self.encoder = CopyRnnEncoder(vocab2id=vocab2id,
@@ -75,13 +72,10 @@ def __init__(self, args, vocab2id):
7572
decoder_src_hidden_size = 2 * src_hidden_size
7673
else:
7774
decoder_src_hidden_size = src_hidden_size
78-
self.decoder = CopyRnnDecoder(vocab2id=vocab2id,
79-
embedding=embedding,
80-
target_hidden_size=target_hidden_size,
81-
src_hidden_size=decoder_src_hidden_size,
82-
max_len=max_len,
83-
dropout=args.dropout,
84-
max_oov_count=max_oov_count)
75+
self.decoder = CopyRnnDecoder(vocab2id=vocab2id, embedding=embedding, args=args)
76+
if decoder_src_hidden_size != target_hidden_size:
77+
self.encoder2decoder_state = nn.Linear(decoder_src_hidden_size, target_hidden_size)
78+
self.encoder2decoder_cell = nn.Linear(decoder_src_hidden_size, target_hidden_size)
8579

8680
def forward(self, src_dict, prev_output_tokens, encoder_output_dict,
8781
prev_decoder_state, prev_hidden_state):
@@ -104,6 +98,8 @@ def forward(self, src_dict, prev_output_tokens, encoder_output_dict,
10498
if encoder_output_dict is None:
10599
encoder_output_dict = self.encoder(src_dict)
106100
prev_hidden_state = encoder_output_dict['encoder_hidden']
101+
prev_hidden_state[0] = self.encoder2decoder_state(prev_hidden_state[0])
102+
prev_hidden_state[1] = self.encoder2decoder_cell(prev_hidden_state[1])
107103

108104
decoder_prob, prev_decoder_state, prev_hidden_state = self.decoder(
109105
src_dict=src_dict,
@@ -123,17 +119,15 @@ def __init__(self, vocab2id, embedding, hidden_size,
123119
self.embedding = embedding
124120
self.hidden_size = hidden_size
125121
self.bidirectional = bidirectional
126-
self.dropout = dropout
127122
self.num_layers = 1
128123
self.pad_idx = vocab2id[PAD_WORD]
129-
124+
self.dropout = dropout
130125
self.lstm = nn.LSTM(
131126
input_size=embed_dim,
132127
hidden_size=hidden_size,
133-
num_layers=1,
128+
num_layers=self.num_layers,
134129
bidirectional=bidirectional,
135-
batch_first=True,
136-
dropout=self.dropout
130+
batch_first=True
137131
)
138132

139133
def forward(self, src_dict):
@@ -146,6 +140,8 @@ def forward(self, src_dict):
146140
src_lengths = src_dict[TOKENS_LENS]
147141
batch_size = len(src_tokens)
148142
src_embed = self.embedding(src_tokens)
143+
src_embed = F.dropout(src_embed, p=self.dropout, training=self.training)
144+
149145
total_length = src_embed.size(1)
150146
packed_src_embed = nn.utils.rnn.pack_padded_sequence(src_embed,
151147
src_lengths,
@@ -172,37 +168,44 @@ def forward(self, src_dict):
172168

173169

174170
class CopyRnnDecoder(nn.Module):
175-
def __init__(self, vocab2id, embedding, target_hidden_size, src_hidden_size, max_len,
176-
dropout, max_oov_count, is_copy=True, is_attention=True):
171+
def __init__(self, vocab2id, embedding, args):
177172
super().__init__()
178173
self.vocab2id = vocab2id
179174
vocab_size = embedding.num_embeddings
180175
embed_dim = embedding.embedding_dim
181176
self.vocab_size = vocab_size
182177
self.embed_size = embed_dim
183178
self.embedding = embedding
184-
self.target_hidden_size = target_hidden_size
185-
self.src_hidden_size = src_hidden_size
186-
self.max_len = max_len
187-
self.max_oov_count = max_oov_count
188-
self.dropout = dropout
179+
self.target_hidden_size = args.target_hidden_size
180+
if args.bidirectional:
181+
self.src_hidden_size = args.src_hidden_size * 2
182+
else:
183+
self.src_hidden_size = args.src_hidden_size
184+
self.max_src_len = args.max_src_len
185+
self.max_oov_count = args.max_oov_count
186+
self.dropout = args.dropout
189187
self.pad_idx = vocab2id[PAD_WORD]
190-
self.is_copy = is_copy
191-
if is_copy:
192-
decoder_hidden_size = embed_dim + src_hidden_size
188+
self.is_copy = args.copy_net
189+
self.input_feeding = args.input_feeding
190+
191+
if self.is_copy:
192+
decoder_input_size = embed_dim + self.src_hidden_size
193193
else:
194-
decoder_hidden_size = embed_dim
194+
decoder_input_size = embed_dim
195+
196+
if args.input_feeding:
197+
decoder_input_size += self.target_hidden_size
198+
195199
self.lstm = nn.LSTM(
196-
input_size=decoder_hidden_size,
197-
hidden_size=target_hidden_size,
200+
input_size=decoder_input_size,
201+
hidden_size=self.target_hidden_size,
198202
num_layers=1,
199-
batch_first=True,
200-
dropout=self.dropout
203+
batch_first=True
201204
)
202-
self.attn_layer = Attention(src_hidden_size, target_hidden_size)
203-
self.copy_proj = nn.Linear(src_hidden_size, target_hidden_size, bias=False)
204-
self.input_copy_proj = nn.Linear(src_hidden_size, target_hidden_size, bias=False)
205-
self.generate_proj = nn.Linear(target_hidden_size, self.vocab_size, bias=False)
205+
self.attn_layer = Attention(self.src_hidden_size, self.target_hidden_size)
206+
self.copy_proj = nn.Linear(self.src_hidden_size, self.target_hidden_size, bias=False)
207+
self.input_copy_proj = nn.Linear(self.src_hidden_size, self.target_hidden_size, bias=False)
208+
self.generate_proj = nn.Linear(self.target_hidden_size, self.vocab_size, bias=False)
206209

207210
def forward(self, prev_output_tokens, encoder_output_dict, prev_context_state,
208211
prev_rnn_state, src_dict):
@@ -224,7 +227,8 @@ def forward(self, prev_output_tokens, encoder_output_dict, prev_context_state,
224227
else:
225228
output = self.forward_rnn(encoder_output_dict=encoder_output_dict,
226229
prev_output_tokens=prev_output_tokens,
227-
prev_rnn_state=prev_rnn_state)
230+
prev_rnn_state=prev_rnn_state,
231+
prev_context_state=prev_context_state)
228232
return output
229233

230234
def forward_copyrnn(self,
@@ -242,6 +246,7 @@ def forward_copyrnn(self,
242246

243247
encoder_output = encoder_output_dict['encoder_output']
244248
encoder_output_mask = encoder_output_dict['encoder_padding_mask']
249+
# B x 1 x L
245250
copy_state = self.get_attn_read_input(encoder_output,
246251
prev_context_state,
247252
prev_output_tokens,
@@ -250,7 +255,13 @@ def forward_copyrnn(self,
250255
# map copied oov tokens to OOV idx to avoid embedding lookup error
251256
prev_output_tokens[prev_output_tokens >= self.vocab_size] = self.vocab2id[UNK_WORD]
252257
src_embed = self.embedding(prev_output_tokens)
253-
decoder_input = torch.cat([src_embed, copy_state], dim=2)
258+
if self.input_feeding:
259+
prev_context_state = prev_context_state.unsqueeze(1)
260+
decoder_input = torch.cat([src_embed, prev_context_state, copy_state], dim=2)
261+
# print(decoder_input.size())
262+
else:
263+
decoder_input = torch.cat([src_embed, copy_state], dim=2)
264+
decoder_input = F.dropout(decoder_input, p=self.dropout, training=self.training)
254265
rnn_output, rnn_state = self.lstm(decoder_input, prev_rnn_state)
255266
rnn_state = list(rnn_state)
256267
# attn_output is the final hidden state of decoder layer
@@ -270,11 +281,16 @@ def forward_copyrnn(self,
270281
total_prob = torch.log(total_prob)
271282
return total_prob, attn_output.squeeze(1), rnn_state
272283

273-
def forward_rnn(self, encoder_output_dict, prev_output_tokens, prev_rnn_state):
284+
def forward_rnn(self, encoder_output_dict, prev_output_tokens, prev_rnn_state, prev_context_state):
274285
encoder_output = encoder_output_dict['encoder_output']
275286
encoder_output_mask = encoder_output_dict['encoder_padding_mask']
276287
src_embed = self.embedding(prev_output_tokens)
277-
rnn_output, rnn_state = self.lstm(src_embed, prev_rnn_state)
288+
if self.input_feeding:
289+
prev_context_state = prev_context_state.unsqueeze(1)
290+
decoder_input = torch.cat([src_embed, prev_context_state], dim=2)
291+
else:
292+
decoder_input = src_embed
293+
rnn_output, rnn_state = self.lstm(decoder_input, prev_rnn_state)
278294
rnn_state = list(rnn_state)
279295
attn_output, attn_weights = self.attn_layer(rnn_output, encoder_output, encoder_output_mask)
280296
probs = torch.log_softmax(self.generate_proj(attn_output).squeeze(1), 1)
@@ -291,17 +307,18 @@ def get_attn_read_input(self, encoder_output, prev_context_state,
291307
:return:
292308
"""
293309
# mask : B x L x 1
294-
mask_bool = torch.eq(prev_output_tokens.repeat(1, self.max_len), src_tokens_with_oov).unsqueeze(2)
310+
mask_bool = torch.eq(prev_output_tokens.repeat(1, self.max_src_len),
311+
src_tokens_with_oov).unsqueeze(2)
295312
mask = mask_bool.type_as(encoder_output)
296313
# B x L x SH
297314
aggregate_weight = torch.tanh(self.input_copy_proj(torch.mul(mask, encoder_output)))
298315
# when all prev_tokens are not in src_tokens, don't execute mask -inf to avoid nan result in softmax
299-
no_zero_mask = ((mask != 0).sum(dim=1) != 0).repeat(1, self.max_len).unsqueeze(2)
316+
no_zero_mask = ((mask != 0).sum(dim=1) != 0).repeat(1, self.max_src_len).unsqueeze(2)
300317
input_copy_logit_mask = no_zero_mask * mask_bool
301318
input_copy_logit = torch.bmm(aggregate_weight, prev_context_state.unsqueeze(2))
302319
input_copy_logit.masked_fill_(input_copy_logit_mask, float('-inf'))
303320
input_copy_weight = torch.softmax(input_copy_logit.squeeze(2), 1)
304-
# B x 1 x L
321+
# B x 1 x SH
305322
copy_state = torch.bmm(input_copy_weight.unsqueeze(1), encoder_output)
306323
return copy_state
307324

0 commit comments

Comments
 (0)