Skip to content

Commit 4c27409

Browse files
author
anna-charlotte
committed
feat: add apply_batch and _map_batch and tests
Signed-off-by: anna-charlotte <[email protected]>
1 parent 51639db commit 4c27409

File tree

3 files changed

+184
-14
lines changed

3 files changed

+184
-14
lines changed

docarray/utils/apply.py

Lines changed: 144 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextlib import nullcontext
2+
from math import ceil
23
from multiprocessing.pool import Pool, ThreadPool
34
from types import LambdaType
45
from typing import Any, Callable, Generator, Optional, TypeVar, Union
@@ -20,14 +21,15 @@ def apply(
2021
) -> T:
2122
"""
2223
Apply `func` to every Document of the given DocumentArray while multiprocessing,
23-
return itself after modification, without in-place changes.
24+
return itself after modification.
2425
2526
:param da: DocumentArray to apply function to
2627
:param func: a function that takes ab:class:`BaseDocument` as input and outputs
2728
a :class:`BaseDocument`.
2829
:param backend: `thread` for multi-threading and `process` for multi-processing.
29-
Defaults to `thread`. In general, if `func` is IO-bound then `thread` is a
30-
good choice. If `func` is CPU-bound, then you may use `process`.
30+
Defaults to `thread`.
31+
In general, if `func` is IO-bound then `thread` is a good choice.
32+
On the other hand, if `func` is CPU-bound, then you may use `process`.
3133
In practice, you should try yourselves to figure out the best value.
3234
However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
3335
you should always use `thread` backend.
@@ -46,10 +48,9 @@ def apply(
4648
4749
:return: DocumentArray with applied modifications
4850
"""
49-
da_new = da.__class_getitem__(item=da.document_type)()
5051
for i, doc in enumerate(_map(da, func, backend, num_worker, pool, show_progress)):
51-
da_new.append(doc)
52-
return da_new
52+
da[i] = doc
53+
return da
5354

5455

