11# -*- coding: UTF-8 -*-
2- import torch .nn as nn
32import torch
3+ import torch .nn as nn
4+ import torch .nn .functional as F
45from deep_keyphrase .utils .constants import *
56from deep_keyphrase .dataloader import TOKENS , TOKENS_OOV , TOKENS_LENS , OOV_COUNT
67
@@ -60,10 +61,6 @@ def __init__(self, args, vocab2id):
6061 super ().__init__ ()
6162 src_hidden_size = args .src_hidden_size
6263 target_hidden_size = args .target_hidden_size
63- if args .bidirectional :
64- target_hidden_size *= 2
65- max_oov_count = args .max_oov_count
66- max_len = args .max_src_len
6764 embed_size = args .embed_size
6865 embedding = nn .Embedding (len (vocab2id ), embed_size , vocab2id [PAD_WORD ])
6966 self .encoder = CopyRnnEncoder (vocab2id = vocab2id ,
@@ -75,13 +72,10 @@ def __init__(self, args, vocab2id):
7572 decoder_src_hidden_size = 2 * src_hidden_size
7673 else :
7774 decoder_src_hidden_size = src_hidden_size
78- self .decoder = CopyRnnDecoder (vocab2id = vocab2id ,
79- embedding = embedding ,
80- target_hidden_size = target_hidden_size ,
81- src_hidden_size = decoder_src_hidden_size ,
82- max_len = max_len ,
83- dropout = args .dropout ,
84- max_oov_count = max_oov_count )
75+ self .decoder = CopyRnnDecoder (vocab2id = vocab2id , embedding = embedding , args = args )
76+ if decoder_src_hidden_size != target_hidden_size :
77+ self .encoder2decoder_state = nn .Linear (decoder_src_hidden_size , target_hidden_size )
78+ self .encoder2decoder_cell = nn .Linear (decoder_src_hidden_size , target_hidden_size )
8579
8680 def forward (self , src_dict , prev_output_tokens , encoder_output_dict ,
8781 prev_decoder_state , prev_hidden_state ):
@@ -104,6 +98,8 @@ def forward(self, src_dict, prev_output_tokens, encoder_output_dict,
10498 if encoder_output_dict is None :
10599 encoder_output_dict = self .encoder (src_dict )
106100 prev_hidden_state = encoder_output_dict ['encoder_hidden' ]
101+ prev_hidden_state [0 ] = self .encoder2decoder_state (prev_hidden_state [0 ])
102+ prev_hidden_state [1 ] = self .encoder2decoder_cell (prev_hidden_state [1 ])
107103
108104 decoder_prob , prev_decoder_state , prev_hidden_state = self .decoder (
109105 src_dict = src_dict ,
@@ -123,17 +119,15 @@ def __init__(self, vocab2id, embedding, hidden_size,
123119 self .embedding = embedding
124120 self .hidden_size = hidden_size
125121 self .bidirectional = bidirectional
126- self .dropout = dropout
127122 self .num_layers = 1
128123 self .pad_idx = vocab2id [PAD_WORD ]
129-
124+ self . dropout = dropout
130125 self .lstm = nn .LSTM (
131126 input_size = embed_dim ,
132127 hidden_size = hidden_size ,
133- num_layers = 1 ,
128+ num_layers = self . num_layers ,
134129 bidirectional = bidirectional ,
135- batch_first = True ,
136- dropout = self .dropout
130+ batch_first = True
137131 )
138132
139133 def forward (self , src_dict ):
@@ -146,6 +140,8 @@ def forward(self, src_dict):
146140 src_lengths = src_dict [TOKENS_LENS ]
147141 batch_size = len (src_tokens )
148142 src_embed = self .embedding (src_tokens )
143+ src_embed = F .dropout (src_embed , p = self .dropout , training = self .training )
144+
149145 total_length = src_embed .size (1 )
150146 packed_src_embed = nn .utils .rnn .pack_padded_sequence (src_embed ,
151147 src_lengths ,
@@ -172,37 +168,44 @@ def forward(self, src_dict):
172168
173169
174170class CopyRnnDecoder (nn .Module ):
175- def __init__ (self , vocab2id , embedding , target_hidden_size , src_hidden_size , max_len ,
176- dropout , max_oov_count , is_copy = True , is_attention = True ):
171+ def __init__ (self , vocab2id , embedding , args ):
177172 super ().__init__ ()
178173 self .vocab2id = vocab2id
179174 vocab_size = embedding .num_embeddings
180175 embed_dim = embedding .embedding_dim
181176 self .vocab_size = vocab_size
182177 self .embed_size = embed_dim
183178 self .embedding = embedding
184- self .target_hidden_size = target_hidden_size
185- self .src_hidden_size = src_hidden_size
186- self .max_len = max_len
187- self .max_oov_count = max_oov_count
188- self .dropout = dropout
179+ self .target_hidden_size = args .target_hidden_size
180+ if args .bidirectional :
181+ self .src_hidden_size = args .src_hidden_size * 2
182+ else :
183+ self .src_hidden_size = args .src_hidden_size
184+ self .max_src_len = args .max_src_len
185+ self .max_oov_count = args .max_oov_count
186+ self .dropout = args .dropout
189187 self .pad_idx = vocab2id [PAD_WORD ]
190- self .is_copy = is_copy
191- if is_copy :
192- decoder_hidden_size = embed_dim + src_hidden_size
188+ self .is_copy = args .copy_net
189+ self .input_feeding = args .input_feeding
190+
191+ if self .is_copy :
192+ decoder_input_size = embed_dim + self .src_hidden_size
193193 else :
194- decoder_hidden_size = embed_dim
194+ decoder_input_size = embed_dim
195+
196+ if args .input_feeding :
197+ decoder_input_size += self .target_hidden_size
198+
195199 self .lstm = nn .LSTM (
196- input_size = decoder_hidden_size ,
197- hidden_size = target_hidden_size ,
200+ input_size = decoder_input_size ,
201+ hidden_size = self . target_hidden_size ,
198202 num_layers = 1 ,
199- batch_first = True ,
200- dropout = self .dropout
203+ batch_first = True
201204 )
202- self .attn_layer = Attention (src_hidden_size , target_hidden_size )
203- self .copy_proj = nn .Linear (src_hidden_size , target_hidden_size , bias = False )
204- self .input_copy_proj = nn .Linear (src_hidden_size , target_hidden_size , bias = False )
205- self .generate_proj = nn .Linear (target_hidden_size , self .vocab_size , bias = False )
205+ self .attn_layer = Attention (self . src_hidden_size , self . target_hidden_size )
206+ self .copy_proj = nn .Linear (self . src_hidden_size , self . target_hidden_size , bias = False )
207+ self .input_copy_proj = nn .Linear (self . src_hidden_size , self . target_hidden_size , bias = False )
208+ self .generate_proj = nn .Linear (self . target_hidden_size , self .vocab_size , bias = False )
206209
207210 def forward (self , prev_output_tokens , encoder_output_dict , prev_context_state ,
208211 prev_rnn_state , src_dict ):
@@ -224,7 +227,8 @@ def forward(self, prev_output_tokens, encoder_output_dict, prev_context_state,
224227 else :
225228 output = self .forward_rnn (encoder_output_dict = encoder_output_dict ,
226229 prev_output_tokens = prev_output_tokens ,
227- prev_rnn_state = prev_rnn_state )
230+ prev_rnn_state = prev_rnn_state ,
231+ prev_context_state = prev_context_state )
228232 return output
229233
230234 def forward_copyrnn (self ,
@@ -242,6 +246,7 @@ def forward_copyrnn(self,
242246
243247 encoder_output = encoder_output_dict ['encoder_output' ]
244248 encoder_output_mask = encoder_output_dict ['encoder_padding_mask' ]
249+ # B x 1 x L
245250 copy_state = self .get_attn_read_input (encoder_output ,
246251 prev_context_state ,
247252 prev_output_tokens ,
@@ -250,7 +255,13 @@ def forward_copyrnn(self,
250255 # map copied oov tokens to OOV idx to avoid embedding lookup error
251256 prev_output_tokens [prev_output_tokens >= self .vocab_size ] = self .vocab2id [UNK_WORD ]
252257 src_embed = self .embedding (prev_output_tokens )
253- decoder_input = torch .cat ([src_embed , copy_state ], dim = 2 )
258+ if self .input_feeding :
259+ prev_context_state = prev_context_state .unsqueeze (1 )
260+ decoder_input = torch .cat ([src_embed , prev_context_state , copy_state ], dim = 2 )
261+ # print(decoder_input.size())
262+ else :
263+ decoder_input = torch .cat ([src_embed , copy_state ], dim = 2 )
264+ decoder_input = F .dropout (decoder_input , p = self .dropout , training = self .training )
254265 rnn_output , rnn_state = self .lstm (decoder_input , prev_rnn_state )
255266 rnn_state = list (rnn_state )
256267 # attn_output is the final hidden state of decoder layer
@@ -270,11 +281,16 @@ def forward_copyrnn(self,
270281 total_prob = torch .log (total_prob )
271282 return total_prob , attn_output .squeeze (1 ), rnn_state
272283
273- def forward_rnn (self , encoder_output_dict , prev_output_tokens , prev_rnn_state ):
284+ def forward_rnn (self , encoder_output_dict , prev_output_tokens , prev_rnn_state , prev_context_state ):
274285 encoder_output = encoder_output_dict ['encoder_output' ]
275286 encoder_output_mask = encoder_output_dict ['encoder_padding_mask' ]
276287 src_embed = self .embedding (prev_output_tokens )
277- rnn_output , rnn_state = self .lstm (src_embed , prev_rnn_state )
288+ if self .input_feeding :
289+ prev_context_state = prev_context_state .unsqueeze (1 )
290+ decoder_input = torch .cat ([src_embed , prev_context_state ], dim = 2 )
291+ else :
292+ decoder_input = src_embed
293+ rnn_output , rnn_state = self .lstm (decoder_input , prev_rnn_state )
278294 rnn_state = list (rnn_state )
279295 attn_output , attn_weights = self .attn_layer (rnn_output , encoder_output , encoder_output_mask )
280296 probs = torch .log_softmax (self .generate_proj (attn_output ).squeeze (1 ), 1 )
@@ -291,17 +307,18 @@ def get_attn_read_input(self, encoder_output, prev_context_state,
291307 :return:
292308 """
293309 # mask : B x L x 1
294- mask_bool = torch .eq (prev_output_tokens .repeat (1 , self .max_len ), src_tokens_with_oov ).unsqueeze (2 )
310+ mask_bool = torch .eq (prev_output_tokens .repeat (1 , self .max_src_len ),
311+ src_tokens_with_oov ).unsqueeze (2 )
295312 mask = mask_bool .type_as (encoder_output )
296313 # B x L x SH
297314 aggregate_weight = torch .tanh (self .input_copy_proj (torch .mul (mask , encoder_output )))
298315 # when all prev_tokens are not in src_tokens, don't execute mask -inf to avoid nan result in softmax
299- no_zero_mask = ((mask != 0 ).sum (dim = 1 ) != 0 ).repeat (1 , self .max_len ).unsqueeze (2 )
316+ no_zero_mask = ((mask != 0 ).sum (dim = 1 ) != 0 ).repeat (1 , self .max_src_len ).unsqueeze (2 )
300317 input_copy_logit_mask = no_zero_mask * mask_bool
301318 input_copy_logit = torch .bmm (aggregate_weight , prev_context_state .unsqueeze (2 ))
302319 input_copy_logit .masked_fill_ (input_copy_logit_mask , float ('-inf' ))
303320 input_copy_weight = torch .softmax (input_copy_logit .squeeze (2 ), 1 )
304- # B x 1 x L
321+ # B x 1 x SH
305322 copy_state = torch .bmm (input_copy_weight .unsqueeze (1 ), encoder_output )
306323 return copy_state
307324
0 commit comments