Skip to content

Commit 4539a2c

Browse files
authored
feat(array): add batch_size (#182)
1 parent a4d1d4e commit 4539a2c

File tree

6 files changed

+101
-72
lines changed

6 files changed

+101
-72
lines changed

docarray/array/mixins/io/binary.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,7 @@ def _load_binary_stream(
9393

9494
from .... import Document
9595

96-
if _show_progress:
97-
from rich.progress import track as _track
98-
99-
track = lambda x: _track(x, description='Deserializing')
100-
else:
101-
track = lambda x: x
96+
from rich.progress import track
10297

10398
with file_ctx as f:
10499
version_numdocs_lendoc0 = f.read(9)
@@ -107,7 +102,9 @@ def _load_binary_stream(
107102
# 8 bytes (uint64)
108103
num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False)
109104

110-
for _ in track(range(num_docs)):
105+
for _ in track(
106+
range(num_docs), description='Deserializing', disable=not _show_progress
107+
):
111108
# 4 bytes (uint32)
112109
len_current_doc_in_bytes = int.from_bytes(
113110
f.read(4), 'big', signed=False
@@ -155,18 +152,14 @@ def _load_binary_all(
155152
version = int.from_bytes(d[0:1], 'big', signed=False)
156153
# 8 bytes (uint64)
157154
num_docs = int.from_bytes(d[1:9], 'big', signed=False)
158-
if show_progress:
159-
from rich.progress import track as _track
160155

161-
track = lambda x: _track(x, description='Deserializing')
162-
else:
163-
track = lambda x: x
156+
from rich.progress import track
164157

165158
# this 9 is version + num_docs bytes used
166159
start_pos = 9
167160
docs = []
168161

169-
for _ in track(range(num_docs)):
162+
for _ in track(range(num_docs), disable=not show_progress):
170163
# 4 bytes (uint32)
171164
len_current_doc_in_bytes = int.from_bytes(
172165
d[start_pos : start_pos + 4], 'big', signed=False
@@ -280,12 +273,6 @@ def to_bytes(
280273
f.write(pickle.dumps(self))
281274
elif protocol in ('pickle', 'protobuf'):
282275
# Binary format for streaming case
283-
if _show_progress:
284-
from rich.progress import track as _track
285-
286-
track = lambda x: _track(x, description='Serializing')
287-
else:
288-
track = lambda x: x
289276

290277
# V1 DocArray streaming serialization format
291278
# | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ...
@@ -296,7 +283,11 @@ def to_bytes(
296283
num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False)
297284
f.write(version_byte + num_docs_as_bytes)
298285

299-
for d in track(self):
286+
from rich.progress import track
287+
288+
for d in track(
289+
self, description='Serializing', disable=not _show_progress
290+
):
300291
# 4 bytes (uint32)
301292
doc_as_bytes = d.to_bytes(protocol=protocol, compress=compress)
302293

docarray/array/mixins/parallel.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,12 @@ def map(
8787
if _is_lambda_or_local_function(func) and backend == 'process':
8888
func = _globalize_lambda_function(func)
8989

90-
if show_progress:
91-
from rich.progress import track as _track
92-
93-
track = lambda x: _track(x, total=len(self))
94-
else:
95-
track = lambda x: x
90+
from rich.progress import track
9691

9792
with _get_pool(backend, num_worker) as p:
98-
for x in track(p.imap(func, self)):
93+
for x in track(
94+
p.imap(func, self), total=len(self), disable=not show_progress
95+
):
9996
yield x
10097

10198
@overload
@@ -184,16 +181,13 @@ def map_batch(
184181
if _is_lambda_or_local_function(func) and backend == 'process':
185182
func = _globalize_lambda_function(func)
186183

187-
if show_progress:
188-
from rich.progress import track as _track
189-
190-
track = lambda x: _track(x, total=ceil(len(self) / batch_size))
191-
else:
192-
track = lambda x: x
184+
from rich.progress import track
193185

194186
with _get_pool(backend, num_worker) as p:
195187
for x in track(
196-
p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle))
188+
p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle)),
189+
total=ceil(len(self) / batch_size),
190+
disable=not show_progress,
197191
):
198192
yield x
199193

docarray/array/mixins/plot.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class PlotMixin:
16-
"""Helper functions for plotting the arrays. """
16+
"""Helper functions for plotting the arrays."""
1717

1818
def _ipython_display_(self):
1919
"""Displays the object in IPython as a side effect"""
@@ -304,6 +304,7 @@ def plot_image_sprites(
304304
channel_axis: int = -1,
305305
image_source: str = 'tensor',
306306
skip_empty: bool = False,
307+
show_progress: bool = False,
307308
) -> None:
308309
"""Generate a sprite image for all image tensors in this DocumentArray-like object.
309310
@@ -316,6 +317,7 @@ def plot_image_sprites(
316317
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
317318
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
318319
:param skip_empty: skip Document who has no .uri or .tensor.
320+
:param show_progress: show a progresbar.
319321
"""
320322
if not self:
321323
raise ValueError(f'{self!r} is empty')
@@ -335,8 +337,11 @@ def plot_image_sprites(
335337
[img_size * img_per_row, img_size * img_per_row, 3], dtype='uint8'
336338
)
337339
img_id = 0
340+
341+
from rich.progress import track
342+
338343
try:
339-
for d in self:
344+
for d in track(self, description='Plotting', disable=not show_progress):
340345

341346
if not d.uri and d.tensor is None:
342347
if skip_empty:

docarray/array/mixins/post.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
22

33
if TYPE_CHECKING:
44
from ... import DocumentArray
55

66

77
class PostMixin:
8-
"""Helper functions for posting DocumentArray to Jina Flow. """
8+
"""Helper functions for posting DocumentArray to Jina Flow."""
99

10-
def post(self, host: str, show_progress: bool = False) -> 'DocumentArray':
10+
def post(
11+
self, host: str, show_progress: bool = False, batch_size: Optional[int] = None
12+
) -> 'DocumentArray':
1113
"""Posting itself to a remote Flow/Sandbox and get the modified DocumentArray back
1214
1315
:param host: a host string. Can be one of the following:
@@ -19,9 +21,13 @@ def post(self, host: str, show_progress: bool = False) -> 'DocumentArray':
1921
- `jinahub+sandbox://Hello/endpoint`
2022
2123
:param show_progress: if to show a progressbar
24+
:param batch_size: number of Document on each request
2225
:return: the new DocumentArray returned from remote
2326
"""
2427

28+
if not self:
29+
return
30+
2531
from urllib.parse import urlparse
2632

2733
r = urlparse(host)
@@ -32,20 +38,28 @@ def post(self, host: str, show_progress: bool = False) -> 'DocumentArray':
3238
._replace(path='')
3339
.geturl()
3440
)
41+
batch_size = batch_size or len(self)
3542

3643
if r.scheme.startswith('jinahub'):
3744
from jina import Flow
3845

3946
f = Flow(quiet=True).add(uses=standardized_host)
4047
with f:
41-
return f.post(_on, inputs=self, show_progress=show_progress)
48+
return f.post(
49+
_on,
50+
inputs=self,
51+
show_progress=show_progress,
52+
request_size=batch_size,
53+
)
4254
elif r.scheme in ('grpc', 'http', 'websocket'):
4355
if _port is None:
4456
raise ValueError(f'can not determine port from {host}')
4557

4658
from jina import Client
4759

4860
c = Client(host=r.hostname, port=_port, protocol=r.scheme)
49-
return c.post(_on, inputs=self, show_progress=show_progress)
61+
return c.post(
62+
_on, inputs=self, show_progress=show_progress, request_size=batch_size
63+
)
5064
else:
5165
raise ValueError(f'unsupported scheme: {r.scheme}')

docarray/helper.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -178,46 +178,70 @@ def get_full_version() -> Dict:
178178
}
179179

180180

181+
assigned_ports = set()
182+
unassigned_ports = []
183+
DEFAULT_MIN_PORT = 49153
184+
MAX_PORT = 65535
185+
186+
187+
def reset_ports():
188+
def _get_unassigned_ports():
189+
# if we are running out of ports, lower default minimum port
190+
if MAX_PORT - DEFAULT_MIN_PORT - len(assigned_ports) < 100:
191+
min_port = int(os.environ.get('JINA_RANDOM_PORT_MIN', '16384'))
192+
else:
193+
min_port = int(
194+
os.environ.get('JINA_RANDOM_PORT_MIN', str(DEFAULT_MIN_PORT))
195+
)
196+
max_port = int(os.environ.get('JINA_RANDOM_PORT_MAX', str(MAX_PORT)))
197+
return set(range(min_port, max_port + 1)) - set(assigned_ports)
198+
199+
unassigned_ports.clear()
200+
assigned_ports.clear()
201+
unassigned_ports.extend(_get_unassigned_ports())
202+
random.shuffle(unassigned_ports)
203+
204+
181205
def random_port() -> Optional[int]:
182206
"""
183-
Get a random available port number from '49153' to '65535'.
207+
Get a random available port number.
184208
185209
:return: A random port.
186210
"""
187211

188-
import threading
189-
import multiprocessing
190-
from contextlib import closing
191-
import socket
192-
193-
def _get_port(port=0):
194-
with multiprocessing.Lock():
195-
with threading.Lock():
196-
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
197-
try:
198-
s.bind(('', port))
199-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
200-
return s.getsockname()[1]
201-
except OSError:
202-
pass
203-
204-
_port = None
205-
if 'JINA_RANDOM_PORT_MIN' in os.environ or 'JINA_RANDOM_PORT_MAX' in os.environ:
206-
min_port = int(os.environ.get('JINA_RANDOM_PORT_MIN', '49153'))
207-
max_port = int(os.environ.get('JINA_RANDOM_PORT_MAX', '65535'))
208-
all_ports = list(range(min_port, max_port + 1))
209-
random.shuffle(all_ports)
210-
for _port in all_ports:
211-
if _get_port(_port) is not None:
212+
def _random_port():
213+
import socket
214+
215+
def _check_bind(port):
216+
with socket.socket() as s:
217+
try:
218+
s.bind(('', port))
219+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
220+
return port
221+
except OSError:
222+
return None
223+
224+
_port = None
225+
if len(unassigned_ports) == 0:
226+
reset_ports()
227+
for idx, _port in enumerate(unassigned_ports):
228+
if _check_bind(_port) is not None:
212229
break
213230
else:
214231
raise OSError(
215-
f'can not find an available port between [{min_port}, {max_port}].'
232+
f'can not find an available port in {len(unassigned_ports)} unassigned ports, assigned already {len(assigned_ports)} ports'
216233
)
217-
else:
218-
_port = _get_port()
234+
int_port = int(_port)
235+
unassigned_ports.pop(idx)
236+
assigned_ports.add(int_port)
237+
return int_port
219238

220-
return int(_port)
239+
try:
240+
return _random_port()
241+
except OSError:
242+
assigned_ports.clear()
243+
unassigned_ports.clear()
244+
return _random_port()
221245

222246

223247
class cached_property:

tests/unit/array/mixins/test_post.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
],
1818
)
1919
@pytest.mark.parametrize('show_pbar', [True, False])
20-
def test_post_to_a_flow(show_pbar, conn_config):
20+
@pytest.mark.parametrize('batch_size', [None, 1, 10])
21+
def test_post_to_a_flow(show_pbar, conn_config, batch_size):
2122
from jina import Flow
2223

2324
def start_flow(stop_event, **kwargs):
@@ -41,7 +42,7 @@ def start_flow(stop_event, **kwargs):
4142

4243
da = DocumentArray.empty(100)
4344
try:
44-
da.post(conn_config[1].replace('$port', str(p)))
45+
da.post(conn_config[1].replace('$port', str(p)), batch_size=batch_size)
4546
except:
4647
raise
4748
finally:

0 commit comments

Comments
 (0)