|
1 | 1 | # -*- coding: UTF-8 -*- |
2 | 2 | import random |
3 | | -import time |
4 | 3 | import traceback |
5 | 4 | import sys |
6 | 5 | import numpy as np |
7 | | -import torch |
8 | | -import torch.multiprocessing as multiprocessing |
| 6 | +import multiprocessing |
9 | 7 | from pysenal import read_jsonline_lazy, get_chunk, read_jsonline |
10 | 8 | from deep_keyphrase.utils.constants import * |
11 | 9 |
|
@@ -44,6 +42,7 @@ def __init__(self, data_source, vocab2id, mode, args): |
44 | 42 | self.max_oov_count = args.max_oov_count |
45 | 43 | self.max_target_len = args.max_target_len |
46 | 44 | self.mode = mode |
| 45 | + self.fix_batch_size = args.fix_batch_size |
47 | 46 | self.prefetch = args.prefetch |
48 | 47 | self.lazy_loading = args.lazy_loading |
49 | 48 | self.shuffle = args.shuffle |
@@ -104,6 +103,8 @@ def __init__(self, loader): |
104 | 103 | self.token_field = loader.token_field |
105 | 104 | self.keyphrases_field = loader.keyphrases_field |
106 | 105 | self.lazy_loading = loader.lazy_loading |
| 106 | + self.backend = loader.args.backend |
| 107 | + self.fix_batch_size = loader.fix_batch_size |
107 | 108 | self.num_workers = multiprocessing.cpu_count() // 2 or 1 |
108 | 109 |
|
109 | 110 | if self.loader.mode == TRAIN_MODE: |
@@ -223,7 +224,7 @@ def __prefetch(self): |
223 | 224 | self.input_queue.put(batch) |
224 | 225 | self._batch_count_in_output_queue += 1 |
225 | 226 | else: |
226 | | - for _ in range(self.num_workers * 5): |
| 227 | + for _ in range(self.num_workers): |
227 | 228 | try: |
228 | 229 | item_chunk = next(self._data) |
229 | 230 | except StopIteration: |
@@ -273,16 +274,31 @@ def get_batches_raw(self, item_chunk, batch): |
273 | 274 | batches = [] |
274 | 275 |
|
275 | 276 | for item in item_chunk: |
276 | | - if batch and len(batch) > self.batch_size: |
277 | | - for sliced_batch in get_chunk(batch, self.batch_size): |
278 | | - batches.append(sliced_batch) |
279 | | - batch = [] |
280 | | - flatten_items = self.flatten_raw_item(item) |
281 | | - if batch and len(batch) + len(flatten_items) > self.batch_size: |
282 | | - batches.append(batch) |
283 | | - batch = flatten_items |
284 | | - else: |
| 277 | + if self.fix_batch_size: |
| 278 | + if batch and len(batch) > self.batch_size: |
| 279 | + tail_count = len(batch) % self.batch_size |
| 280 | + if tail_count: |
| 281 | + batch_chunk = batch[:-tail_count] |
| 282 | + batch = batch[-tail_count:] |
| 283 | + else: |
| 284 | + batch_chunk = batch |
| 285 | + batch = [] |
| 286 | + for sliced_batch in get_chunk(batch_chunk, self.batch_size): |
| 287 | + batches.append(sliced_batch) |
| 288 | + |
| 289 | + flatten_items = self.flatten_raw_item(item) |
285 | 290 | batch.extend(flatten_items) |
| 291 | + else: |
| 292 | + if batch and len(batch) > self.batch_size: |
| 293 | + for sliced_batch in get_chunk(batch, self.batch_size): |
| 294 | + batches.append(sliced_batch) |
| 295 | + batch = [] |
| 296 | + flatten_items = self.flatten_raw_item(item) |
| 297 | + if batch and len(batch) + len(flatten_items) > self.batch_size: |
| 298 | + batches.append(batch) |
| 299 | + batch = flatten_items |
| 300 | + else: |
| 301 | + batch.extend(flatten_items) |
286 | 302 | # batches = self.reorder_batch_list(batches) |
287 | 303 | return batches, batch |
288 | 304 |
|
@@ -333,7 +349,12 @@ def batch2tensor(self, batch): |
333 | 349 | new_batch = {} |
334 | 350 | for key, val in batch.items(): |
335 | 351 | if isinstance(val, np.ndarray): |
336 | | - new_batch[key] = torch.as_tensor(val) |
| 352 | + if self.backend == 'torch': |
| 353 | + import torch |
| 354 | + new_batch[key] = torch.as_tensor(val) |
| 355 | + elif self.backend == 'tf': |
| 356 | + import tensorflow as tf |
| 357 | + new_batch[key] = tf.constant(val) |
337 | 358 | else: |
338 | 359 | new_batch[key] = val |
339 | 360 | return new_batch |
|
0 commit comments