Skip to content

Commit 94a479e

Browse files
author
Joan Fontanals
authored
fix: fix search in memory with AnyEmbedding (#1696)
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent 0a1da30 commit 94a479e

File tree

6 files changed

+161
-64
lines changed

6 files changed

+161
-64
lines changed

docarray/array/list_advance_indexing.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import (
2-
TYPE_CHECKING,
32
Any,
43
Iterable,
54
List,
@@ -14,7 +13,18 @@
1413
import numpy as np
1514
from typing_extensions import SupportsIndex
1615

17-
from docarray.utils._internal.misc import import_library
16+
from docarray.utils._internal.misc import (
17+
is_torch_available,
18+
is_tf_available,
19+
)
20+
21+
torch_available = is_torch_available()
22+
if torch_available:
23+
import torch
24+
tf_available = is_tf_available()
25+
if tf_available:
26+
import tensorflow as tf # type: ignore
27+
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
1828

1929
T_item = TypeVar('T_item')
2030
T = TypeVar('T', bound='ListAdvancedIndexing')
@@ -75,17 +85,20 @@ def _normalize_index_item(
7585
return item.tolist()
7686

7787
# torch index types
78-
if TYPE_CHECKING:
79-
import torch
80-
else:
81-
torch = import_library('torch', raise_error=True)
82-
83-
allowed_torch_dtypes = [
84-
torch.bool,
85-
torch.int64,
86-
]
87-
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
88-
return item.tolist()
88+
if torch_available:
89+
90+
allowed_torch_dtypes = [
91+
torch.bool,
92+
torch.int64,
93+
]
94+
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
95+
return item.tolist()
96+
97+
if tf_available:
98+
if isinstance(item, tf.Tensor):
99+
return item.numpy().tolist()
100+
if isinstance(item, TensorFlowTensor):
101+
return item.tensor.numpy().tolist()
89102

90103
return item
91104

docarray/computation/abstract_comp_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import typing
22
from abc import ABC, abstractmethod
3-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union
3+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, Iterable
44

55
if TYPE_CHECKING:
66
import numpy as np
77

88
# In practice all of the below will be the same type
99
TTensor = TypeVar('TTensor')
10-
TTensorRetrieval = TypeVar('TTensorRetrieval')
10+
TTensorRetrieval = TypeVar('TTensorRetrieval', bound=Iterable)
1111
TTensorMetrics = TypeVar('TTensorMetrics')
1212

1313

docarray/index/backends/in_memory.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from docarray.utils.find import (
3434
FindResult,
3535
FindResultBatched,
36-
_da_attr_type,
3736
_extract_embeddings,
3837
_FindResult,
3938
_FindResultBatched,
@@ -196,10 +195,7 @@ def _rebuild_embedding(self):
196195
self._embedding_map = dict()
197196
else:
198197
for field_, embedding in self._embedding_map.items():
199-
embedding_type = _da_attr_type(self._docs, field_)
200-
self._embedding_map[field_] = _extract_embeddings(
201-
self._docs, field_, embedding_type
202-
)
198+
self._embedding_map[field_] = _extract_embeddings(self._docs, field_)
203199

204200
def _del_items(self, doc_ids: Sequence[str]):
205201
"""Delete Documents from the index.

docarray/typing/tensor/tensorflow_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,7 @@ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
341341
def _docarray_to_ndarray(self) -> np.ndarray:
342342
"""cast itself to a numpy array"""
343343
return self.tensor.numpy()
344+
345+
@property
346+
def shape(self):
347+
return tf.shape(self.tensor)

docarray/utils/find.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,46 @@
11
__all__ = ['find', 'find_batched']
22

3-
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast
4-
5-
from typing_inspect import is_union_type
3+
from typing import (
4+
Any,
5+
Dict,
6+
List,
7+
NamedTuple,
8+
Optional,
9+
Tuple,
10+
Union,
11+
cast,
12+
Type,
13+
TYPE_CHECKING,
14+
)
615

716
from docarray.array.any_array import AnyDocArray
817
from docarray.array.doc_list.doc_list import DocList
918
from docarray.array.doc_vec.doc_vec import DocVec
1019
from docarray.base_doc import BaseDoc
11-
from docarray.helper import _get_field_type_by_access_path
1220
from docarray.typing import AnyTensor
13-
from docarray.typing.tensor.abstract_tensor import AbstractTensor
14-
from docarray.utils._internal._typing import safe_issubclass
21+
from docarray.computation.numpy_backend import NumpyCompBackend
22+
from docarray.typing.tensor import NdArray
23+
from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa
24+
25+
torch_available = is_torch_available()
26+
if torch_available:
27+
import torch
28+
29+
from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
30+
from docarray.computation.torch_backend import TorchCompBackend
31+
32+
tf_available = is_tf_available()
33+
if tf_available:
34+
import tensorflow as tf # type: ignore
35+
36+
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
37+
from docarray.computation.tensorflow_backend import TensorFlowCompBackend
38+
39+
if TYPE_CHECKING:
40+
from docarray.computation.abstract_numpy_based_backend import (
41+
AbstractComputationalBackend,
42+
)
43+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
1544

1645

1746
class FindResult(NamedTuple):
@@ -108,6 +137,7 @@ class MyDocument(BaseDoc):
108137
can be either `cpu` or a `cuda` device.
109138
:param descending: sort the results in descending order.
110139
Per default, this is chosen based on the `metric` argument.
140+
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
111141
:return: A named tuple of the form (DocList, AnyTensor),
112142
where the first element contains the closes matches for the query,
113143
and the second element contains the corresponding scores.
@@ -199,29 +229,26 @@ class MyDocument(BaseDoc):
199229
can be either `cpu` or a `cuda` device.
200230
:param descending: sort the results in descending order.
201231
Per default, this is chosen based on the `metric` argument.
232+
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
202233
:return: A named tuple of the form (DocList, AnyTensor),
203234
where the first element contains the closest matches for each query,
204235
and the second element contains the corresponding scores.
205236
"""
206237
if descending is None:
207238
descending = metric.endswith('_sim') # similarity metrics are descending
208239

209-
embedding_type = _da_attr_type(index, search_field)
210-
comp_backend = embedding_type.get_comp_backend()
211-
212240
# extract embeddings from query and index
213241
if cache is not None and search_field in cache:
214242
index_embeddings, valid_idx = cache[search_field]
215243
else:
216-
index_embeddings, valid_idx = _extract_embeddings(
217-
index, search_field, embedding_type
218-
)
244+
index_embeddings, valid_idx = _extract_embeddings(index, search_field)
219245
if cache is not None:
220246
cache[search_field] = (
221247
index_embeddings,
222248
valid_idx,
223249
) # cache embedding for next query
224-
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)
250+
query_embeddings, _ = _extract_embeddings(query, search_field)
251+
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings)
225252

226253
# compute distances and return top results
227254
metric_fn = getattr(comp_backend.Metrics, metric)
@@ -267,60 +294,64 @@ def _extract_embedding_single(
267294
return emb
268295

269296

297+
def _get_tensor_type_and_comp_backend_from_tensor(
298+
tensor,
299+
) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']:
300+
"""Extract the embeddings from the data.
301+
302+
:param tensor: the tensor for which to extract
303+
:return: a tuple of the tensor type and the computational backend
304+
"""
305+
da_tensor_type: Type['AbstractTensor'] = NdArray
306+
comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend()
307+
if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)):
308+
comp_backend = TorchCompBackend()
309+
da_tensor_type = TorchTensor
310+
elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)):
311+
comp_backend = TensorFlowCompBackend()
312+
da_tensor_type = TensorFlowTensor
313+
314+
return da_tensor_type, comp_backend
315+
316+
270317
def _extract_embeddings(
271318
data: Union[AnyDocArray, BaseDoc, AnyTensor],
272319
search_field: str,
273-
embedding_type: Type,
274320
) -> Tuple[AnyTensor, Optional[List[int]]]:
275321
"""Extract the embeddings from the data.
276322
277323
:param data: the data
278324
:param search_field: the embedding field
279-
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
280325
:return: a tuple of the embeddings and optionally a list of the non-null indices
281326
"""
282327
emb: AnyTensor
283328
valid_idx = None
329+
comp_backend = None
330+
da_tensor_type = None
284331
if isinstance(data, DocList):
285332
emb_valid = [
286333
(emb, i)
287334
for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
288335
if emb is not None
289336
]
290337
emb_list, valid_idx = zip(*emb_valid)
291-
emb = embedding_type._docarray_stack(emb_list)
338+
if len(emb_list) > 0:
339+
(
340+
da_tensor_type,
341+
comp_backend,
342+
) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0])
343+
else:
344+
raise Exception(f'No embedding could be extracted from data {data}')
345+
346+
emb = da_tensor_type._docarray_stack(emb_list)
292347
elif isinstance(data, (DocVec, BaseDoc)):
293348
emb = next(AnyDocArray._traverse(data, search_field))
294349
else: # treat data as tensor
295350
emb = cast(AnyTensor, data)
296351

