Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from docarray.helper import _get_field_type_by_access_path
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
from docarray.utils._internal._typing import safe_issubclass


Expand Down Expand Up @@ -207,6 +209,9 @@ class MyDocument(BaseDoc):
descending = metric.endswith('_sim') # similarity metrics are descending

embedding_type = _da_attr_type(index, search_field)
if embedding_type == AnyEmbedding:
embedding_type = NdArrayEmbedding

comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
Expand Down
14 changes: 14 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import rand

from docarray import BaseDoc, DocList
from docarray.documents import TextDoc
from docarray.index.backends.in_memory import InMemoryExactNNIndex
from docarray.typing import NdArray, TorchTensor

Expand Down Expand Up @@ -112,6 +113,19 @@ class MyDoc(BaseDoc):
assert len(scores) == 0


def test_with_text_doc():
index = InMemoryExactNNIndex[TextDoc]()

docs = DocList[TextDoc](
[TextDoc(text='hey', embedding=np.random.rand(128)) for i in range(200)]
)
index.index(docs)
res = index.find_batched(docs[0:10], search_field='embedding')
assert len(res.documents) == 10
for r in res.documents:
assert len(r) == 5


def test_concatenated_queries(doc_index):
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))

Expand Down