|
8 | 8 | from deep_keyphrase.copy_rnn.beam_search import BeamSearch |
9 | 9 | from deep_keyphrase.dataloader import KeyphraseDataLoader, RAW_BATCH, TOKENS, INFERENCE_MODE, EVAL_MODE |
10 | 10 | from deep_keyphrase.utils.vocab_loader import load_vocab |
11 | | -from deep_keyphrase.utils.constants import BOS_WORD |
| 11 | +from deep_keyphrase.utils.constants import BOS_WORD, UNK_WORD |
12 | 12 | from deep_keyphrase.utils.tokenizer import token_char_tokenize |
13 | 13 |
|
14 | 14 |
|
@@ -47,6 +47,7 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le |
47 | 47 | max_target_len=self.max_target_len, |
48 | 48 | id2vocab=self.id2vocab, |
49 | 49 | bos_idx=self.vocab2id[BOS_WORD], |
| 50 | + unk_idx=self.vocab2id[UNK_WORD], |
50 | 51 | args=self.config) |
51 | 52 | self.pred_base_config = {'max_oov_count': self.config.max_oov_count, |
52 | 53 | 'max_src_len': self.max_src_len, |
@@ -74,7 +75,7 @@ def predict(self, text_list, batch_size=10, delimiter=None, tokenized=False): |
74 | 75 | text_list = [{TOKENS: i} for i in text_list] |
75 | 76 | else: |
76 | 77 | text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list] |
77 | | - args = Munch({'batch_size': batch_size, **self.pred_base_config}) |
| 78 | + args = Munch({'batch_size': batch_size, **self.config._asdict(), **self.pred_base_config}) |
78 | 79 | loader = KeyphraseDataLoader(data_source=text_list, |
79 | 80 | vocab2id=self.vocab2id, |
80 | 81 | mode=INFERENCE_MODE, |
@@ -106,6 +107,7 @@ def eval_predict(self, src_filename, dest_filename, args, |
106 | 107 | max_target_len=self.max_target_len, |
107 | 108 | id2vocab=self.id2vocab, |
108 | 109 | bos_idx=self.vocab2id[BOS_WORD], |
| 110 | + unk_idx=self.vocab2id[UNK_WORD], |
109 | 111 | args=self.config) |
110 | 112 |
|
111 | 113 | for batch in loader: |
|
0 commit comments