352+
if comp_backend is None:
353+
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb)
354+
297355
if len(emb.shape) == 1:
298-
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
356+
emb = comp_backend.reshape(tensor=emb, shape=(1, -1))
299357
return emb, valid_idx
300-
301-
302-
def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
303-
"""Get the type of the attribute according to the Document type
304-
(schema) of the DocList.
305-
306-
:param docs: the DocList
307-
:param access_path: the "__"-separated access path
308-
:return: the type of the attribute
309-
"""
310-
field_type: Optional[Type] = _get_field_type_by_access_path(
311-
docs.doc_type, access_path
312-
)
313-
if field_type is None:
314-
raise ValueError(f"Access path is not valid: {access_path}")
315-
316-
if is_union_type(field_type):
317-
# determine type based on the fist element
318-
field_type = type(next(AnyDocArray._traverse(docs[0], access_path)))
319-
320-
if not safe_issubclass(field_type, AbstractTensor):
321-
raise ValueError(
322-
f'attribute {access_path} is not a tensor-like type, '
323-
f'but {field_type.__class__.__name__}'
324-
)
325-
326-
return cast(Type[AnyTensor], field_type)

tests/index/in_memory/test_in_memory.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66
from torch import rand
77

88
from docarray import BaseDoc, DocList
9+
from docarray.documents import TextDoc
910
from docarray.index.backends.in_memory import InMemoryExactNNIndex
1011
from docarray.typing import NdArray, TorchTensor
1112

