11from contextlib import nullcontext
2+ from math import ceil
23from multiprocessing .pool import Pool , ThreadPool
34from types import LambdaType
45from typing import Any , Callable , Generator , Optional , TypeVar , Union
@@ -20,14 +21,15 @@ def apply(
2021) -> T :
2122 """
2223 Apply `func` to every Document of the given DocumentArray while multiprocessing,
23- return itself after modification, without in-place changes .
24+ return itself after modification.
2425
2526 :param da: DocumentArray to apply function to
2627 :param func: a function that takes ab:class:`BaseDocument` as input and outputs
2728 a :class:`BaseDocument`.
2829 :param backend: `thread` for multi-threading and `process` for multi-processing.
29- Defaults to `thread`. In general, if `func` is IO-bound then `thread` is a
30- good choice. If `func` is CPU-bound, then you may use `process`.
30+ Defaults to `thread`.
31+ In general, if `func` is IO-bound then `thread` is a good choice.
32+ On the other hand, if `func` is CPU-bound, then you may use `process`.
3133 In practice, you should try yourselves to figure out the best value.
3234 However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
3335 you should always use `thread` backend.
@@ -46,10 +48,9 @@ def apply(
4648
4749 :return: DocumentArray with applied modifications
4850 """
49- da_new = da .__class_getitem__ (item = da .document_type )()
5051 for i , doc in enumerate (_map (da , func , backend , num_worker , pool , show_progress )):
51- da_new . append ( doc )
52- return da_new
52+ da [ i ] = doc
53+ return da
5354
5455
5556def _map (
@@ -65,12 +66,13 @@ def _map(
6566 yielding the results.
6667
6768 :param da: DocumentArray to apply function to
68- :param func:a function that takes ab :class:`BaseDocument` as input and outputs
69+ :param func: a function that takes a :class:`BaseDocument` as input and outputs
6970 a :class:`BaseDocument`. You can either modify elements in-place or return
70- new Documents.
71+ new Documents (depending on `backend`) .
7172 :param backend: `thread` for multi-threading and `process` for multi-processing.
72- Defaults to `thread`. In general, if `func` is IO-bound then `thread` is a
73- good choice. If `func` is CPU-bound, then you may use `process`.
73+ Defaults to `thread`.
74+ In general, if `func` is IO-bound then `thread` is a good choice.
75+ On the other hand, if `func` is CPU-bound, then you may use `process`.
7476 In practice, you should try yourselves to figure out the best value.
7577 However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
7678 you should always use `thread` backend.
@@ -110,6 +112,138 @@ def _map(
110112 yield x
111113
112114
115+ def apply_batch (
116+ da : T ,
117+ func : Callable [[T ], T ],
118+ batch_size : int ,
119+ backend : str = 'thread' ,
120+ num_worker : Optional [int ] = None ,
121+ shuffle : bool = False ,
122+ pool : Optional [Union [Pool , ThreadPool ]] = None ,
123+ show_progress : bool = False ,
124+ ) -> T :
125+ """Batches itself into mini-batches, applies `func` to every mini-batch, and return itself after the modifications.
126+
127+ EXAMPLE USAGE
128+
129+ .. code-block:: python
130+
131+ from docarray import Document, DocumentArray
132+
133+ da = DocumentArray([Document(text='The cake is a lie') for _ in range(100)])
134+
135+
136+ def func(doc):
137+ da.texts = [t.upper() for t in da.texts]
138+ return da
139+
140+
141+ da.apply_batch(func, batch_size=10)
142+ print(da.texts[:3])
143+
144+ .. code-block:: text
145+
146+ ['THE CAKE IS A LIE', 'THE CAKE IS A LIE', 'THE CAKE IS A LIE']
147+
148+ :param da: DocumentArray to apply function to
149+ :param func: a function that takes a :class:`BaseDocument` as input and outputs
150+ a :class:`BaseDocument`.
151+ :param batch_size: size of each generated batch (except the last batch, which might
152+ be smaller).
153+ :param backend: `thread` for multi-threading and `process` for multi-processing.
154+ Defaults to `thread`.
155+ In general, if `func` is IO-bound then `thread` is a good choice.
156+ On the other hand, if `func` is CPU-bound, then you may use `process`.
157+ In practice, you should try yourselves to figure out the best value.
158+ However, if you wish to modify the elements in-place, regardless of IO/CPU-bound,
159+ you should always use `thread` backend.
160+
161+ .. warning::
162+ When using `process` backend, your `func` should not modify elements in-place.
163+ This is because the multiprocessing backend passes the variable via pickle
164+ and works in another process.
165+ The passed object and the original object do **not** share the same memory.
166+
167+ :param num_worker: the number of parallel workers. If not given, the number of CPUs
168+ in the system will be used.
169+ :param shuffle: If set, shuffle the Documents before dividing into minibatches.
170+ :param pool: use an existing/external process or thread pool. If given, you will
171+ be responsible for closing the pool.
172+ :param show_progress: show a progress bar. Defaults to False.
173+
174+ :return DocumentArray after modifications
175+ """
176+ for i , batch in enumerate (
177+ _map_batch (
178+ da , func , batch_size , backend , num_worker , shuffle , pool , show_progress
179+ )
180+ ):
181+ indices = [i for i in range (i * batch_size , (i + 1 ) * batch_size )]
182+ da [indices ] = batch
183+ return da
184+
185+
186+ def _map_batch (
187+ da : T ,
188+ func : Callable [[T ], T ],
189+ batch_size : int ,
190+ backend : str = 'thread' ,
191+ num_worker : Optional [int ] = None ,
192+ shuffle : bool = False ,
193+ pool : Optional [Union [Pool , ThreadPool ]] = None ,
194+ show_progress : bool = False ,
195+ ) -> Generator [T , None , None ]:
196+ """Return an iterator that applies function to every **minibatch** of iterable in parallel, yielding the results.
197+ Each element in the returned iterator is :class:`DocumentArray`.
198+
199+ .. seealso::
200+ - To process single element, please use :meth:`.map`;
201+ - To return :class:`DocumentArray`, please use :meth:`.apply_batch`.
202+
203+ :param batch_size: Size of each generated batch (except the last one, which might be smaller).
204+ :param shuffle: If set, shuffle the Documents before dividing into minibatches.
205+ :param func: a function that takes :class:`DocumentArray` as input and outputs anything. You can either modify elements
206+ in-place (only with `thread` backend) or work later on return elements.
207+ :param backend: if to use multi-`process` or multi-`thread` as the parallelization backend. In general, if your
208+ ``func`` is IO-bound then perhaps `thread` is good enough. If your ``func`` is CPU-bound then you may use `process`.
209+ In practice, you should try yourselves to figure out the best value. However, if you wish to modify the elements
210+ in-place, regardless of IO/CPU-bound, you should always use `thread` backend.
211+
212+ .. warning::
213+ When using `process` backend, you should not expect ``func`` modify elements in-place. This is because
214+ the multiprocessing backing pass the variable via pickle and work in another process. The passed object
215+ and the original object do **not** share the same memory.
216+
217+ :param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
218+ :param show_progress: show a progress bar
219+ :param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.
220+
221+ :yield: anything return from ``func``
222+ """
223+
224+ if backend == 'process' and _is_lambda_or_partial_or_local_function (func ):
225+ raise ValueError (
226+ f'Multiprocessing does not allow functions that are local, lambda or partial: { func } '
227+ )
228+
229+ from rich .progress import track
230+
231+ ctx_p : Union [nullcontext , Union [Pool , ThreadPool ]]
232+ if pool :
233+ p = pool
234+ ctx_p = nullcontext ()
235+ else :
236+ p = _get_pool (backend , num_worker )
237+ ctx_p = p
238+
239+ with ctx_p :
240+ imap = p .imap (func , da .batch (batch_size = batch_size , shuffle = shuffle ))
241+ for x in track (
242+ imap , total = ceil (len (da ) / batch_size ), disable = not show_progress
243+ ):
244+ yield x
245+
246+
113247def _get_pool (backend , num_worker ) -> Union [Pool , ThreadPool ]:
114248 if backend == 'thread' :
115249 return ThreadPool (processes = num_worker )
0 commit comments