Skip to content

Commit a41e3f7

Browse files
author
Joan Fontanals Martinez
committed
fix: try to pass all mypy checks
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent e7315e8 commit a41e3f7

File tree

4 files changed

+80
-58
lines changed

4 files changed

+80
-58
lines changed

docarray/computation/abstract_comp_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import typing
22
from abc import ABC, abstractmethod
3-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union
3+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, Iterable
44

55
if TYPE_CHECKING:
66
import numpy as np
77

88
# In practice all of the below will be the same type
99
TTensor = TypeVar('TTensor')
10-
TTensorRetrieval = TypeVar('TTensorRetrieval')
10+
TTensorRetrieval = TypeVar('TTensorRetrieval', bound=Iterable)
1111
TTensorMetrics = TypeVar('TTensorMetrics')
1212

1313

docarray/index/backends/in_memory.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from docarray.utils.find import (
3434
FindResult,
3535
FindResultBatched,
36-
_da_attr_type,
3736
_extract_embeddings,
3837
_FindResult,
3938
_FindResultBatched,
@@ -196,10 +195,7 @@ def _rebuild_embedding(self):
196195
self._embedding_map = dict()
197196
else:
198197
for field_, embedding in self._embedding_map.items():
199-
embedding_type = _da_attr_type(self._docs, field_)
200-
self._embedding_map[field_] = _extract_embeddings(
201-
self._docs, field_, embedding_type
202-
)
198+
self._embedding_map[field_] = _extract_embeddings(self._docs, field_)
203199

204200
def _del_items(self, doc_ids: Sequence[str]):
205201
"""Delete Documents from the index.

docarray/utils/find.py

Lines changed: 75 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,46 @@
11
__all__ = ['find', 'find_batched']
22

3-
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast
4-
5-
from typing_inspect import is_union_type
3+
from typing import (
4+
Any,
5+
Dict,
6+
List,
7+
NamedTuple,
8+
Optional,
9+
Tuple,
10+
Union,
11+
cast,
12+
Type,
13+
TYPE_CHECKING,
14+
)
615

716
from docarray.array.any_array import AnyDocArray
817
from docarray.array.doc_list.doc_list import DocList
918
from docarray.array.doc_vec.doc_vec import DocVec
1019
from docarray.base_doc import BaseDoc
11-
from docarray.helper import _get_field_type_by_access_path
1220
from docarray.typing import AnyTensor
13-
from docarray.typing.tensor.abstract_tensor import AbstractTensor
14-
from docarray.typing.tensor.embedding import AnyEmbedding
15-
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
16-
from docarray.utils._internal._typing import safe_issubclass
21+
from docarray.computation.numpy_backend import NumpyCompBackend
22+
from docarray.typing.tensor import NdArray
23+
from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa
24+
25+
torch_available = is_torch_available()
26+
if torch_available:
27+
import torch
28+
29+
from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
30+
from docarray.computation.torch_backend import TorchCompBackend
31+
32+
tf_available = is_tf_available()
33+
if tf_available:
34+
import tensorflow as tf # type: ignore
35+
36+
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
37+
from docarray.computation.tensorflow_backend import TensorFlowCompBackend
38+
39+
if TYPE_CHECKING:
40+
from docarray.computation.abstract_numpy_based_backend import (
41+
AbstractComputationalBackend,
42+
)
43+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
1744

1845

1946
class FindResult(NamedTuple):
@@ -110,6 +137,7 @@ class MyDocument(BaseDoc):
110137
can be either `cpu` or a `cuda` device.
111138
:param descending: sort the results in descending order.
112139
Per default, this is chosen based on the `metric` argument.
140+
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
113141
:return: A named tuple of the form (DocList, AnyTensor),
114142
where the first element contains the closes matches for the query,
115143
and the second element contains the corresponding scores.
@@ -201,32 +229,26 @@ class MyDocument(BaseDoc):
201229
can be either `cpu` or a `cuda` device.
202230
:param descending: sort the results in descending order.
203231
Per default, this is chosen based on the `metric` argument.
232+
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
204233
:return: A named tuple of the form (DocList, AnyTensor),
205234
where the first element contains the closest matches for each query,
206235
and the second element contains the corresponding scores.
207236
"""
208237
if descending is None:
209238
descending = metric.endswith('_sim') # similarity metrics are descending
210239

211-
embedding_type = _da_attr_type(index, search_field)
212-
if embedding_type == AnyEmbedding:
213-
embedding_type = NdArrayEmbedding
214-
215-
comp_backend = embedding_type.get_comp_backend()
216-
217240
# extract embeddings from query and index
218241
if cache is not None and search_field in cache:
219242
index_embeddings, valid_idx = cache[search_field]
220243
else:
221-
index_embeddings, valid_idx = _extract_embeddings(
222-
index, search_field, embedding_type
223-
)
244+
index_embeddings, valid_idx = _extract_embeddings(index, search_field)
224245
if cache is not None:
225246
cache[search_field] = (
226247
index_embeddings,
227248
valid_idx,
228249
) # cache embedding for next query
229-
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)
250+
query_embeddings, _ = _extract_embeddings(query, search_field)
251+
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings)
230252

