Skip to content

Commit 4e76108

Browse files
implement CopyRNN, add beam_search, add bidirectional
1 parent 9934be2 commit 4e76108

File tree

1 file changed

+113
-25
lines changed

1 file changed

+113
-25
lines changed

deep_keyphrase/copy_rnn/model_tf.py

Lines changed: 113 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
11
# -*- coding: UTF-8 -*-
22
import argparse
33
import tensorflow as tf
4+
from ..dataloader import UNK_WORD, BOS_WORD
5+
6+
7+
def mask_fill(t, mask, num):
8+
"""
9+
10+
:param t: input tensor
11+
:param mask: mask indicator True for keeping value and False for mask
12+
:param num: num to be fill in mask tensor
13+
:return:
14+
"""
15+
t_dtype = t.dtype
16+
mask = tf.cast(mask, dtype=t_dtype)
17+
neg_mask = 1 - mask
18+
filled_t = t * mask + neg_mask * num
19+
return filled_t
420

521

622
class Attention(tf.keras.layers.Layer):
7-
def __init__(self, encoder_dim, decoder_dim, dec_len, score_mode='general'):
23+
def __init__(self, encoder_dim, decoder_dim, score_mode='general'):
824
super().__init__()
925
self.encoder_dim = encoder_dim
1026
self.decoder_dim = decoder_dim
11-
self.dec_len = dec_len
1227
self.score_mode = score_mode
1328
self.permuate_1_2 = tf.keras.layers.Permute((2, 1))
1429
if self.score_mode == 'general':
@@ -25,8 +40,8 @@ def score(self, query, key, mask):
2540
elif self.score_mode == 'dot':
2641
attn_weights = tf.matmul(query, self.permuate_1_2(key))
2742

28-
mask = tf.repeat(tf.expand_dims(mask, 1), repeats=self.dec_len, axis=1)
29-
attn_weights = tf.cast(mask, tf.float32) * attn_weights
43+
mask = tf.repeat(tf.expand_dims(mask, 1), repeats=query.shape[1], axis=1)
44+
attn_weights = mask_fill(attn_weights, mask, float('-inf') / 2)
3045
attn_weights = tf.nn.softmax(attn_weights, axis=2)
3146
return attn_weights
3247

@@ -38,20 +53,86 @@ def call(self, decoder_output, encoder_output, enc_mask):
3853

3954

4055
class CopyRnnTF(tf.keras.Model):
41-
def __init__(self, args: argparse.Namespace, vocab_size):
56+
def __init__(self, args: argparse.Namespace, vocab2id):
4257
super().__init__()
43-
self.embedding = tf.keras.layers.Embedding(vocab_size, args.embed_dim)
58+
self.args = args
59+
self.vocab_size = len(vocab2id)
60+
self.embedding = tf.keras.layers.Embedding(self.vocab_size, args.embed_dim)
4461
self.encoder = Encoder(args, self.embedding)
4562
self.decoder = Decoder(args, self.embedding)
63+
self.max_target_len = self.args.max_target_len
64+
self.total_vocab_size = self.vocab_size + args.max_oov_count
65+
self.encoder2decoder_state = tf.keras.layers.Dense(args.decoder_hidden_size)
66+
self.encoder2decoder_cell = tf.keras.layers.Dense(args.decoder_hidden_size)
67+
self.beam_size = args.beam_size
68+
self.unk_idx = vocab2id[UNK_WORD]
69+
self.bos_idx = vocab2id[BOS_WORD]
4670

4771
def call(self, x, x_with_oov, x_len, enc_output, dec_x, prev_h, prev_c):
4872
if enc_output is None:
4973
enc_output, prev_h, prev_c = self.encoder(x, x_len)
74+
prev_h = self.encoder2decoder_state(prev_h)
75+
prev_c = self.encoder2decoder_state(prev_c)
5076

5177
probs, prev_h, prev_c = self.decoder(dec_x, x, x_with_oov, x_len, enc_output, prev_h, prev_c)
5278

5379
return probs, enc_output, prev_h, prev_c
5480

