Skip to content

Commit 0c3bdfd

Browse files
author
anna-charlotte
committed
feat: add apply function
Signed-off-by: anna-charlotte <[email protected]>
1 parent 28b96fe commit 0c3bdfd

File tree

3 files changed

+230
-3
lines changed

3 files changed

+230
-3
lines changed

docarray/display/tensor_display.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def __rich_console__(
3030
from rich.segment import Segment
3131
from rich.style import Style
3232

33-
tensor_normalized = comp_be.minmax_normalize(
34-
comp_be.detach(self.tensor), (0, 5)
35-
)
33+
tensor_normalized = comp_be.minmax_normalize(t_squeezed, (0, 5))
3634

3735
hue = 0.75
3836
saturation = 1.0

docarray/utils/apply.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import uuid
2+
from contextlib import nullcontext
3+
from types import LambdaType
4+
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, TypeVar, Union
5+
6+
from docarray import BaseDocument
7+
from docarray.array.abstract_array import AnyDocumentArray
8+
9+
if TYPE_CHECKING:
10+
from multiprocessing.pool import Pool
11+
12+
13+
T = TypeVar('T', bound=AnyDocumentArray)
14+
15+
16+
def apply(
17+
da: T,
18+
func: Callable[[BaseDocument], BaseDocument],
19+
num_worker: Optional[int] = None,
20+
pool: Optional['Pool'] = None,
21+
show_progress: bool = False,
22+
) -> T:
23+
"""
24+
Apply `func` to every Document of the given DocumentArray while multiprocessing,
25+
return itself after modification, without in-place changes.
26+
27+
:param da: DocumentArray to apply function to
28+
:param func: a function that takes ab:class:`BaseDocument` as input and outputs
29+
a :class:`BaseDocument`.
30+
:param num_worker: the number of parallel workers. If not given, the number of
31+
CPUs in the system will be used.
32+
:param pool: use an existing/external process or thread pool. If given, you will
33+
be responsible for closing the pool.
34+
:param show_progress: show a progress bar. Defaults to False.
35+
36+
:return: DocumentArray with applied modifications
37+
"""
38+
da_new = da.__class_getitem__(item=da.document_type)()
39+
for i, doc in enumerate(_map(da, func, num_worker, show_progress, pool)):
40+
da_new.append(doc)
41+
return da_new
42+
43+
44+
def _map(
45+
da: T,
46+
func: Callable[[BaseDocument], BaseDocument],
47+
num_worker: Optional[int] = None,
48+
pool: Optional['Pool'] = None,
49+
show_progress: bool = False,
50+
) -> Generator['BaseDocument', None, None]:
51+
"""
52+
Return an iterator that applies `func` to every Document in `da` in parallel,
53+
yielding the results.
54+
55+
:param da: DocumentArray to apply function to
56+
:param func:a function that takes ab:class:`BaseDocument` as input and outputs
57+
a :class:`BaseDocument`. You can either modify elements in-place or return
58+
new Documents.
59+
:param num_worker: the number of parallel workers. If not given, the number of
60+
CPUs in the system will be used.
61+
use an existing/external process or thread pool. If given, you will
62+
be responsible for closing the pool.
63+
:param show_progress: show a progress bar. Defaults to False.
64+
65+
:yield: Documents returned from `func`
66+
"""
67+
from rich.progress import track
68+
69+
if _is_lambda_or_partial_or_local_function(func):
70+
print(f"func = {func}")
71+
func = _globalize_function(func)
72+
print(f"func = {func}")
73+
74+
ctx_p: Union[nullcontext, 'Pool']
75+
if pool:
76+
p = pool
77+
ctx_p = nullcontext()
78+
else:
79+
from multiprocessing.pool import Pool
80+
81+
p = Pool(processes=num_worker)
82+
ctx_p = p
83+
84+
with ctx_p:
85+
for x in track(p.imap(func, da), total=len(da), disable=not show_progress):
86+
yield x
87+
88+
89+
def _is_lambda_or_partial_or_local_function(func: Callable[[Any], Any]):
90+
return (
91+
(isinstance(func, LambdaType) and func.__name__ == '<lambda>')
92+
or not hasattr(func, '__qualname__')
93+
or ('<locals>' in func.__qualname__)
94+
)
95+
96+
97+
def _globalize_function(func):
98+
import sys
99+
100+
def result(*args, **kwargs):
101+
return func(*args, **kwargs)
102+
103+
result.__name__ = result.__qualname__ = uuid.uuid4().hex
104+
setattr(sys.modules[result.__module__], result.__name__, result)
105+
return result

tests/units/util/test_apply.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import time
2+
from multiprocessing import cpu_count
3+
from typing import Optional
4+
5+
import numpy as np
6+
import pytest
7+
8+
from docarray import BaseDocument, DocumentArray
9+
from docarray.documents import Image
10+
from docarray.typing import NdArray
11+
from docarray.utils.apply import apply
12+
from tests.units.typing.test_bytes import IMAGE_PATHS
13+
14+
15+
def foo(d: Image) -> Image:
16+
if d.url is not None:
17+
d.tensor = d.url.load()
18+
return d
19+
20+
21+
@pytest.fixture()
22+
def da():
23+
da = DocumentArray[Image](
24+
[Image(url=url) for url in IMAGE_PATHS.values() for _ in range(10)]
25+
)
26+
return da
27+
28+
29+
def test_apply(da):
30+
for tensor in da.tensor:
31+
assert tensor is None
32+
33+
da_applied = apply(da=da, func=foo)
34+
35+
assert len(da) == len(da_applied)
36+
for tensor in da_applied.tensor:
37+
assert tensor is not None
38+
39+
40+
def test_apply_with_lambda(da):
41+
for tensor in da.tensor:
42+
assert tensor is None
43+
44+
da_applied = apply(da=da, func=lambda x: x)
45+
46+
assert len(da) == len(da_applied)
47+
for tensor in da_applied.tensor:
48+
assert tensor is None
49+
50+
51+
def test_apply_with_local_function(da):
52+
def local_func(d: Image) -> Image:
53+
if d.url is not None:
54+
d.tensor = d.url.load()
55+
return d
56+
57+
for tensor in da.tensor:
58+
assert tensor is None
59+
60+
da_applied = apply(da=da, func=local_func)
61+
62+
assert len(da) == len(da_applied)
63+
for tensor in da_applied.tensor:
64+
assert tensor is None
65+
66+
67+
class MyDoc(BaseDocument):
68+
tensor_a: Optional[NdArray]
69+
tensor_b: Optional[NdArray]
70+
tensor_matmul: Optional[NdArray]
71+
72+
73+
@pytest.fixture()
74+
def func():
75+
def matmul(doc):
76+
if doc.tensor_a is not None and doc.tensor_b is not None:
77+
doc.tensor_matmul = np.matmul(doc.tensor_a, doc.tensor_b)
78+
return doc
79+
80+
return matmul
81+
82+
83+
def matmul(doc):
84+
if doc.tensor_a is not None and doc.tensor_b is not None:
85+
doc.tensor_matmul = np.matmul(doc.tensor_a, doc.tensor_b)
86+
return doc
87+
88+
89+
def test_benchmark(func):
90+
time_mproc = []
91+
time_no_mproc = []
92+
93+
for n_docs in [1, 2]:
94+
da = DocumentArray[MyDoc](
95+
[
96+
MyDoc(
97+
tensor_a=np.random.randn(100, 200),
98+
tensor_b=np.random.randn(200, 100),
99+
)
100+
for _ in range(n_docs)
101+
]
102+
)
103+
104+
# with multiprocessing
105+
start_time = time.time()
106+
apply(da=da, func=func)
107+
duration_mproc = time.time() - start_time
108+
time_mproc.append(duration_mproc)
109+
110+
# without multiprocessing
111+
start_time = time.time()
112+
da_no_mproc = DocumentArray[MyDoc]()
113+
for i, doc in enumerate(da):
114+
da_no_mproc.append(func(doc))
115+
duration_no_mproc = time.time() - start_time
116+
time_no_mproc.append(duration_no_mproc)
117+
118+
# if more than 1 CPU available, check that when using multiprocessing
119+
# grows slower with more documents, then without multiprocessing.
120+
print(f"cpu_count() = {cpu_count()}")
121+
if cpu_count() > 1:
122+
growth_factor_mproc = time_mproc[1] / time_mproc[0]
123+
growth_factor_no_mproc = time_no_mproc[1] / time_no_mproc[0]
124+
assert growth_factor_mproc < growth_factor_no_mproc

0 commit comments

Comments
 (0)