|
| 1 | +import uuid |
| 2 | +from contextlib import nullcontext |
| 3 | +from types import LambdaType |
| 4 | +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, TypeVar, Union |
| 5 | + |
| 6 | +from docarray import BaseDocument |
| 7 | +from docarray.array.abstract_array import AnyDocumentArray |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from multiprocessing.pool import Pool |
| 11 | + |
| 12 | + |
| 13 | +T = TypeVar('T', bound=AnyDocumentArray) |
| 14 | + |
| 15 | + |
| 16 | +def apply( |
| 17 | + da: T, |
| 18 | + func: Callable[[BaseDocument], BaseDocument], |
| 19 | + num_worker: Optional[int] = None, |
| 20 | + pool: Optional['Pool'] = None, |
| 21 | + show_progress: bool = False, |
| 22 | +) -> T: |
| 23 | + """ |
| 24 | + Apply `func` to every Document of the given DocumentArray while multiprocessing, |
| 25 | + return itself after modification, without in-place changes. |
| 26 | +
|
| 27 | + :param da: DocumentArray to apply function to |
| 28 | + :param func: a function that takes ab:class:`BaseDocument` as input and outputs |
| 29 | + a :class:`BaseDocument`. |
| 30 | + :param num_worker: the number of parallel workers. If not given, the number of |
| 31 | + CPUs in the system will be used. |
| 32 | + :param pool: use an existing/external process or thread pool. If given, you will |
| 33 | + be responsible for closing the pool. |
| 34 | + :param show_progress: show a progress bar. Defaults to False. |
| 35 | +
|
| 36 | + :return: DocumentArray with applied modifications |
| 37 | + """ |
| 38 | + da_new = da.__class_getitem__(item=da.document_type)() |
| 39 | + for i, doc in enumerate(_map(da, func, num_worker, show_progress, pool)): |
| 40 | + da_new.append(doc) |
| 41 | + return da_new |
| 42 | + |
| 43 | + |
| 44 | +def _map( |
| 45 | + da: T, |
| 46 | + func: Callable[[BaseDocument], BaseDocument], |
| 47 | + num_worker: Optional[int] = None, |
| 48 | + pool: Optional['Pool'] = None, |
| 49 | + show_progress: bool = False, |
| 50 | +) -> Generator['BaseDocument', None, None]: |
| 51 | + """ |
| 52 | + Return an iterator that applies `func` to every Document in `da` in parallel, |
| 53 | + yielding the results. |
| 54 | +
|
| 55 | + :param da: DocumentArray to apply function to |
| 56 | + :param func:a function that takes ab:class:`BaseDocument` as input and outputs |
| 57 | + a :class:`BaseDocument`. You can either modify elements in-place or return |
| 58 | + new Documents. |
| 59 | + :param num_worker: the number of parallel workers. If not given, the number of |
| 60 | + CPUs in the system will be used. |
| 61 | + use an existing/external process or thread pool. If given, you will |
| 62 | + be responsible for closing the pool. |
| 63 | + :param show_progress: show a progress bar. Defaults to False. |
| 64 | +
|
| 65 | + :yield: Documents returned from `func` |
| 66 | + """ |
| 67 | + from rich.progress import track |
| 68 | + |
| 69 | + if _is_lambda_or_partial_or_local_function(func): |
| 70 | + print(f"func = {func}") |
| 71 | + func = _globalize_function(func) |
| 72 | + print(f"func = {func}") |
| 73 | + |
| 74 | + ctx_p: Union[nullcontext, 'Pool'] |
| 75 | + if pool: |
| 76 | + p = pool |
| 77 | + ctx_p = nullcontext() |
| 78 | + else: |
| 79 | + from multiprocessing.pool import Pool |
| 80 | + |
| 81 | + p = Pool(processes=num_worker) |
| 82 | + ctx_p = p |
| 83 | + |
| 84 | + with ctx_p: |
| 85 | + for x in track(p.imap(func, da), total=len(da), disable=not show_progress): |
| 86 | + yield x |
| 87 | + |
| 88 | + |
| 89 | +def _is_lambda_or_partial_or_local_function(func: Callable[[Any], Any]): |
| 90 | + return ( |
| 91 | + (isinstance(func, LambdaType) and func.__name__ == '<lambda>') |
| 92 | + or not hasattr(func, '__qualname__') |
| 93 | + or ('<locals>' in func.__qualname__) |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +def _globalize_function(func): |
| 98 | + import sys |
| 99 | + |
| 100 | + def result(*args, **kwargs): |
| 101 | + return func(*args, **kwargs) |
| 102 | + |
| 103 | + result.__name__ = result.__qualname__ = uuid.uuid4().hex |
| 104 | + setattr(sys.modules[result.__module__], result.__name__, result) |
| 105 | + return result |
0 commit comments