Skip to content

Commit 14e3cc1

Browse files
author
anna-charlotte
committed
fix: apply
Signed-off-by: anna-charlotte <[email protected]>
1 parent 1f064f3 commit 14e3cc1

File tree

4 files changed

+56
-41
lines changed

4 files changed

+56
-41
lines changed

docarray/utils/apply.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from types import LambdaType
55
from typing import Any, Callable, Generator, Optional, TypeVar, Union
66

7+
from rich.progress import track
8+
79
from docarray import BaseDocument
810
from docarray.array.abstract_array import AnyDocumentArray
911

@@ -18,10 +20,10 @@ def apply(
1820
num_worker: Optional[int] = None,
1921
pool: Optional[Union[Pool, ThreadPool]] = None,
2022
show_progress: bool = False,
21-
) -> T:
23+
) -> None:
2224
"""
23-
Apply `func` to every Document of the given DocumentArray while multithreading or
24-
multiprocessing, return itself after modification.
25+
Apply `func` to every Document of the given DocumentArray in-place while multithreading
26+
or multiprocessing.
2527
2628
EXAMPLE USAGE
2729
@@ -38,7 +40,7 @@ def load_url_to_tensor(img: Image) -> Image:
3840
3941
4042
da = DocumentArray[Image]([Image(url='path/to/img.png') for _ in range(100)])
41-
da = apply(
43+
apply(
4244
da, load_url_to_tensor, backend='thread'
4345
) # threading is usually a good option for IO-bound tasks such as loading an image from url
4446
@@ -68,11 +70,9 @@ def load_url_to_tensor(img: Image) -> Image:
6870
be responsible for closing the pool.
6971
:param show_progress: show a progress bar. Defaults to False.
7072
71-
:return: DocumentArray with applied modifications
7273
"""
7374
for i, doc in enumerate(_map(da, func, backend, num_worker, pool, show_progress)):
7475
da[i] = doc
75-
return da
7676

7777