231253
# compute distances and return top results
232254
metric_fn = getattr(comp_backend.Metrics, metric)
@@ -272,60 +294,64 @@ def _extract_embedding_single(
272294
return emb
273295

274296

297+
def _get_tensor_type_and_comp_backend_from_tensor(
298+
tensor,
299+
) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']:
300+
"""Extract the embeddings from the data.
301+
302+
:param tensor: the tensor for which to extract
303+
:return: a tuple of the tensor type and the computational backend
304+
"""
305+
da_tensor_type: Type['AbstractTensor'] = NdArray
306+
comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend()
307+
if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)):
308+
comp_backend = TorchCompBackend()
309+
da_tensor_type = TorchTensor
310+
elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)):
311+
comp_backend = TensorFlowCompBackend()
312+
da_tensor_type = TensorFlowTensor
313+
314+
return da_tensor_type, comp_backend
315+
316+
275317
def _extract_embeddings(
276318
data: Union[AnyDocArray, BaseDoc, AnyTensor],
277319
search_field: str,
278-
embedding_type: Type,
279320
) -> Tuple[AnyTensor, Optional[List[int]]]:
280321
"""Extract the embeddings from the data.
281322
282323
:param data: the data
283324
:param search_field: the embedding field
284-
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
285325
:return: a tuple of the embeddings and optionally a list of the non-null indices
286326
"""
287327
emb: AnyTensor
288328
valid_idx = None
329+
comp_backend = None
330+
da_tensor_type = None
289331
if isinstance(data, DocList):
290332
emb_valid = [
291333
(emb, i)
292334
for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
293335
if emb is not None
294336
]
295337
emb_list, valid_idx = zip(*emb_valid)
296-
emb = embedding_type._docarray_stack(emb_list)
338+
if len(emb_list > 0):
339+
(
340+
da_tensor_type,
341+
comp_backend,
342+
) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0])
343+
else:
344+
raise Exception(f'No embedding could be extracted from data {data}')
345+
346+
emb = da_tensor_type._docarray_stack(emb_list)
297347
elif isinstance(data, (DocVec, BaseDoc)):
298348
emb = next(AnyDocArray._traverse(data, search_field))
299349
else: # treat data as tensor
300350
emb = cast(AnyTensor, data)
301351

352+
if comp_backend is None:
353+
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb)
354+
302355
if len(emb.shape) == 1:
303-
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
356+
emb = comp_backend.reshape(tensor=emb, shape=(1, -1))
304357
return emb, valid_idx
305-
306-
307-
def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
308-
"""Get the type of the attribute according to the Document type
309-
(schema) of the DocList.
310-
311-
:param docs: the DocList
312-
:param access_path: the "__"-separated access path
313-
:return: the type of the attribute
314-
"""
315-
field_type: Optional[Type] = _get_field_type_by_access_path(
316-
docs.doc_type, access_path
317-
)
318-
if field_type is None:
319-
raise ValueError(f"Access path is not valid: {access_path}")
320-
321-
if is_union_type(field_type):
322-
# determine type based on the fist element
323-
field_type = type(next(AnyDocArray._traverse(docs[0], access_path)))
324-
325-
if not safe_issubclass(field_type, AbstractTensor):
326-
raise ValueError(
327-
f'attribute {access_path} is not a tensor-like type, '
328-
f'but {field_type.__class__.__name__}'
329-
)
330-
331-
return cast(Type[AnyTensor], field_type)

tests/index/weaviate/docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: '3.8'
1+
version: '3.3'
22

33
services:
44

@@ -24,4 +24,4 @@ services:
2424
LOG_LEVEL: debug # verbose
2525
LOG_FORMAT: text
2626
# LOG_LEVEL: trace # very verbose
27-
GODEBUG: gctrace=1 # make go garbage collector verbose
27+
GODEBUG: gctrace=1 # make go garbage collector verbose

0 commit comments

Comments
 (0)