Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: try to pass all mypy checks
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
Joan Fontanals Martinez committed Jul 11, 2023
commit 2babe51848d75261aa2c84369c836e7e6016275e
4 changes: 2 additions & 2 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, Iterable

if TYPE_CHECKING:
import numpy as np

# In practice all of the below will be the same type
TTensor = TypeVar('TTensor')
TTensorRetrieval = TypeVar('TTensorRetrieval')
TTensorRetrieval = TypeVar('TTensorRetrieval', bound=Iterable)
TTensorMetrics = TypeVar('TTensorMetrics')


Expand Down
6 changes: 1 addition & 5 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from docarray.utils.find import (
FindResult,
FindResultBatched,
_da_attr_type,
_extract_embeddings,
_FindResult,
_FindResultBatched,
Expand Down Expand Up @@ -196,10 +195,7 @@ def _rebuild_embedding(self):
self._embedding_map = dict()
else:
for field_, embedding in self._embedding_map.items():
embedding_type = _da_attr_type(self._docs, field_)
self._embedding_map[field_] = _extract_embeddings(
self._docs, field_, embedding_type
)
self._embedding_map[field_] = _extract_embeddings(self._docs, field_)

def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.
Expand Down
4 changes: 4 additions & 0 deletions docarray/typing/tensor/tensorflow_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,7 @@ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
return self.tensor.numpy()

@property
def shape(self):
return tf.shape(self.tensor)
124 changes: 75 additions & 49 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
__all__ = ['find', 'find_batched']

from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast

from typing_inspect import is_union_type
from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
cast,
Type,
TYPE_CHECKING,
)

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.doc_vec import DocVec
from docarray.base_doc import BaseDoc
from docarray.helper import _get_field_type_by_access_path
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
from docarray.utils._internal._typing import safe_issubclass
from docarray.computation.numpy_backend import NumpyCompBackend
from docarray.typing.tensor import NdArray
from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa

torch_available = is_torch_available()
if torch_available:
import torch

from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
from docarray.computation.torch_backend import TorchCompBackend

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
from docarray.computation.tensorflow_backend import TensorFlowCompBackend

if TYPE_CHECKING:
from docarray.computation.abstract_numpy_based_backend import (
AbstractComputationalBackend,
)
from docarray.typing.tensor.abstract_tensor import AbstractTensor


class FindResult(NamedTuple):
Expand Down Expand Up @@ -110,6 +137,7 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closes matches for the query,
and the second element contains the corresponding scores.
Expand Down Expand Up @@ -201,32 +229,26 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
:param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closest matches for each query,
and the second element contains the corresponding scores.
"""
if descending is None:
descending = metric.endswith('_sim') # similarity metrics are descending

embedding_type = _da_attr_type(index, search_field)
if embedding_type == AnyEmbedding:
embedding_type = NdArrayEmbedding

comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
if cache is not None and search_field in cache:
index_embeddings, valid_idx = cache[search_field]
else:
index_embeddings, valid_idx = _extract_embeddings(
index, search_field, embedding_type
)
index_embeddings, valid_idx = _extract_embeddings(index, search_field)
if cache is not None:
cache[search_field] = (
index_embeddings,
valid_idx,
) # cache embedding for next query
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)
query_embeddings, _ = _extract_embeddings(query, search_field)
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings)

# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
Expand Down Expand Up @@ -272,60 +294,64 @@ def _extract_embedding_single(
return emb


def _get_tensor_type_and_comp_backend_from_tensor(
tensor,
) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']:
"""Extract the embeddings from the data.

:param tensor: the tensor for which to extract
:return: a tuple of the tensor type and the computational backend
"""
da_tensor_type: Type['AbstractTensor'] = NdArray
comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend()
if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)):
comp_backend = TorchCompBackend()
da_tensor_type = TorchTensor
elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)):
comp_backend = TensorFlowCompBackend()
da_tensor_type = TensorFlowTensor

return da_tensor_type, comp_backend


def _extract_embeddings(
data: Union[AnyDocArray, BaseDoc, AnyTensor],
search_field: str,
embedding_type: Type,
) -> Tuple[AnyTensor, Optional[List[int]]]:
"""Extract the embeddings from the data.

:param data: the data
:param search_field: the embedding field
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: a tuple of the embeddings and optionally a list of the non-null indices
"""
emb: AnyTensor
valid_idx = None
comp_backend = None
da_tensor_type = None
if isinstance(data, DocList):
emb_valid = [
(emb, i)
for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
if emb is not None
]
emb_list, valid_idx = zip(*emb_valid)
emb = embedding_type._docarray_stack(emb_list)
if len(emb_list) > 0:
(
da_tensor_type,
comp_backend,
) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0])
else:
raise Exception(f'No embedding could be extracted from data {data}')

emb = da_tensor_type._docarray_stack(emb_list)
elif isinstance(data, (DocVec, BaseDoc)):
emb = next(AnyDocArray._traverse(data, search_field))
else: # treat data as tensor
emb = cast(AnyTensor, data)

if comp_backend is None:
_, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb)

if len(emb.shape) == 1:
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
emb = comp_backend.reshape(tensor=emb, shape=(1, -1))
return emb, valid_idx


def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
"""Get the type of the attribute according to the Document type
(schema) of the DocList.

:param docs: the DocList
:param access_path: the "__"-separated access path
:return: the type of the attribute
"""
field_type: Optional[Type] = _get_field_type_by_access_path(
docs.doc_type, access_path
)
if field_type is None:
raise ValueError(f"Access path is not valid: {access_path}")

if is_union_type(field_type):
# determine type based on the fist element
field_type = type(next(AnyDocArray._traverse(docs[0], access_path)))

if not safe_issubclass(field_type, AbstractTensor):
raise ValueError(
f'attribute {access_path} is not a tensor-like type, '
f'but {field_type.__class__.__name__}'
)

return cast(Type[AnyTensor], field_type)
4 changes: 2 additions & 2 deletions tests/index/weaviate/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: '3.8'
version: '3.3'

services:

Expand All @@ -24,4 +24,4 @@ services:
LOG_LEVEL: debug # verbose
LOG_FORMAT: text
# LOG_LEVEL: trace # very verbose
GODEBUG: gctrace=1 # make go garbage collector verbose
GODEBUG: gctrace=1 # make go garbage collector verbose