81+
@tf.function
82+
def beam_search(self, x, x_with_oov, x_len):
83+
batch_size = x.shape[0]
84+
beam_batch_size = self.beam_size * batch_size
85+
prev_output_tokens = tf.ones([batch_size, 1], dtype=tf.int64) * self.bos_idx
86+
87+
probs, enc_output, prev_h, prev_c = self.call(x, x_with_oov, x_len, None,
88+
prev_output_tokens, None, None)
89+
probs = tf.squeeze(probs, axis=1)
90+
prev_best_probs, prev_best_index = tf.math.top_k(probs, k=self.beam_size)
91+
92+
prev_h = tf.repeat(prev_h, self.beam_size, axis=0)
93+
prev_c = tf.repeat(prev_c, self.beam_size, axis=0)
94+
enc_output = tf.repeat(enc_output, self.beam_size, axis=0)
95+
96+
result_sequences = prev_best_index
97+
98+
prev_best_index = mask_fill(prev_best_index, prev_best_index >= self.vocab_size, self.unk_idx)
99+
prev_best_index = tf.reshape(prev_best_index, [beam_batch_size, -1])
100+
x = tf.repeat(x, repeats=self.beam_size, axis=0)
101+
102+
x_with_oov = tf.repeat(x_with_oov, repeats=self.beam_size, axis=0)
103+
x_len = tf.repeat(x_len, repeats=self.beam_size, axis=0)
104+
105+
for target_idx in range(1, self.max_target_len):
106+
probs, enc_output, prev_h, prev_c = self.call(x, x_with_oov, x_len, enc_output,
107+
prev_best_index,
108+
prev_h, prev_c)
109+
probs = tf.squeeze(probs, axis=1)
110+
accumulated_probs = tf.reshape(prev_best_probs, [beam_batch_size, -1])
111+
accumulated_probs = tf.repeat(accumulated_probs, repeats=self.total_vocab_size, axis=1)
112+
accumulated_probs += probs
113+
accumulated_probs = tf.reshape(accumulated_probs, [batch_size, -1])
114+
beam_search_best_probs, top_token_index = tf.math.top_k(-accumulated_probs, self.beam_size)
115+
116+
select_idx_factor = tf.range(0, batch_size) * self.beam_size
117+
select_idx_factor = tf.repeat(tf.expand_dims(select_idx_factor, axis=1), self.beam_size, axis=1)
118+
state_select_idx = tf.reshape(top_token_index, [beam_batch_size]) // probs.shape[1]
119+
state_select_idx += tf.reshape(select_idx_factor, [beam_batch_size])
120+
121+
prev_best_index = top_token_index % probs.shape[1]
122+
prev_h = tf.gather(prev_h, state_select_idx, axis=0)
123+
prev_c = tf.gather(prev_c, state_select_idx, axis=0)
124+
125+
result_sequences = tf.reshape(result_sequences, [beam_batch_size, -1])
126+
result_sequences = tf.gather(result_sequences, state_select_idx, axis=0)
127+
result_sequences = tf.reshape(result_sequences, [batch_size, self.beam_size, -1])
128+
129+
result_sequences = tf.concat([result_sequences, tf.expand_dims(prev_best_index, axis=2)], axis=2)
130+
131+
prev_best_index = tf.reshape(prev_best_index, [beam_batch_size, -1])
132+
prev_best_index = mask_fill(prev_best_index, prev_best_index >= self.vocab_size, self.unk_idx)
133+
134+
return result_sequences
135+
55136

56137
class Encoder(tf.keras.layers.Layer):
57138
def __init__(self, args, embedding):
@@ -60,12 +141,20 @@ def __init__(self, args, embedding):
60141
self.embedding = embedding
61142
self.lstm = tf.keras.layers.LSTM(self.args.encoder_hidden_size,
62143
return_state=True, return_sequences=True)
144+
if args.bidirectional:
145+
self.lstm = tf.keras.layers.Bidirectional(self.lstm)
146+
63147
self.max_dec = self.args.max_src_len
64148

65149
def call(self, x, x_len):
66150
embed_x = self.embedding(x)
67151
mask = tf.sequence_mask(x_len, maxlen=self.max_dec)
68-
lstm_output, state_h, state_c = self.lstm(embed_x, mask=mask)
152+
if self.args.bidirectional:
153+
lstm_output, state_fw_h, state_fw_c, state_bw_h, state_bw_c = self.lstm(embed_x, mask=mask)
154+
state_h = tf.concat([state_fw_h, state_bw_h], axis=1)
155+
state_c = tf.concat([state_fw_c, state_bw_c], axis=1)
156+
else:
157+
lstm_output, state_h, state_c = self.lstm(embed_x)
69158
return lstm_output, state_h, state_c
70159

71160

@@ -80,17 +169,16 @@ def __init__(self, args, embedding):
80169
self.max_enc = self.args.max_src_len
81170
self.lstm = tf.keras.layers.LSTM(self.args.decoder_hidden_size,
82171
return_state=True, return_sequences=True)
83-
self.attention = Attention(self.args.encoder_hidden_size,
84-
self.args.decoder_hidden_size,
85-
self.args.max_target_len)
172+
if self.args.bidirectional:
173+
enc_hidden_size = self.args.encoder_hidden_size * 2
174+
else:
175+
enc_hidden_size = self.args.encoder_hidden_size
176+
self.attention = Attention(enc_hidden_size,
177+
self.args.decoder_hidden_size)
86178
self.generate_layer = tf.keras.layers.Dense(self.vocab_size, use_bias=False)
87179
self.concat_layer = tf.keras.layers.Concatenate()
88180
self.copy_layer = tf.keras.layers.Dense(self.args.decoder_hidden_size)
89181
self.permuate_1_2 = tf.keras.layers.Permute((2, 1))
90-
i1, i2 = tf.meshgrid(tf.range(self.args.batch_size, dtype=tf.int64),
91-
tf.range(self.args.max_target_len, dtype=tf.int64), indexing="ij")
92-
self.i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, self.args.max_src_len])
93-
self.i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, self.args.max_src_len])
94182

