Skip to content

Commit 253c41b

Browse files
fix batch_size type
1 parent 7a1f81b commit 253c41b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

deep_keyphrase/copy_rnn/predict_tf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def eval_predict(self, batch, model=None, delimiter=None):
2929
def predict(self, text_list):
3030
x_batch, x_oov_batch, sent_len_batch, oov_list_batch = self.generate_input_batch(text_list)
3131
batch_size = len(x_batch)
32-
result_tensor = self.model.beam_search(x_batch, x_oov_batch, sent_len_batch, [batch_size])
32+
batch_size_np = np.array([batch_size], dtype=np.long)
33+
result_tensor = self.model.beam_search(x_batch, x_oov_batch, sent_len_batch, batch_size_np)
3334
result_np = result_tensor.numpy()
3435
return self.__idx2result_beam('', oov_list_batch, result_np)
3536

0 commit comments

Comments
 (0)