Skip to content

Commit 3ee5f3b

Browse files
fix oov mapping
1 parent 4f26941 commit 3ee5f3b

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

deep_keyphrase/copy_rnn/beam_search.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66

77
class BeamSearch(object):
8-
def __init__(self, model, beam_size, max_target_len, id2vocab, bos_idx, args):
8+
def __init__(self, model, beam_size, max_target_len, id2vocab, bos_idx, unk_idx, args):
99
self.model = model
1010
self.beam_size = beam_size
1111
self.id2vocab = id2vocab
12+
self.vocab_size = len(self.id2vocab)
1213
self.max_target_len = max_target_len
1314
self.bos_idx = bos_idx
15+
self.unk_idx = unk_idx
1416
self.target_hidden_size = args.target_hidden_size
1517

1618
def beam_search(self, src_dict, delimiter=None):
@@ -41,6 +43,9 @@ def beam_search(self, src_dict, delimiter=None):
4143
prev_hidden_state=hidden_state)
4244
decoder_prob, encoder_output_dict, decoder_state, hidden_state = model_output
4345
prev_best_probs, prev_best_index = torch.topk(decoder_prob, self.beam_size, 1)
46+
# map oov token to unk
47+
oov_token_mask = prev_best_index >= self.vocab_size
48+
prev_best_index.masked_fill_(oov_token_mask, self.unk_idx)
4449
# B*b x TH
4550
prev_decoder_state = decoder_state.unsqueeze(1).repeat(1, self.beam_size, 1)
4651
prev_decoder_state = prev_decoder_state.view(beam_batch_size, -1)
@@ -88,6 +93,9 @@ def beam_search(self, src_dict, delimiter=None):
8893
prev_best_index = top_token_index % decoder_prob.size(1)
8994
prev_hidden_state[0] = prev_hidden_state[0].index_select(1, state_select_idx)
9095
prev_hidden_state[1] = prev_hidden_state[1].index_select(1, state_select_idx)
96+
# map oov token to unk
97+
oov_token_mask = prev_best_index >= self.vocab_size
98+
prev_best_index.masked_fill_(oov_token_mask, self.unk_idx)
9199

92100
result_sequences = result_sequences.view(batch_size * self.beam_size, -1)
93101
result_sequences = result_sequences.index_select(0, state_select_idx)

0 commit comments

Comments
 (0)