95183
def call(self, dec_x, enc_x, enc_x_with_oov, enc_len, enc_output, enc_h, enc_c):
96184
return self.call_one_pass(dec_x, enc_x, enc_x_with_oov, enc_len, enc_output, enc_h, enc_c)
@@ -100,10 +188,8 @@ def call_one_pass(self, dec_x, enc_x: tf.Tensor, enc_x_with_oov, enc_len, enc_ou
100188
use for teacher forcing during training
101189
:return:
102190
"""
103-
# batch_size = enc_x.shape[0]
104191
embed_dec_x = self.embedding(dec_x)
105192
mask = tf.sequence_mask(enc_len, maxlen=self.max_dec)
106-
# print(mask.shape)
107193
hidden_states, state_h, state_c = self.lstm(embed_dec_x, (enc_h, enc_c))
108194
attn_output = self.attention(hidden_states, enc_output, mask)
109195
generation_logits = tf.exp(self.generate_layer(attn_output))
@@ -114,7 +200,7 @@ def call_one_pass(self, dec_x, enc_x: tf.Tensor, enc_x_with_oov, enc_len, enc_ou
114200
copy_logits = self.get_copy_score(enc_output, enc_x_with_oov, attn_output)
115201
total_logits = generation_logits + copy_logits
116202
total_prob = total_logits / tf.reduce_sum(total_logits, axis=1, keepdims=True)
117-
total_prob = tf.math.log1p(total_prob)
203+
# total_prob = tf.math.log1p(total_prob)
118204
return total_prob, state_h, state_c
119205

120206
def call_auto_regressive(self, x, prev_state):
@@ -125,23 +211,25 @@ def get_copy_score(self, src_output, x_with_oov: tf.Tensor, tgt_output):
125211
batch_size = x_with_oov.shape[0]
126212
total_vocab_size = self.vocab_size + self.max_oov_count
127213
dec_len = tgt_output.shape[1]
128-
enc_len = src_output.shape[1]
129214

130215
tgt_output = self.permuate_1_2(tgt_output)
131216
copy_score_in_seq = tf.matmul(tf.tanh(self.copy_layer(src_output)), tgt_output)
132217

133218
copy_score_in_seq = self.permuate_1_2(copy_score_in_seq)
134219
copy_score_in_seq = tf.exp(copy_score_in_seq)
135220

136-
i1, i2 = tf.meshgrid(tf.range(batch_size, dtype=tf.int64),
137-
tf.range(self.max_dec, dtype=tf.int64), indexing="ij")
138-
221+
batch_idx, src_idx, _ = tf.meshgrid(tf.range(batch_size, dtype=tf.int64),
222+
tf.range(dec_len, dtype=tf.int64),
223+
tf.range(self.max_dec, dtype=tf.int64), indexing="ij")
224+
# batch_idx = tf.transpose(tf.broadcast_to(tf.range(batch_size, dtype=tf.int64),
225+
# [self.max_dec, dec_len, batch_size]))
226+
# src_idx = tf.broadcast_to(tf.range(dec_len, dtype=tf.int64), [batch_size, dec_len])
227+
# src_idx = tf.repeat(tf.expand_dims(src_idx, axis=2), repeats=self.max_dec, axis=2)
228+
x_with_oov = tf.repeat(tf.expand_dims(x_with_oov, axis=1), repeats=dec_len, axis=1)
139229
# Create final indices
140-
idx = tf.stack([i1, i2, x_with_oov], axis=-1)
141-
idx = tf.expand_dims(idx, axis=1)
142-
idx = tf.repeat(idx, repeats=dec_len, axis=1)
230+
score_idx = tf.stack([batch_idx, src_idx, x_with_oov], axis=-1)
143231
# Output shape
144232
to_shape = [batch_size, dec_len, total_vocab_size]
145233
# Get scattered tensor
146-
copy_score_in_vocab = tf.scatter_nd(idx, copy_score_in_seq, to_shape)
234+
copy_score_in_vocab = tf.scatter_nd(score_idx, copy_score_in_seq, to_shape)
147235
return copy_score_in_vocab

0 commit comments

Comments
 (0)