1- import time
2- from multiprocessing import cpu_count
3- from typing import Optional
4-
5- import numpy as np
61import pytest
72
8- from docarray import BaseDocument , DocumentArray
3+ from docarray import DocumentArray
94from docarray .documents import Image
10- from docarray .typing import NdArray
115from docarray .utils .apply import apply
126from tests .units .typing .test_bytes import IMAGE_PATHS
137
@@ -21,7 +15,7 @@ def foo(d: Image) -> Image:
2115@pytest .fixture ()
2216def da ():
2317 da = DocumentArray [Image ](
24- [Image (url = url ) for url in IMAGE_PATHS .values () for _ in range (10 )]
18+ [Image (url = url ) for url in IMAGE_PATHS .values () for _ in range (2 )]
2519 )
2620 return da
2721
@@ -37,88 +31,24 @@ def test_apply(da):
3731 assert tensor is not None
3832
3933
40- def test_apply_with_lambda (da ):
41- for tensor in da . tensor :
42- assert tensor is None
34+ def test_apply_lambda_func_raise_exception (da ):
35+ with pytest . raises ( ValueError , match = 'Multiprocessing does not allow' ) :
36+ apply ( da = da , func = lambda x : x )
4337
44- da_applied = apply (da = da , func = lambda x : x )
4538
46- assert len (da ) == len ( da_applied )
47- for tensor in da_applied . tensor :
48- assert tensor is None
39+ def test_apply_local_func_raise_exception (da ):
40+ def local_func ( x ) :
41+ return x
4942
43+ with pytest .raises (ValueError , match = 'Multiprocessing does not allow' ):
44+ apply (da = da , func = local_func )
5045
51- def test_apply_with_local_function (da ):
52- def local_func (d : Image ) -> Image :
53- if d .url is not None :
54- d .tensor = d .url .load ()
55- return d
5646
57- for tensor in da . tensor :
58- assert tensor is None
47+ def test_check_order () :
48+ da = DocumentArray [ Image ]([ Image ( id = i ) for i in range ( 2 )])
5949
60- da_applied = apply (da = da , func = local_func )
50+ da_applied = apply (da = da , func = foo )
6151
6252 assert len (da ) == len (da_applied )
63- for tensor in da_applied .tensor :
64- assert tensor is None
65-
66-
67- class MyDoc (BaseDocument ):
68- tensor_a : Optional [NdArray ]
69- tensor_b : Optional [NdArray ]
70- tensor_matmul : Optional [NdArray ]
71-
72-
73- @pytest .fixture ()
74- def func ():
75- def matmul (doc ):
76- if doc .tensor_a is not None and doc .tensor_b is not None :
77- doc .tensor_matmul = np .matmul (doc .tensor_a , doc .tensor_b )
78- return doc
79-
80- return matmul
81-
82-
83- def matmul (doc ):
84- if doc .tensor_a is not None and doc .tensor_b is not None :
85- doc .tensor_matmul = np .matmul (doc .tensor_a , doc .tensor_b )
86- return doc
87-
88-
89- def test_benchmark (func ):
90- time_mproc = []
91- time_no_mproc = []
92-
93- for n_docs in [1 , 2 ]:
94- da = DocumentArray [MyDoc ](
95- [
96- MyDoc (
97- tensor_a = np .random .randn (100 , 200 ),
98- tensor_b = np .random .randn (200 , 100 ),
99- )
100- for _ in range (n_docs )
101- ]
102- )
103-
104- # with multiprocessing
105- start_time = time .time ()
106- apply (da = da , func = func )
107- duration_mproc = time .time () - start_time
108- time_mproc .append (duration_mproc )
109-
110- # without multiprocessing
111- start_time = time .time ()
112- da_no_mproc = DocumentArray [MyDoc ]()
113- for i , doc in enumerate (da ):
114- da_no_mproc .append (func (doc ))
115- duration_no_mproc = time .time () - start_time
116- time_no_mproc .append (duration_no_mproc )
117-
118- # if more than 1 CPU available, check that when using multiprocessing
119- # grows slower with more documents, then without multiprocessing.
120- print (f"cpu_count() = { cpu_count ()} " )
121- if cpu_count () > 1 :
122- growth_factor_mproc = time_mproc [1 ] / time_mproc [0 ]
123- growth_factor_no_mproc = time_no_mproc [1 ] / time_no_mproc [0 ]
124- assert growth_factor_mproc < growth_factor_no_mproc
53+ for id_1 , id_2 in zip (da , da_applied ):
54+ assert id_1 == id_2
0 commit comments