Skip to content

Commit 629a4fc

Browse files
adjust predictor call
1 parent 3ee5f3b commit 629a4fc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

deep_keyphrase/copy_rnn/predict.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from deep_keyphrase.copy_rnn.beam_search import BeamSearch
99
from deep_keyphrase.dataloader import KeyphraseDataLoader, RAW_BATCH, TOKENS, INFERENCE_MODE, EVAL_MODE
1010
from deep_keyphrase.utils.vocab_loader import load_vocab
11-
from deep_keyphrase.utils.constants import BOS_WORD
11+
from deep_keyphrase.utils.constants import BOS_WORD, UNK_WORD
1212
from deep_keyphrase.utils.tokenizer import token_char_tokenize
1313

1414

@@ -47,6 +47,7 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le
4747
max_target_len=self.max_target_len,
4848
id2vocab=self.id2vocab,
4949
bos_idx=self.vocab2id[BOS_WORD],
50+
unk_idx=self.vocab2id[UNK_WORD],
5051
args=self.config)
5152
self.pred_base_config = {'max_oov_count': self.config.max_oov_count,
5253
'max_src_len': self.max_src_len,
@@ -74,7 +75,7 @@ def predict(self, text_list, batch_size=10, delimiter=None, tokenized=False):
7475
text_list = [{TOKENS: i} for i in text_list]
7576
else:
7677
text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list]
77-
args = Munch({'batch_size': batch_size, **self.pred_base_config})
78+
args = Munch({'batch_size': batch_size, **self.config._asdict(), **self.pred_base_config})
7879
loader = KeyphraseDataLoader(data_source=text_list,
7980
vocab2id=self.vocab2id,
8081
mode=INFERENCE_MODE,
@@ -106,6 +107,7 @@ def eval_predict(self, src_filename, dest_filename, args,
106107
max_target_len=self.max_target_len,
107108
id2vocab=self.id2vocab,
108109
bos_idx=self.vocab2id[BOS_WORD],
110+
unk_idx=self.vocab2id[UNK_WORD],
109111
args=self.config)
110112

111113
for batch in loader:

0 commit comments

Comments
 (0)