Skip to content

Commit d72cbf5

Browse files
author
anna-charlotte
committed
fix: apply
Signed-off-by: anna-charlotte <[email protected]>
1 parent 06ba7eb commit d72cbf5

File tree

2 files changed

+20
-102
lines changed

2 files changed

+20
-102
lines changed

docarray/utils/apply.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import uuid
21
from contextlib import nullcontext
32
from types import LambdaType
43
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, TypeVar, Union
@@ -36,7 +35,7 @@ def apply(
3635
:return: DocumentArray with applied modifications
3736
"""
3837
da_new = da.__class_getitem__(item=da.document_type)()
39-
for i, doc in enumerate(_map(da, func, num_worker, show_progress, pool)):
38+
for i, doc in enumerate(_map(da, func, num_worker, pool, show_progress)):
4039
da_new.append(doc)
4140
return da_new
4241

@@ -67,9 +66,9 @@ def _map(
6766
from rich.progress import track
6867

6968
if _is_lambda_or_partial_or_local_function(func):
70-
print(f"func = {func}")
71-
func = _globalize_function(func)
72-
print(f"func = {func}")
69+
raise ValueError(
70+
f'Multiprocessing does not allow functions that are local, lambda or partial: {func}'
71+
)
7372

7473
ctx_p: Union[nullcontext, 'Pool']
7574
if pool:
@@ -86,20 +85,9 @@ def _map(
8685
yield x
8786

8887

89-
def _is_lambda_or_partial_or_local_function(func: Callable[[Any], Any]):
88+
def _is_lambda_or_partial_or_local_function(func: Callable[[Any], Any]) -> bool:
9089
return (
9190
(isinstance(func, LambdaType) and func.__name__ == '<lambda>')
9291
or not hasattr(func, '__qualname__')
9392
or ('<locals>' in func.__qualname__)
9493
)
95-
96-
97-
def _globalize_function(func):
98-
import sys
99-
100-
def result(*args, **kwargs):
101-
return func(*args, **kwargs)
102-
103-
result.__name__ = result.__qualname__ = uuid.uuid4().hex
104-
setattr(sys.modules[result.__module__], result.__name__, result)
105-
return result

tests/units/util/test_apply.py

Lines changed: 15 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
import time
2-
from multiprocessing import cpu_count
3-
from typing import Optional
4-
5-
import numpy as np
61
import pytest
72

8-
from docarray import BaseDocument, DocumentArray
3+
from docarray import DocumentArray
94
from docarray.documents import Image
10-
from docarray.typing import NdArray
115
from docarray.utils.apply import apply
126
from tests.units.typing.test_bytes import IMAGE_PATHS
137

@@ -21,7 +15,7 @@ def foo(d: Image) -> Image:
2115
@pytest.fixture()
2216
def da():
2317
da = DocumentArray[Image](
24-
[Image(url=url) for url in IMAGE_PATHS.values() for _ in range(10)]
18+
[Image(url=url) for url in IMAGE_PATHS.values() for _ in range(2)]
2519
)
2620
return da
2721

@@ -37,88 +31,24 @@ def test_apply(da):
3731
assert tensor is not None
3832

3933

40-
def test_apply_with_lambda(da):
41-
for tensor in da.tensor:
42-
assert tensor is None
34+
def test_apply_lambda_func_raise_exception(da):
35+
with pytest.raises(ValueError, match='Multiprocessing does not allow'):
36+
apply(da=da, func=lambda x: x)
4337

44-
da_applied = apply(da=da, func=lambda x: x)
4538

46-
assert len(da) == len(da_applied)
47-
for tensor in da_applied.tensor:
48-
assert tensor is None
39+
def test_apply_local_func_raise_exception(da):
40+
def local_func(x):
41+
return x
4942

43+
with pytest.raises(ValueError, match='Multiprocessing does not allow'):
44+
apply(da=da, func=local_func)
5045

51-
def test_apply_with_local_function(da):
52-
def local_func(d: Image) -> Image:
53-
if d.url is not None:
54-
d.tensor = d.url.load()
55-
return d
5646

57-
for tensor in da.tensor:
58-
assert tensor is None
47+
def test_check_order():
48+
da = DocumentArray[Image]([Image(id=i) for i in range(2)])
5949

60-
da_applied = apply(da=da, func=local_func)
50+
da_applied = apply(da=da, func=foo)
6151

6252
assert len(da) == len(da_applied)
63-
for tensor in da_applied.tensor:
64-
assert tensor is None
65-
66-
67-
class MyDoc(BaseDocument):
68-
tensor_a: Optional[NdArray]
69-
tensor_b: Optional[NdArray]
70-
tensor_matmul: Optional[NdArray]
71-
72-
73-
@pytest.fixture()
74-
def func():
75-
def matmul(doc):
76-
if doc.tensor_a is not None and doc.tensor_b is not None:
77-
doc.tensor_matmul = np.matmul(doc.tensor_a, doc.tensor_b)
78-
return doc
79-
80-
return matmul
81-
82-
83-
def matmul(doc):
84-
if doc.tensor_a is not None and doc.tensor_b is not None:
85-
doc.tensor_matmul = np.matmul(doc.tensor_a, doc.tensor_b)
86-
return doc
87-
88-
89-
def test_benchmark(func):
90-
time_mproc = []
91-
time_no_mproc = []
92-
93-
for n_docs in [1, 2]:
94-
da = DocumentArray[MyDoc](
95-
[
96-
MyDoc(
97-
tensor_a=np.random.randn(100, 200),
98-
tensor_b=np.random.randn(200, 100),
99-
)
100-
for _ in range(n_docs)
101-
]
102-
)
103-
104-
# with multiprocessing
105-
start_time = time.time()
106-
apply(da=da, func=func)
107-
duration_mproc = time.time() - start_time
108-
time_mproc.append(duration_mproc)
109-
110-
# without multiprocessing
111-
start_time = time.time()
112-
da_no_mproc = DocumentArray[MyDoc]()
113-
for i, doc in enumerate(da):
114-
da_no_mproc.append(func(doc))
115-
duration_no_mproc = time.time() - start_time
116-
time_no_mproc.append(duration_no_mproc)
117-
118-
# if more than 1 CPU available, check that when using multiprocessing
119-
# grows slower with more documents, then without multiprocessing.
120-
print(f"cpu_count() = {cpu_count()}")
121-
if cpu_count() > 1:
122-
growth_factor_mproc = time_mproc[1] / time_mproc[0]
123-
growth_factor_no_mproc = time_no_mproc[1] / time_no_mproc[0]
124-
assert growth_factor_mproc < growth_factor_no_mproc
53+
for id_1, id_2 in zip(da, da_applied):
54+
assert id_1 == id_2

0 commit comments

Comments
 (0)