Skip to content

Commit 6cd5912

Browse files
author
Charlotte Gerhaher
authored
feat: add map function (#1187)
* feat: add apply function Signed-off-by: anna-charlotte <[email protected]> * test: add benchmark tests Signed-off-by: anna-charlotte <[email protected]> * fix: apply Signed-off-by: anna-charlotte <[email protected]> * fix: benchmark test Signed-off-by: anna-charlotte <[email protected]> * test: benchmark Signed-off-by: anna-charlotte <[email protected]> * fix: apply Signed-off-by: anna-charlotte <[email protected]> * fix: clean up Signed-off-by: anna-charlotte <[email protected]> * chore: remove benchmark tests from general tests Signed-off-by: anna-charlotte <[email protected]> * chore: fix ci Signed-off-by: anna-charlotte <[email protected]> * feat: add threading option and benchmark test Signed-off-by: anna-charlotte <[email protected]> * test: use both backend options in tests Signed-off-by: anna-charlotte <[email protected]> * feat: add batching to abstract array Signed-off-by: anna-charlotte <[email protected]> * feat: add apply_batch and _map_batch and tests Signed-off-by: anna-charlotte <[email protected]> * test: fix load from da Signed-off-by: anna-charlotte <[email protected]> * docs: update docstrings Signed-off-by: anna-charlotte <[email protected]> * docs: add example for apply Signed-off-by: anna-charlotte <[email protected]> * fix: mypy Signed-off-by: anna-charlotte <[email protected]> * refactor: clean up Signed-off-by: anna-charlotte <[email protected]> * refactor: make batch method private Signed-off-by: anna-charlotte <[email protected]> * fix: apply Signed-off-by: anna-charlotte <[email protected]> * Test: add for apply batch Signed-off-by: anna-charlotte <[email protected]> * fix: benchmark test increase ndocs Signed-off-by: anna-charlotte <[email protected]> * test: clean up Signed-off-by: anna-charlotte <[email protected]> * test: try to fix Signed-off-by: anna-charlotte <[email protected]> * test: try to fix test Signed-off-by: anna-charlotte <[email protected]> * fix: test Signed-off-by: anna-charlotte <[email protected]> * fix: test Signed-off-by: anna-charlotte <[email protected]> * fix: apply suggestions from code review Signed-off-by: anna-charlotte <[email protected]> * fix: remove print statemetns Signed-off-by: anna-charlotte <[email protected]> * fix: apply samis suggestion Signed-off-by: anna-charlotte <[email protected]> * fix: add tests for func da to doc and da to other len da Signed-off-by: anna-charlotte <[email protected]> * fix: revert last commit Signed-off-by: anna-charlotte <[email protected]> * test: add len assert Signed-off-by: anna-charlotte <[email protected]> * test: add assertions Signed-off-by: anna-charlotte <[email protected]> * test: add test to for da extend in batch apply Signed-off-by: anna-charlotte <[email protected]> * test: extend with only one doc Signed-off-by: anna-charlotte <[email protected]> * test: fix Signed-off-by: anna-charlotte <[email protected]> * fix: test Signed-off-by: anna-charlotte <[email protected]> * fix: test Signed-off-by: anna-charlotte <[email protected]> * fix: set docs in apply Signed-off-by: anna-charlotte <[email protected]> * fix: indices Signed-off-by: anna-charlotte <[email protected]> * fix: indices Signed-off-by: anna-charlotte <[email protected]> * fix: indices Signed-off-by: anna-charlotte <[email protected]> * fix: indices Signed-off-by: anna-charlotte <[email protected]> * fix:test Signed-off-by: anna-charlotte <[email protected]> * fix: mypy Signed-off-by: anna-charlotte <[email protected]> * fix: type hint Signed-off-by: anna-charlotte <[email protected]> * fix: remove apply, only keep map Signed-off-by: anna-charlotte <[email protected]> * refactor: map to map_docs Signed-off-by: anna-charlotte <[email protected]> * fix: apply suggestion Signed-off-by: anna-charlotte <[email protected]> * docs: add example usage Signed-off-by: anna-charlotte <[email protected]> --------- Signed-off-by: anna-charlotte <[email protected]>
1 parent 082a39c commit 6cd5912

File tree

11 files changed

+585
-10
lines changed

11 files changed

+585
-10
lines changed

.github/workflows/ci.yml

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
- name: Test
118118
id: test
119119
run: |
120-
poetry run pytest -m "not tensorflow" ${{ matrix.test-path }}
120+
poetry run pytest -m "not (tensorflow or benchmark)" ${{ matrix.test-path }}
121121
timeout-minutes: 30
122122
# env:
123123
# JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
@@ -162,7 +162,7 @@ jobs:
162162
- name: Test
163163
id: test
164164
run: |
165-
poetry run pytest -m "not tensorflow" ${{ matrix.test-path }}
165+
poetry run pytest -m "not (tensorflow or benchmark)" ${{ matrix.test-path }}
166166
timeout-minutes: 30
167167

168168

@@ -222,10 +222,35 @@ jobs:
222222
poetry run pytest -m 'tensorflow' tests
223223
timeout-minutes: 30
224224

225+
docarray-test-benchmarks:
226+
needs: [lint-ruff, check-black, import-test]
227+
runs-on: ubuntu-latest
228+
strategy:
229+
fail-fast: false
230+
matrix:
231+
python-version: [3.7]
232+
steps:
233+
- uses: actions/[email protected]
234+
- name: Set up Python ${{ matrix.python-version }}
235+
uses: actions/setup-python@v4
236+
with:
237+
python-version: ${{ matrix.python-version }}
238+
- name: Prepare environment
239+
run: |
240+
python -m pip install --upgrade pip
241+
python -m pip install poetry
242+
poetry install --all-extras
243+
244+
- name: Test
245+
id: test
246+
run: |
247+
poetry run pytest -m 'benchmark' tests
248+
timeout-minutes: 30
249+
225250

226251
# just for blocking the merge until all parallel core-test are successful
227252
success-all-test:
228-
needs: [docarray-test, docarray-test-proto3, docarray-test-tensorflow, import-test, check-black, check-mypy, lint-ruff]
253+
needs: [docarray-test, docarray-test-proto3, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, check-mypy, lint-ruff]
229254
if: always()
230255
runs-on: ubuntu-latest
231256
steps:

docarray/array/abstract_array.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
import random
12
from abc import abstractmethod
23
from typing import (
34
TYPE_CHECKING,
45
Any,
56
Dict,
7+
Generator,
68
Generic,
9+
Iterable,
710
List,
811
Sequence,
912
Type,
1013
TypeVar,
1114
Union,
1215
cast,
16+
overload,
1317
)
1418

19+
import numpy as np
20+
1521
from docarray.base_document import BaseDocument
1622
from docarray.display.document_array_summary import DocumentArraySummary
1723
from docarray.typing import NdArray
@@ -24,6 +30,7 @@
2430

2531
T = TypeVar('T', bound='AnyDocumentArray')
2632
T_doc = TypeVar('T_doc', bound=BaseDocument)
33+
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
2734

2835

2936
class AnyDocumentArray(Sequence[T_doc], Generic[T_doc], AbstractType):
@@ -79,6 +86,30 @@ def _setter(self, value):
7986

8087
return cls.__typed_da__[cls][item]
8188

89+
@overload
90+
def __getitem__(self: T, item: int) -> T_doc:
91+
...
92+
93+
@overload
94+
def __getitem__(self: T, item: IndexIterType) -> T:
95+
...
96+
97+
@abstractmethod
98+
def __getitem__(self, item: Union[int, IndexIterType]) -> Union[T_doc, T]:
99+
...
100+
101+
@overload
102+
def __setitem__(self: T, key: int, value: T_doc):
103+
...
104+
105+
@overload
106+
def __setitem__(self: T, key: IndexIterType, value: T):
107+
...
108+
109+
@abstractmethod
110+
def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]):
111+
...
112+
82113
@abstractmethod
83114
def _get_array_attribute(
84115
self: T,
@@ -249,3 +280,39 @@ def summary(self):
249280
Document type.
250281
"""
251282
DocumentArraySummary(self).summary()
283+
284+
def _batch(
285+
self: T,
286+
batch_size: int,
287+
shuffle: bool = False,
288+
show_progress: bool = False,
289+
) -> Generator[T, None, None]:
290+
"""
291+
Creates a `Generator` that yields `DocumentArray` of size `batch_size`.
292+
Note, that the last batch might be smaller than `batch_size`.
293+
294+
:param batch_size: Size of each generated batch.
295+
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
296+
:param show_progress: if set, show a progress bar when batching documents.
297+
:yield: a Generator of `DocumentArray`, each in the length of `batch_size`
298+
"""
299+
from rich.progress import track
300+
301+
if not (isinstance(batch_size, int) and batch_size > 0):
302+
raise ValueError(
303+
f'`batch_size` should be a positive integer, received: {batch_size}'
304+
)
305+
306+
N = len(self)
307+
indices = list(range(N))
308+
n_batches = int(np.ceil(N / batch_size))
309+
310+
if shuffle:
311+
random.shuffle(indices)
312+
313+
for i in track(
314+
range(n_batches),
315+
description='Batching documents',
316+
disable=not show_progress,
317+
):
318+
yield self[indices[i * batch_size : (i + 1) * batch_size]]

docarray/array/array/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ def __getitem__(self, item):
177177
raise TypeError(f'Invalid type {type(head)} for indexing')
178178

179179
@overload
180-
def __setitem__(self: T, key: IndexIterType, value: T):
180+
def __setitem__(self: T, key: int, value: T_doc):
181181
...
182182

183183
@overload
184-
def __setitem__(self: T, key: int, value: T_doc):
184+
def __setitem__(self: T, key: IndexIterType, value: T):
185185
...
186186

187187
def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]):

docarray/display/tensor_display.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def __rich_console__(
3030
from rich.segment import Segment
3131
from rich.style import Style
3232

33-
tensor_normalized = comp_be.minmax_normalize(
34-
comp_be.detach(self.tensor), (0, 5)
35-
)
33+
tensor_normalized = comp_be.minmax_normalize(t_squeezed, (0, 5))
3634

3735
hue = 0.75
3836
saturation = 1.0

docarray/helper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
1+
from types import LambdaType
2+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
23

34
if TYPE_CHECKING:
45
from docarray import BaseDocument
@@ -138,3 +139,14 @@ def _get_field_type_by_access_path(
138139
return None
139140
else:
140141
return None
142+
143+
144+
def _is_lambda_or_partial_or_local_function(func: Callable[[Any], Any]) -> bool:
145+
"""
146+
Return True if `func` is lambda, local or partial function, else False.
147+
"""
148+
return (
149+
(isinstance(func, LambdaType) and func.__name__ == '<lambda>')
150+
or not hasattr(func, '__qualname__')
151+
or ('<locals>' in func.__qualname__)
152+
)

0 commit comments

Comments
 (0)