7878
def _map(
@@ -115,7 +115,6 @@ def _map(
115115
116116
:yield: Documents returned from `func`
117117
"""
118-
from rich.progress import track
119118

120119
if backend == 'process' and _is_lambda_or_partial_or_local_function(func):
121120
raise ValueError(
@@ -145,10 +144,9 @@ def apply_batch(
145144
shuffle: bool = False,
146145
pool: Optional[Union[Pool, ThreadPool]] = None,
147146
show_progress: bool = False,
148-
) -> T:
147+
) -> None:
149148
"""
150-
Batches itself into mini-batches, applies `func` to every mini-batch, and return
151-
itself after the modifications.
149+
Batches itself into mini-batches, applies `func` to every mini-batch in-place.
152150
153151
EXAMPLE USAGE
154152
@@ -168,7 +166,7 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
168166
169167
170168
da = DocumentArray[MyDoc]([MyDoc(name='my orange cat') for _ in range(100)])
171-
da = apply_batch(da, upper_case_name, batch_size=10)
169+
apply_batch(da, upper_case_name, batch_size=10)
172170
print(da.name[:3])
173171
174172
.. code-block:: text
@@ -200,8 +198,6 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
200198
:param pool: use an existing/external process or thread pool. If given, you will
201199
be responsible for closing the pool.
202200
:param show_progress: show a progress bar. Defaults to False.
203-
204-
:return DocumentArray after modifications
205201
"""
206202
for i, batch in enumerate(
207203
_map_batch(
@@ -210,7 +206,6 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
210206
):
211207
indices = [i for i in range(i * batch_size, (i + 1) * batch_size)]
212208
da[indices] = batch
213-
return da
214209

215210

216211
def _map_batch(
@@ -256,15 +251,13 @@ def _map_batch(
256251
:param pool: use an existing/external pool. If given, `backend` is ignored and you will
257252
be responsible for closing the pool.
258253
259-
:yield: anything return from ``func``
254+
:yield: DocumentArrays returned from `func`
260255
"""
261256
if backend == 'process' and _is_lambda_or_partial_or_local_function(func):
262257
raise ValueError(
263258
f'Multiprocessing does not allow functions that are local, lambda or partial: {func}'
264259
)
265260

266-
from rich.progress import track
267-
268261
ctx_p: Union[nullcontext, Union[Pool, ThreadPool]]
269262
if pool:
270263
p = pool
@@ -274,7 +267,7 @@ def _map_batch(
274267
ctx_p = p
275268

276269
with ctx_p:
277-
imap = p.imap(func, da.batch(batch_size=batch_size, shuffle=shuffle))
270+
imap = p.imap(func, da._batch(batch_size=batch_size, shuffle=shuffle))
278271
for x in track(
279272
imap, total=ceil(len(da) / batch_size), disable=not show_progress
280273
):

tests/benchmark_tests/test_apply.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from docarray import BaseDocument, DocumentArray
77
from docarray.documents import Image
88
from docarray.typing import NdArray
9-
from docarray.utils.apply import apply
9+
from docarray.utils.apply import apply, apply_batch
1010
from tests.units.typing.test_bytes import IMAGE_PATHS
1111

1212
pytestmark = pytest.mark.benchmark
@@ -34,17 +34,14 @@ def time_multiprocessing(num_workers: int) -> float:
3434
return time() - start_time
3535

3636
time_1_cpu = time_multiprocessing(num_workers=1)
37-
print(f"time_1_cpu = {time_1_cpu}")
3837
time_2_cpu = time_multiprocessing(num_workers=2)
39-
print(f"time_2_cpu = {time_2_cpu}")
4038

4139
assert time_2_cpu < time_1_cpu
4240

4341

4442
def io_intensive(img: Image) -> Image:
4543
# some io intensive function: load and set image url
46-
t = img.url.load()
47-
img.tensor = t
44+
img.tensor = img.url.load()
4845
return img
4946

5047

@@ -59,8 +56,35 @@ def time_multithreading(num_workers: int) -> float:
5956
return time() - start_time
6057

6158
time_1_thread = time_multithreading(num_workers=1)
62-
print(f"time_1_thread = {time_1_thread}")
6359
time_2_thread = time_multithreading(num_workers=2)
64-
print(f"time_2_thread = {time_2_thread}")
60+
61+
assert time_2_thread < time_1_thread
62+
63+
64+
def io_intensive_batch(da: DocumentArray[Image]) -> DocumentArray[Image]:
65+
# some io intensive function: load and set image url
66+
for doc in da:
67+
doc.tensor = doc.url.load()
68+
return da
69+
70+
71+
def test_apply_batch_multithreading_benchmark():
72+
def time_multithreading_batch(num_workers: int) -> float:
73+
n_docs = 100
74+
da = DocumentArray[Image](
75+
[Image(url=IMAGE_PATHS['png']) for _ in range(n_docs)]
76+
)
77+
start_time = time()
78+
apply_batch(
79+
da=da,
80+
func=io_intensive_batch,
81+
backend='thread',
82+
num_worker=num_workers,
83+
batch_size=10,
84+
)
85+
return time() - start_time
86+
87+
time_1_thread = time_multithreading_batch(num_workers=1)
88+
time_2_thread = time_multithreading_batch(num_workers=2)
6589

6690
assert time_2_thread < time_1_thread

tests/units/array/test_batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class MyDoc(BaseDocument):
2626
if stack:
2727
da = da.stack()
2828

29-
batches = list(da.batch(batch_size=batch_size, shuffle=shuffle))
29+
batches = list(da._batch(batch_size=batch_size, shuffle=shuffle))
3030
assert len(batches) == n_batches
3131

3232
for i, batch in enumerate(batches):

tests/units/util/test_apply.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from docarray.utils.apply import _map_batch, apply, apply_batch
88
from tests.units.typing.test_bytes import IMAGE_PATHS
99

10+
N_DOCS = 2
11+
1012

1113
def load_from_doc(d: Image) -> Image:
1214
if d.url is not None:
@@ -16,9 +18,7 @@ def load_from_doc(d: Image) -> Image:
1618

1719
@pytest.fixture()
1820
def da():
19-
da = DocumentArray[Image](
20-
[Image(url=url) for url in IMAGE_PATHS.values() for _ in range(2)]
21-
)
21+
da = DocumentArray[Image]([Image(url=IMAGE_PATHS['png']) for _ in range(N_DOCS)])
2222
return da
2323

2424

@@ -27,10 +27,10 @@ def test_apply(da, backend):
2727
for tensor in da.tensor:
2828
assert tensor is None
2929

30-
da_applied = apply(da=da, func=load_from_doc, backend=backend)
30+
apply(da=da, func=load_from_doc, backend=backend)
3131

32-
assert len(da) == len(da_applied)
33-
for tensor in da_applied.tensor:
32+
assert len(da) == N_DOCS
33+
for tensor in da.tensor:
3434
assert tensor is not None
3535

3636

@@ -49,13 +49,13 @@ def local_func(x):
4949

5050
@pytest.mark.parametrize('backend', ['thread', 'process'])
5151
def test_check_order(backend):
52-
da = DocumentArray[Image]([Image(id=i) for i in range(2)])
52+
da = DocumentArray[Image]([Image(id=i) for i in range(N_DOCS)])
5353

54-
da_applied = apply(da=da, func=load_from_doc, backend=backend)
54+
apply(da=da, func=load_from_doc, backend=backend)
5555

56-
assert len(da) == len(da_applied)
57-
for id_1, id_2 in zip(da, da_applied):
58-
assert id_1 == id_2
56+
assert len(da) == N_DOCS
57+
for i, id_1 in enumerate(da.id):
58+
assert id_1 == str(i)
5959

6060

6161
def load_from_da(da: DocumentArray[Image]) -> DocumentArray[Image]:
@@ -69,11 +69,9 @@ def load_from_da(da: DocumentArray[Image]) -> DocumentArray[Image]:
6969
def test_apply_batch_multithreading(n_docs, batch_size, backend):
7070

7171
da = DocumentArray[Image]([Image(url=IMAGE_PATHS['png']) for _ in range(n_docs)])
72-
da_applied = apply_batch(
73-
da=da, func=load_from_da, batch_size=batch_size, backend=backend
74-
)
72+
apply_batch(da=da, func=load_from_da, batch_size=batch_size, backend=backend)
7573

76-
for doc in da_applied:
74+
for doc in da:
7775
assert isinstance(doc, Image)
7876

7977

0 commit comments

Comments
 (0)