44from types import LambdaType
55from typing import Any , Callable , Generator , Optional , TypeVar , Union
66
7+ from rich .progress import track
8+
79from docarray import BaseDocument
810from docarray .array .abstract_array import AnyDocumentArray
911
@@ -18,10 +20,10 @@ def apply(
1820 num_worker : Optional [int ] = None ,
1921 pool : Optional [Union [Pool , ThreadPool ]] = None ,
2022 show_progress : bool = False ,
21- ) -> T :
23+ ) -> None :
2224 """
23- Apply `func` to every Document of the given DocumentArray while multithreading or
24- multiprocessing, return itself after modification .
25+ Apply `func` to every Document of the given DocumentArray in-place while multithreading
26+ or multiprocessing .
2527
2628 EXAMPLE USAGE
2729
@@ -38,7 +40,7 @@ def load_url_to_tensor(img: Image) -> Image:
3840
3941
4042 da = DocumentArray[Image]([Image(url='path/to/img.png') for _ in range(100)])
41- da = apply(
43+ apply(
4244 da, load_url_to_tensor, backend='thread'
4345 ) # threading is usually a good option for IO-bound tasks such as loading an image from url
4446
@@ -68,11 +70,9 @@ def load_url_to_tensor(img: Image) -> Image:
6870 be responsible for closing the pool.
6971 :param show_progress: show a progress bar. Defaults to False.
7072
71- :return: DocumentArray with applied modifications
7273 """
7374 for i , doc in enumerate (_map (da , func , backend , num_worker , pool , show_progress )):
7475 da [i ] = doc
75- return da
7676
7777
7878def _map (
@@ -115,7 +115,6 @@ def _map(
115115
116116 :yield: Documents returned from `func`
117117 """
118- from rich .progress import track
119118
120119 if backend == 'process' and _is_lambda_or_partial_or_local_function (func ):
121120 raise ValueError (
@@ -145,10 +144,9 @@ def apply_batch(
145144 shuffle : bool = False ,
146145 pool : Optional [Union [Pool , ThreadPool ]] = None ,
147146 show_progress : bool = False ,
148- ) -> T :
147+ ) -> None :
149148 """
150- Batches itself into mini-batches, applies `func` to every mini-batch, and return
151- itself after the modifications.
149+ Batches itself into mini-batches, applies `func` to every mini-batch in-place.
152150
153151 EXAMPLE USAGE
154152
@@ -168,7 +166,7 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
168166
169167
170168 da = DocumentArray[MyDoc]([MyDoc(name='my orange cat') for _ in range(100)])
171- da = apply_batch(da, upper_case_name, batch_size=10)
169+ apply_batch(da, upper_case_name, batch_size=10)
172170 print(da.name[:3])
173171
174172 .. code-block:: text
@@ -200,8 +198,6 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
200198 :param pool: use an existing/external process or thread pool. If given, you will
201199 be responsible for closing the pool.
202200 :param show_progress: show a progress bar. Defaults to False.
203-
204- :return DocumentArray after modifications
205201 """
206202 for i , batch in enumerate (
207203 _map_batch (
@@ -210,7 +206,6 @@ def upper_case_name(da: DocumentArray[MyDoc]) -> DocumentArray[MyDoc]:
210206 ):
211207 indices = [i for i in range (i * batch_size , (i + 1 ) * batch_size )]
212208 da [indices ] = batch
213- return da
214209
215210
216211def _map_batch (
@@ -256,15 +251,13 @@ def _map_batch(
256251 :param pool: use an existing/external pool. If given, `backend` is ignored and you will
257252 be responsible for closing the pool.
258253
259- :yield: anything return from `` func` `
254+ :yield: DocumentArrays returned from `func`
260255 """
261256 if backend == 'process' and _is_lambda_or_partial_or_local_function (func ):
262257 raise ValueError (
263258 f'Multiprocessing does not allow functions that are local, lambda or partial: { func } '
264259 )
265260
266- from rich .progress import track
267-
268261 ctx_p : Union [nullcontext , Union [Pool , ThreadPool ]]
269262 if pool :
270263 p = pool
@@ -274,7 +267,7 @@ def _map_batch(
274267 ctx_p = p
275268
276269 with ctx_p :
277- imap = p .imap (func , da .batch (batch_size = batch_size , shuffle = shuffle ))
270+ imap = p .imap (func , da ._batch (batch_size = batch_size , shuffle = shuffle ))
278271 for x in track (
279272 imap , total = ceil (len (da ) / batch_size ), disable = not show_progress
280273 ):
0 commit comments