11# -*- coding: UTF-8 -*-
22import argparse
33import tensorflow as tf
4+ from ..dataloader import UNK_WORD , BOS_WORD
5+
6+
7+ def mask_fill (t , mask , num ):
8+ """
9+
10+ :param t: input tensor
11+ :param mask: mask indicator True for keeping value and False for mask
12+ :param num: num to be fill in mask tensor
13+ :return:
14+ """
15+ t_dtype = t .dtype
16+ mask = tf .cast (mask , dtype = t_dtype )
17+ neg_mask = 1 - mask
18+ filled_t = t * mask + neg_mask * num
19+ return filled_t
420
521
622class Attention (tf .keras .layers .Layer ):
7- def __init__ (self , encoder_dim , decoder_dim , dec_len , score_mode = 'general' ):
23+ def __init__ (self , encoder_dim , decoder_dim , score_mode = 'general' ):
824 super ().__init__ ()
925 self .encoder_dim = encoder_dim
1026 self .decoder_dim = decoder_dim
11- self .dec_len = dec_len
1227 self .score_mode = score_mode
1328 self .permuate_1_2 = tf .keras .layers .Permute ((2 , 1 ))
1429 if self .score_mode == 'general' :
@@ -25,8 +40,8 @@ def score(self, query, key, mask):
2540 elif self .score_mode == 'dot' :
2641 attn_weights = tf .matmul (query , self .permuate_1_2 (key ))
2742
28- mask = tf .repeat (tf .expand_dims (mask , 1 ), repeats = self . dec_len , axis = 1 )
29- attn_weights = tf . cast ( mask , tf . float32 ) * attn_weights
43+ mask = tf .repeat (tf .expand_dims (mask , 1 ), repeats = query . shape [ 1 ] , axis = 1 )
44+ attn_weights = mask_fill ( attn_weights , mask , float ( '-inf' ) / 2 )
3045 attn_weights = tf .nn .softmax (attn_weights , axis = 2 )
3146 return attn_weights
3247
@@ -38,20 +53,86 @@ def call(self, decoder_output, encoder_output, enc_mask):
3853
3954
4055class CopyRnnTF (tf .keras .Model ):
41- def __init__ (self , args : argparse .Namespace , vocab_size ):
56+ def __init__ (self , args : argparse .Namespace , vocab2id ):
4257 super ().__init__ ()
43- self .embedding = tf .keras .layers .Embedding (vocab_size , args .embed_dim )
58+ self .args = args
59+ self .vocab_size = len (vocab2id )
60+ self .embedding = tf .keras .layers .Embedding (self .vocab_size , args .embed_dim )
4461 self .encoder = Encoder (args , self .embedding )
4562 self .decoder = Decoder (args , self .embedding )
63+ self .max_target_len = self .args .max_target_len
64+ self .total_vocab_size = self .vocab_size + args .max_oov_count
65+ self .encoder2decoder_state = tf .keras .layers .Dense (args .decoder_hidden_size )
66+ self .encoder2decoder_cell = tf .keras .layers .Dense (args .decoder_hidden_size )
67+ self .beam_size = args .beam_size
68+ self .unk_idx = vocab2id [UNK_WORD ]
69+ self .bos_idx = vocab2id [BOS_WORD ]
4670
4771 def call (self , x , x_with_oov , x_len , enc_output , dec_x , prev_h , prev_c ):
4872 if enc_output is None :
4973 enc_output , prev_h , prev_c = self .encoder (x , x_len )
74+ prev_h = self .encoder2decoder_state (prev_h )
75+ prev_c = self .encoder2decoder_state (prev_c )
5076
5177 probs , prev_h , prev_c = self .decoder (dec_x , x , x_with_oov , x_len , enc_output , prev_h , prev_c )
5278
5379 return probs , enc_output , prev_h , prev_c
5480
81+ @tf .function
82+ def beam_search (self , x , x_with_oov , x_len ):
83+ batch_size = x .shape [0 ]
84+ beam_batch_size = self .beam_size * batch_size
85+ prev_output_tokens = tf .ones ([batch_size , 1 ], dtype = tf .int64 ) * self .bos_idx
86+
87+ probs , enc_output , prev_h , prev_c = self .call (x , x_with_oov , x_len , None ,
88+ prev_output_tokens , None , None )
89+ probs = tf .squeeze (probs , axis = 1 )
90+ prev_best_probs , prev_best_index = tf .math .top_k (probs , k = self .beam_size )
91+
92+ prev_h = tf .repeat (prev_h , self .beam_size , axis = 0 )
93+ prev_c = tf .repeat (prev_c , self .beam_size , axis = 0 )
94+ enc_output = tf .repeat (enc_output , self .beam_size , axis = 0 )
95+
96+ result_sequences = prev_best_index
97+
98+ prev_best_index = mask_fill (prev_best_index , prev_best_index >= self .vocab_size , self .unk_idx )
99+ prev_best_index = tf .reshape (prev_best_index , [beam_batch_size , - 1 ])
100+ x = tf .repeat (x , repeats = self .beam_size , axis = 0 )
101+
102+ x_with_oov = tf .repeat (x_with_oov , repeats = self .beam_size , axis = 0 )
103+ x_len = tf .repeat (x_len , repeats = self .beam_size , axis = 0 )
104+
105+ for target_idx in range (1 , self .max_target_len ):
106+ probs , enc_output , prev_h , prev_c = self .call (x , x_with_oov , x_len , enc_output ,
107+ prev_best_index ,
108+ prev_h , prev_c )
109+ probs = tf .squeeze (probs , axis = 1 )
110+ accumulated_probs = tf .reshape (prev_best_probs , [beam_batch_size , - 1 ])
111+ accumulated_probs = tf .repeat (accumulated_probs , repeats = self .total_vocab_size , axis = 1 )
112+ accumulated_probs += probs
113+ accumulated_probs = tf .reshape (accumulated_probs , [batch_size , - 1 ])
114+ beam_search_best_probs , top_token_index = tf .math .top_k (- accumulated_probs , self .beam_size )
115+
116+ select_idx_factor = tf .range (0 , batch_size ) * self .beam_size
117+ select_idx_factor = tf .repeat (tf .expand_dims (select_idx_factor , axis = 1 ), self .beam_size , axis = 1 )
118+ state_select_idx = tf .reshape (top_token_index , [beam_batch_size ]) // probs .shape [1 ]
119+ state_select_idx += tf .reshape (select_idx_factor , [beam_batch_size ])
120+
121+ prev_best_index = top_token_index % probs .shape [1 ]
122+ prev_h = tf .gather (prev_h , state_select_idx , axis = 0 )
123+ prev_c = tf .gather (prev_c , state_select_idx , axis = 0 )
124+
125+ result_sequences = tf .reshape (result_sequences , [beam_batch_size , - 1 ])
126+ result_sequences = tf .gather (result_sequences , state_select_idx , axis = 0 )
127+ result_sequences = tf .reshape (result_sequences , [batch_size , self .beam_size , - 1 ])
128+
129+ result_sequences = tf .concat ([result_sequences , tf .expand_dims (prev_best_index , axis = 2 )], axis = 2 )
130+
131+ prev_best_index = tf .reshape (prev_best_index , [beam_batch_size , - 1 ])
132+ prev_best_index = mask_fill (prev_best_index , prev_best_index >= self .vocab_size , self .unk_idx )
133+
134+ return result_sequences
135+
55136
56137class Encoder (tf .keras .layers .Layer ):
57138 def __init__ (self , args , embedding ):
@@ -60,12 +141,20 @@ def __init__(self, args, embedding):
60141 self .embedding = embedding
61142 self .lstm = tf .keras .layers .LSTM (self .args .encoder_hidden_size ,
62143 return_state = True , return_sequences = True )
144+ if args .bidirectional :
145+ self .lstm = tf .keras .layers .Bidirectional (self .lstm )
146+
63147 self .max_dec = self .args .max_src_len
64148
65149 def call (self , x , x_len ):
66150 embed_x = self .embedding (x )
67151 mask = tf .sequence_mask (x_len , maxlen = self .max_dec )
68- lstm_output , state_h , state_c = self .lstm (embed_x , mask = mask )
152+ if self .args .bidirectional :
153+ lstm_output , state_fw_h , state_fw_c , state_bw_h , state_bw_c = self .lstm (embed_x , mask = mask )
154+ state_h = tf .concat ([state_fw_h , state_bw_h ], axis = 1 )
155+ state_c = tf .concat ([state_fw_c , state_bw_c ], axis = 1 )
156+ else :
157+ lstm_output , state_h , state_c = self .lstm (embed_x )
69158 return lstm_output , state_h , state_c
70159
71160
@@ -80,17 +169,16 @@ def __init__(self, args, embedding):
80169 self .max_enc = self .args .max_src_len
81170 self .lstm = tf .keras .layers .LSTM (self .args .decoder_hidden_size ,
82171 return_state = True , return_sequences = True )
83- self .attention = Attention (self .args .encoder_hidden_size ,
84- self .args .decoder_hidden_size ,
85- self .args .max_target_len )
172+ if self .args .bidirectional :
173+ enc_hidden_size = self .args .encoder_hidden_size * 2
174+ else :
175+ enc_hidden_size = self .args .encoder_hidden_size
176+ self .attention = Attention (enc_hidden_size ,
177+ self .args .decoder_hidden_size )
86178 self .generate_layer = tf .keras .layers .Dense (self .vocab_size , use_bias = False )
87179 self .concat_layer = tf .keras .layers .Concatenate ()
88180 self .copy_layer = tf .keras .layers .Dense (self .args .decoder_hidden_size )
89181 self .permuate_1_2 = tf .keras .layers .Permute ((2 , 1 ))
90- i1 , i2 = tf .meshgrid (tf .range (self .args .batch_size , dtype = tf .int64 ),
91- tf .range (self .args .max_target_len , dtype = tf .int64 ), indexing = "ij" )
92- self .i1 = tf .tile (i1 [:, :, tf .newaxis ], [1 , 1 , self .args .max_src_len ])
93- self .i2 = tf .tile (i2 [:, :, tf .newaxis ], [1 , 1 , self .args .max_src_len ])
94182
95183 def call (self , dec_x , enc_x , enc_x_with_oov , enc_len , enc_output , enc_h , enc_c ):
96184 return self .call_one_pass (dec_x , enc_x , enc_x_with_oov , enc_len , enc_output , enc_h , enc_c )
@@ -100,10 +188,8 @@ def call_one_pass(self, dec_x, enc_x: tf.Tensor, enc_x_with_oov, enc_len, enc_ou
100188 use for teacher forcing during training
101189 :return:
102190 """
103- # batch_size = enc_x.shape[0]
104191 embed_dec_x = self .embedding (dec_x )
105192 mask = tf .sequence_mask (enc_len , maxlen = self .max_dec )
106- # print(mask.shape)
107193 hidden_states , state_h , state_c = self .lstm (embed_dec_x , (enc_h , enc_c ))
108194 attn_output = self .attention (hidden_states , enc_output , mask )
109195 generation_logits = tf .exp (self .generate_layer (attn_output ))
@@ -114,7 +200,7 @@ def call_one_pass(self, dec_x, enc_x: tf.Tensor, enc_x_with_oov, enc_len, enc_ou
114200 copy_logits = self .get_copy_score (enc_output , enc_x_with_oov , attn_output )
115201 total_logits = generation_logits + copy_logits
116202 total_prob = total_logits / tf .reduce_sum (total_logits , axis = 1 , keepdims = True )
117- total_prob = tf .math .log1p (total_prob )
203+ # total_prob = tf.math.log1p(total_prob)
118204 return total_prob , state_h , state_c
119205
120206 def call_auto_regressive (self , x , prev_state ):
@@ -125,23 +211,25 @@ def get_copy_score(self, src_output, x_with_oov: tf.Tensor, tgt_output):
125211 batch_size = x_with_oov .shape [0 ]
126212 total_vocab_size = self .vocab_size + self .max_oov_count
127213 dec_len = tgt_output .shape [1 ]
128- enc_len = src_output .shape [1 ]
129214
130215 tgt_output = self .permuate_1_2 (tgt_output )
131216 copy_score_in_seq = tf .matmul (tf .tanh (self .copy_layer (src_output )), tgt_output )
132217
133218 copy_score_in_seq = self .permuate_1_2 (copy_score_in_seq )
134219 copy_score_in_seq = tf .exp (copy_score_in_seq )
135220
136- i1 , i2 = tf .meshgrid (tf .range (batch_size , dtype = tf .int64 ),
137- tf .range (self .max_dec , dtype = tf .int64 ), indexing = "ij" )
138-
221+ batch_idx , src_idx , _ = tf .meshgrid (tf .range (batch_size , dtype = tf .int64 ),
222+ tf .range (dec_len , dtype = tf .int64 ),
223+ tf .range (self .max_dec , dtype = tf .int64 ), indexing = "ij" )
224+ # batch_idx = tf.transpose(tf.broadcast_to(tf.range(batch_size, dtype=tf.int64),
225+ # [self.max_dec, dec_len, batch_size]))
226+ # src_idx = tf.broadcast_to(tf.range(dec_len, dtype=tf.int64), [batch_size, dec_len])
227+ # src_idx = tf.repeat(tf.expand_dims(src_idx, axis=2), repeats=self.max_dec, axis=2)
228+ x_with_oov = tf .repeat (tf .expand_dims (x_with_oov , axis = 1 ), repeats = dec_len , axis = 1 )
139229 # Create final indices
140- idx = tf .stack ([i1 , i2 , x_with_oov ], axis = - 1 )
141- idx = tf .expand_dims (idx , axis = 1 )
142- idx = tf .repeat (idx , repeats = dec_len , axis = 1 )
230+ score_idx = tf .stack ([batch_idx , src_idx , x_with_oov ], axis = - 1 )
143231 # Output shape
144232 to_shape = [batch_size , dec_len , total_vocab_size ]
145233 # Get scattered tensor
146- copy_score_in_vocab = tf .scatter_nd (idx , copy_score_in_seq , to_shape )
234+ copy_score_in_vocab = tf .scatter_nd (score_idx , copy_score_in_seq , to_shape )
147235 return copy_score_in_vocab
0 commit comments