Skip to content

Commit 0bca198

Browse files
remove redundant code
1 parent 0ae1e77 commit 0bca198

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

deep_keyphrase/dataloader.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# -*- coding: UTF-8 -*-
22
import random
3-
import time
43
import traceback
54
import sys
65
import numpy as np
7-
import torch
8-
import torch.multiprocessing as multiprocessing
6+
import multiprocessing
97
from pysenal import read_jsonline_lazy, get_chunk, read_jsonline
108
from deep_keyphrase.utils.constants import *
119

@@ -44,6 +42,7 @@ def __init__(self, data_source, vocab2id, mode, args):
4442
self.max_oov_count = args.max_oov_count
4543
self.max_target_len = args.max_target_len
4644
self.mode = mode
45+
self.fix_batch_size = args.fix_batch_size
4746
self.prefetch = args.prefetch
4847
self.lazy_loading = args.lazy_loading
4948
self.shuffle = args.shuffle
@@ -104,6 +103,8 @@ def __init__(self, loader):
104103
self.token_field = loader.token_field
105104
self.keyphrases_field = loader.keyphrases_field
106105
self.lazy_loading = loader.lazy_loading
106+
self.backend = loader.args.backend
107+
self.fix_batch_size = loader.fix_batch_size
107108
self.num_workers = multiprocessing.cpu_count() // 2 or 1
108109

109110
if self.loader.mode == TRAIN_MODE:
@@ -223,7 +224,7 @@ def __prefetch(self):
223224
self.input_queue.put(batch)
224225
self._batch_count_in_output_queue += 1
225226
else:
226-
for _ in range(self.num_workers * 5):
227+
for _ in range(self.num_workers):
227228
try:
228229
item_chunk = next(self._data)
229230
except StopIteration:
@@ -273,16 +274,31 @@ def get_batches_raw(self, item_chunk, batch):
273274
batches = []
274275

275276
for item in item_chunk:
276-
if batch and len(batch) > self.batch_size:
277-
for sliced_batch in get_chunk(batch, self.batch_size):
278-
batches.append(sliced_batch)
279-
batch = []
280-
flatten_items = self.flatten_raw_item(item)
281-
if batch and len(batch) + len(flatten_items) > self.batch_size:
282-
batches.append(batch)
283-
batch = flatten_items
284-
else:
277+
if self.fix_batch_size:
278+
if batch and len(batch) > self.batch_size:
279+
tail_count = len(batch) % self.batch_size
280+
if tail_count:
281+
batch_chunk = batch[:-tail_count]
282+
batch = batch[-tail_count:]
283+
else:
284+
batch_chunk = batch
285+
batch = []
286+
for sliced_batch in get_chunk(batch_chunk, self.batch_size):
287+
batches.append(sliced_batch)
288+
289+
flatten_items = self.flatten_raw_item(item)
285290
batch.extend(flatten_items)
291+
else:
292+
if batch and len(batch) > self.batch_size:
293+
for sliced_batch in get_chunk(batch, self.batch_size):
294+
batches.append(sliced_batch)
295+
batch = []
296+
flatten_items = self.flatten_raw_item(item)
297+
if batch and len(batch) + len(flatten_items) > self.batch_size:
298+
batches.append(batch)
299+
batch = flatten_items
300+
else:
301+
batch.extend(flatten_items)
286302
# batches = self.reorder_batch_list(batches)
287303
return batches, batch
288304

@@ -333,7 +349,12 @@ def batch2tensor(self, batch):
333349
new_batch = {}
334350
for key, val in batch.items():
335351
if isinstance(val, np.ndarray):
336-
new_batch[key] = torch.as_tensor(val)
352+
if self.backend == 'torch':
353+
import torch
354+
new_batch[key] = torch.as_tensor(val)
355+
elif self.backend == 'tf':
356+
import tensorflow as tf
357+
new_batch[key] = tf.constant(val)
337358
else:
338359
new_batch[key] = val
339360
return new_batch

0 commit comments

Comments
 (0)