Skip to content

Commit ac1c338

Browse files
fix padding error
1 parent 0bca198 commit ac1c338

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

deep_keyphrase/copy_rnn/predict_tf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ def generate_input(self, tokens):
102102
sent_len = len(token_ids)
103103

104104
if len(token_ids) < self.max_src_len:
105-
token_ids.extend([PAD_WORD] * (self.max_src_len - len(token_ids)))
106-
token_ids_with_oov.extend([PAD_WORD] * (self.max_src_len - len(token_ids)))
105+
pad_tokens = [self.vocab2id[PAD_WORD]] * (self.max_src_len - len(token_ids))
106+
token_ids.extend(pad_tokens)
107+
token_ids_with_oov.extend(pad_tokens)
107108
elif len(token_ids) > self.max_src_len:
108109
token_ids = token_ids[:self.max_src_len]
109110
token_ids_with_oov = token_ids_with_oov[:self.max_src_len]

0 commit comments

Comments
 (0)