Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions docarray/array/list_advance_indexing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Expand All @@ -14,7 +13,18 @@
import numpy as np
from typing_extensions import SupportsIndex

from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import (
is_torch_available,
is_tf_available,
)

torch_available = is_torch_available()
if torch_available:
import torch
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor

T_item = TypeVar('T_item')
T = TypeVar('T', bound='ListAdvancedIndexing')
Expand Down Expand Up @@ -75,17 +85,20 @@ def _normalize_index_item(
return item.tolist()

# torch index types
if TYPE_CHECKING:
import torch
else:
torch = import_library('torch', raise_error=True)

allowed_torch_dtypes = [
torch.bool,
torch.int64,
]
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
return item.tolist()
if torch_available:

allowed_torch_dtypes = [
torch.bool,
torch.int64,
]
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
return item.tolist()

if tf_available:
if isinstance(item, tf.Tensor):
return item.numpy().tolist()
if isinstance(item, TensorFlowTensor):
return item.tensor.numpy().tolist()

return item

Expand Down
4 changes: 2 additions & 2 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, Iterable

if TYPE_CHECKING:
import numpy as np

# In practice all of the below will be the same type
TTensor = TypeVar('TTensor')
TTensorRetrieval = TypeVar('TTensorRetrieval')
TTensorRetrieval = TypeVar('TTensorRetrieval', bound=Iterable)
TTensorMetrics = TypeVar('TTensorMetrics')


Expand Down
6 changes: 1 addition & 5 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from docarray.utils.find import (
FindResult,
FindResultBatched,
_da_attr_type,
_extract_embeddings,
_FindResult,
_FindResultBatched,
Expand Down Expand Up @@ -196,10 +195,7 @@ def _rebuild_embedding(self):
self._embedding_map = dict()
else:
for field_, embedding in self._embedding_map.items():
embedding_type = _da_attr_type(self._docs, field_)
self._embedding_map[field_] = _extract_embeddings(
self._docs, field_, embedding_type
)
self._embedding_map[field_] = _extract_embeddings(self._docs, field_)

def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.
Expand Down
4 changes: 4 additions & 0 deletions docarray/typing/tensor/tensorflow_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,7 @@ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
return self.tensor.numpy()

@property
def shape(self):
return tf.shape(self.tensor)
119 changes: 75 additions & 44 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
__all__ = ['find', 'find_batched']

from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast

from typing_inspect import is_union_type
from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
cast,
Type,
TYPE_CHECKING,
)

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.doc_vec import DocVec
from docarray.base_doc import BaseDoc
from docarray.helper import _get_field_type_by_access_path
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import safe_issubclass
from docarray.computation.numpy_backend import NumpyCompBackend
from docarray.typing.tensor import NdArray
from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa

torch_available = is_torch_available()
if torch_available:
import torch

from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
from docarray.computation.torch_backend import TorchCompBackend

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
from docarray.computation.tensorflow_backend import TensorFlowCompBackend

if TYPE_CHECKING:
from docarray.computation.abstract_numpy_based_backend import (
AbstractComputationalBackend,
)
from docarray.typing.tensor.abstract_tensor import AbstractTensor


class FindResult(NamedTuple):
Expand Down Expand Up @@ -108,6 +137,7 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closes matches for the query,
and the second element contains the corresponding scores.
Expand Down Expand Up @@ -199,29 +229,26 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closest matches for each query,
and the second element contains the corresponding scores.
"""
if descending is None:
descending = metric.endswith('_sim') # similarity metrics are descending

embedding_type = _da_attr_type(index, search_field)
comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
if cache is not None and search_field in cache:
index_embeddings, valid_idx = cache[search_field]
else:
index_embeddings, valid_idx = _extract_embeddings(
index, search_field, embedding_type
)
index_embeddings, valid_idx = _extract_embeddings(index, search_field)
if cache is not None:
cache[search_field] = (
index_embeddings,
valid_idx,
) # cache embedding for next query
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)
query_embeddings, _ = _extract_embeddings(query, search_field)
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings)

# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
Expand Down Expand Up @@ -267,60 +294,64 @@ def _extract_embedding_single(
return emb


def _get_tensor_type_and_comp_backend_from_tensor(
tensor,
) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']:
"""Extract the embeddings from the data.

:param tensor: the tensor for which to extract
:return: a tuple of the tensor type and the computational backend
"""
da_tensor_type: Type['AbstractTensor'] = NdArray
comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend()
if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)):
comp_backend = TorchCompBackend()
da_tensor_type = TorchTensor
elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)):
comp_backend = TensorFlowCompBackend()
da_tensor_type = TensorFlowTensor

return da_tensor_type, comp_backend


def _extract_embeddings(
data: Union[AnyDocArray, BaseDoc, AnyTensor],
search_field: str,
embedding_type: Type,
) -> Tuple[AnyTensor, Optional[List[int]]]:
"""Extract the embeddings from the data.

:param data: the data
:param search_field: the embedding field
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: a tuple of the embeddings and optionally a list of the non-null indices
"""
emb: AnyTensor
valid_idx = None
comp_backend = None
da_tensor_type = None
if isinstance(data, DocList):
emb_valid = [
(emb, i)
for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
if emb is not None
]
emb_list, valid_idx = zip(*emb_valid)
emb = embedding_type._docarray_stack(emb_list)
if len(emb_list) > 0:
(
da_tensor_type,
comp_backend,
) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0])
else:
raise Exception(f'No embedding could be extracted from data {data}')

emb = da_tensor_type._docarray_stack(emb_list)
elif isinstance(data, (DocVec, BaseDoc)):
emb = next(AnyDocArray._traverse(data, search_field))
else: # treat data as tensor
emb = cast(AnyTensor, data)

if comp_backend is None:
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb)

if len(emb.shape) == 1:
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
emb = comp_backend.reshape(tensor=emb, shape=(1, -1))
return emb, valid_idx


def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
"""Get the type of the attribute according to the Document type
(schema) of the DocList.

:param docs: the DocList
:param access_path: the "__"-separated access path
:return: the type of the attribute
"""
field_type: Optional[Type] = _get_field_type_by_access_path(
docs.doc_type, access_path
)
if field_type is None:
raise ValueError(f"Access path is not valid: {access_path}")

if is_union_type(field_type):
# determine type based on the fist element
field_type = type(next(AnyDocArray._traverse(docs[0], access_path)))

if not safe_issubclass(field_type, AbstractTensor):
raise ValueError(
f'attribute {access_path} is not a tensor-like type, '
f'but {field_type.__class__.__name__}'
)

return cast(Type[AnyTensor], field_type)
53 changes: 53 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
from torch import rand

from docarray import BaseDoc, DocList
from docarray.documents import TextDoc
from docarray.index.backends.in_memory import InMemoryExactNNIndex
from docarray.typing import NdArray, TorchTensor

from docarray.utils._internal.misc import is_tf_available

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf
from docarray.typing import TensorFlowTensor


class SchemaDoc(BaseDoc):
text: str
Expand Down Expand Up @@ -112,6 +120,51 @@ class MyDoc(BaseDoc):
assert len(scores) == 0


def test_with_text_doc_ndarray():
index = InMemoryExactNNIndex[TextDoc]()

docs = DocList[TextDoc](
[TextDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)]
)
index.index(docs)
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
assert len(res.documents) == 10
for r in res.documents:
assert len(r) == 5


@pytest.mark.tensorflow
def test_with_text_doc_tensorflow():
index = InMemoryExactNNIndex[TextDoc]()

docs = DocList[TextDoc](
[
TextDoc(text='hey', embedding=tf.random.uniform(shape=[128]))
for _ in range(200)
]
)
index.index(docs)
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
assert len(res.documents) == 10
for r in res.documents:
assert len(r) == 5


def test_with_text_doc_torch():
import torch

index = InMemoryExactNNIndex[TextDoc]()

docs = DocList[TextDoc](
[TextDoc(text='hey', embedding=torch.rand(128)) for _ in range(200)]
)
index.index(docs)
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
assert len(res.documents) == 10
for r in res.documents:
assert len(r) == 5


def test_concatenated_queries(doc_index):
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))

Expand Down