5556
def _map(
@@ -65,12 +66,13 @@ def _map(
6566
yielding the results.
6667
6768
:param da: DocumentArray to apply function to
68-
:param func:a function that takes ab:class:`BaseDocument` as input and outputs
69+
:param func: a function that takes a :class:`BaseDocument` as input and outputs
6970
a :class:`BaseDocument`. You can either modify elements in-place or return
70-
new Documents.
71+
new Documents (depending on `backend`).
7172
:param backend: `thread` for multi-threading and `process` for multi-processing.
72-
Defaults to `thread`. In general, if `func` is IO-bound then `thread` is a
73-
good choice. If `func` is CPU-bound, then you may use `process`.
73+
Defaults to `thread`.
74+
In general, if `func` is IO-bound then `thread` is a good choice.
75+
On the other hand, if `func` is CPU-bound, then you may use `process`.
7476
In practice, you should try yourselves to figure out the best value.
7577
However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
7678
you should always use `thread` backend.
@@ -110,6 +112,138 @@ def _map(
110112
yield x
111113

112114

115+
def apply_batch(
116+
da: T,
117+
func: Callable[[T], T],
118+
batch_size: int,
119+
backend: str = 'thread',
120+
num_worker: Optional[int] = None,
121+
shuffle: bool = False,
122+
pool: Optional[Union[Pool, ThreadPool]] = None,
123+
show_progress: bool = False,
124+
) -> T:
125+
"""Batches itself into mini-batches, applies `func` to every mini-batch, and return itself after the modifications.
126+
127+
EXAMPLE USAGE
128+
129+
.. code-block:: python
130+
131+
from docarray import Document, DocumentArray
132+
133+
da = DocumentArray([Document(text='The cake is a lie') for _ in range(100)])
134+
135+
136+
def func(doc):
137+
da.texts = [t.upper() for t in da.texts]
138+
return da
139+
140+
141+
da.apply_batch(func, batch_size=10)
142+
print(da.texts[:3])
143+
144+
.. code-block:: text
145+
146+
['THE CAKE IS A LIE', 'THE CAKE IS A LIE', 'THE CAKE IS A LIE']
147+
148+
:param da: DocumentArray to apply function to
149+
:param func: a function that takes a :class:`BaseDocument` as input and outputs
150+
a :class:`BaseDocument`.
151+
:param batch_size: size of each generated batch (except the last batch, which might
152+
be smaller).
153+
:param backend: `thread` for multi-threading and `process` for multi-processing.
154+
Defaults to `thread`.
155+
In general, if `func` is IO-bound then `thread` is a good choice.
156+
On the other hand, if `func` is CPU-bound, then you may use `process`.
157+
In practice, you should try yourselves to figure out the best value.
158+
However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
159+
you should always use `thread` backend.
160+
161+
.. warning::
162+
When using `process` backend, your `func` should not modify elements in-place.
163+
This is because the multiprocessing backend passes the variable via pickle
164+
and works in another process.
165+
The passed object and the original object do **not** share the same memory.
166+
167+
:param num_worker: the number of parallel workers. If not given, the number of CPUs
168+
in the system will be used.
169+
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
170+
:param pool: use an existing/external process or thread pool. If given, you will
171+
be responsible for closing the pool.
172+
:param show_progress: show a progress bar. Defaults to False.
173+
174+
:return DocumentArray after modifications
175+
"""
176+
for i, batch in enumerate(
177+
_map_batch(
178+
da, func, batch_size, backend, num_worker, shuffle, pool, show_progress
179+
)
180+
):
181+
indices = [i for i in range(i * batch_size, (i + 1) * batch_size)]
182+
da[indices] = batch
183+
return da
184+
185+
186+
def _map_batch(
187+
da: T,
188+
func: Callable[[T], T],
189+
batch_size: int,
190+
backend: str = 'thread',
191+
num_worker: Optional[int] = None,
192+
shuffle: bool = False,
193+
pool: Optional[Union[Pool, ThreadPool]] = None,
194+
show_progress: bool = False,
195+
) -> Generator[T, None, None]:
196+
"""Return an iterator that applies function to every **minibatch** of iterable in parallel, yielding the results.
197+
Each element in the returned iterator is :class:`DocumentArray`.
198+
199+
.. seealso::
200+
- To process single element, please use :meth:`.map`;
201+
- To return :class:`DocumentArray`, please use :meth:`.apply_batch`.
202+
203+
:param batch_size: Size of each generated batch (except the last one, which might be smaller).
204+
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
205+
:param func: a function that takes :class:`DocumentArray` as input and outputs anything. You can either modify elements
206+
in-place (only with `thread` backend) or work later on return elements.
207+
:param backend: if to use multi-`process` or multi-`thread` as the parallelization backend. In general, if your
208+
``func`` is IO-bound then perhaps `thread` is good enough. If your ``func`` is CPU-bound then you may use `process`.
209+
In practice, you should try yourselves to figure out the best value. However, if you wish to modify the elements
210+
in-place, regardless of IO/CPU-bound, you should always use `thread` backend.
211+
212+
.. warning::
213+
When using `process` backend, you should not expect ``func`` modify elements in-place. This is because
214+
the multiprocessing backing pass the variable via pickle and work in another process. The passed object
215+
and the original object do **not** share the same memory.
216+
217+
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
218+
:param show_progress: show a progress bar
219+
:param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.
220+
221+
:yield: anything return from ``func``
222+
"""
223+
224+
if backend == 'process' and _is_lambda_or_partial_or_local_function(func):
225+
raise ValueError(
226+
f'Multiprocessing does not allow functions that are local, lambda or partial: {func}'
227+
)
228+
229+
from rich.progress import track
230+
231+
ctx_p: Union[nullcontext, Union[Pool, ThreadPool]]
232+
if pool:
233+
p = pool
234+
ctx_p = nullcontext()
235+
else:
236+
p = _get_pool(backend, num_worker)
237+
ctx_p = p
238+
239+
with ctx_p:
240+
imap = p.imap(func, da.batch(batch_size=batch_size, shuffle=shuffle))
241+
for x in track(
242+
imap, total=ceil(len(da) / batch_size), disable=not show_progress
243+
):
244+
yield x
245+
246+
113247
def _get_pool(backend, num_worker) -> Union[Pool, ThreadPool]:
114248
if backend == 'thread':
115249
return ThreadPool(processes=num_worker)

tests/benchmark_tests/test_apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def time_multiprocessing(num_workers: int) -> float:
3434
return time() - start_time
3535

3636
time_1_cpu = time_multiprocessing(num_workers=1)
37+
print(f"time_1_cpu = {time_1_cpu}")
3738
time_2_cpu = time_multiprocessing(num_workers=2)
39+
print(f"time_2_cpu = {time_2_cpu}")
3840

3941
assert time_2_cpu < time_1_cpu
4042

@@ -57,6 +59,8 @@ def time_multithreading(num_workers: int) -> float:
5759
return time() - start_time
5860

5961
time_1_thread = time_multithreading(num_workers=1)
62+
print(f"time_1_thread = {time_1_thread}")
6063
time_2_thread = time_multithreading(num_workers=2)
64+
print(f"time_2_thread = {time_2_thread}")
6165

6266
assert time_2_thread < time_1_thread

tests/units/util/test_apply.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from typing import Generator
2+
13
import pytest
24

35
from docarray import DocumentArray
46
from docarray.documents import Image
5-
from docarray.utils.apply import apply
7+
from docarray.utils.apply import _map_batch, apply, apply_batch
68
from tests.units.typing.test_bytes import IMAGE_PATHS
79

810

9-
def foo(d: Image) -> Image:
11+
def load_from_doc(d: Image) -> Image:
1012
if d.url is not None:
1113
d.tensor = d.url.load()
1214
return d
@@ -25,7 +27,7 @@ def test_apply(da, backend):
2527
for tensor in da.tensor:
2628
assert tensor is None
2729

28-
da_applied = apply(da=da, func=foo, backend=backend)
30+
da_applied = apply(da=da, func=load_from_doc, backend=backend)
2931

3032
assert len(da) == len(da_applied)
3133
for tensor in da_applied.tensor:
@@ -49,8 +51,38 @@ def local_func(x):
4951
def test_check_order(backend):
5052
da = DocumentArray[Image]([Image(id=i) for i in range(2)])
5153

52-
da_applied = apply(da=da, func=foo, backend=backend)
54+
da_applied = apply(da=da, func=load_from_doc, backend=backend)
5355

5456
assert len(da) == len(da_applied)
5557
for id_1, id_2 in zip(da, da_applied):
5658
assert id_1 == id_2
59+
60+
61+
def load_from_da(da: DocumentArray[Image]) -> DocumentArray[Image]:
62+
da_new = da.__class_getitem__(da.document_type)([Image() for _ in da])
63+
return da_new
64+
65+
66+
@pytest.mark.parametrize('n_docs,batch_size', [(10, 5), (10, 7)])
67+
@pytest.mark.parametrize('backend', ['thread', 'process'])
68+
def test_apply_batch_multithreading(n_docs, batch_size, backend):
69+
70+
da = DocumentArray[Image]([Image(url=IMAGE_PATHS['png']) for _ in range(n_docs)])
71+
da_applied = apply_batch(
72+
da=da, func=load_from_da, batch_size=batch_size, backend=backend
73+
)
74+
75+
for doc in da_applied:
76+
assert isinstance(doc, Image)
77+
78+
79+
@pytest.mark.parametrize('n_docs,batch_size', [(10, 5), (10, 7)])
80+
@pytest.mark.parametrize('backend', ['thread', 'process'])
81+
def test_map_batch(n_docs, batch_size, backend):
82+
83+
da = DocumentArray[Image]([Image(url=IMAGE_PATHS['png']) for _ in range(n_docs)])
84+
it = _map_batch(da=da, func=load_from_da, batch_size=batch_size, backend=backend)
85+
assert isinstance(it, Generator)
86+
87+
for batch in it:
88+
assert isinstance(batch, DocumentArray[Image])

0 commit comments

Comments
 (0)