13+
from docarray.utils._internal.misc import is_tf_available
14+
15+
tf_available = is_tf_available()
16+
if tf_available:
17+
import tensorflow as tf
18+
from docarray.typing import TensorFlowTensor
19+
1220

1321
class SchemaDoc(BaseDoc):
1422
text: str
@@ -112,6 +120,51 @@ class MyDoc(BaseDoc):
112120
assert len(scores) == 0
113121

114122

123+
def test_with_text_doc_ndarray():
124+
index = InMemoryExactNNIndex[TextDoc]()
125+
126+
docs = DocList[TextDoc](
127+
[TextDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)]
128+
)
129+
index.index(docs)
130+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
131+
assert len(res.documents) == 10
132+
for r in res.documents:
133+
assert len(r) == 5
134+
135+
136+
@pytest.mark.tensorflow
137+
def test_with_text_doc_tensorflow():
138+
index = InMemoryExactNNIndex[TextDoc]()
139+
140+
docs = DocList[TextDoc](
141+
[
142+
TextDoc(text='hey', embedding=tf.random.uniform(shape=[128]))
143+
for _ in range(200)
144+
]
145+
)
146+
index.index(docs)
147+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
148+
assert len(res.documents) == 10
149+
for r in res.documents:
150+
assert len(r) == 5
151+
152+
153+
def test_with_text_doc_torch():
154+
import torch
155+
156+
index = InMemoryExactNNIndex[TextDoc]()
157+
158+
docs = DocList[TextDoc](
159+
[TextDoc(text='hey', embedding=torch.rand(128)) for _ in range(200)]
160+
)
161+
index.index(docs)
162+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
163+
assert len(res.documents) == 10
164+
for r in res.documents:
165+
assert len(r) == 5
166+
167+
115168
def test_concatenated_queries(doc_index):
116169
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))
117170

0 commit comments

Comments
 (0)