Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: allow find in in memory index with AnyEmbedding
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
Joan Fontanals Martinez committed Jul 10, 2023
commit bc0e40dcc875c405ee5f59c60cda31a7fff78e91
10 changes: 5 additions & 5 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from docarray.helper import _get_field_type_by_access_path
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
from docarray.utils._internal._typing import safe_issubclass


Expand Down Expand Up @@ -207,12 +209,10 @@ class MyDocument(BaseDoc):
descending = metric.endswith('_sim') # similarity metrics are descending

embedding_type = _da_attr_type(index, search_field)
try:
comp_backend = embedding_type.get_comp_backend()
except:
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
if embedding_type == AnyEmbedding:
embedding_type = NdArrayEmbedding
comp_backend = embedding_type.get_comp_backend()

comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
if cache is not None and search_field in cache:
Expand Down
14 changes: 14 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import rand

from docarray import BaseDoc, DocList
from docarray.documents import TextDoc
from docarray.index.backends.in_memory import InMemoryExactNNIndex
from docarray.typing import NdArray, TorchTensor

Expand Down Expand Up @@ -112,6 +113,19 @@ class MyDoc(BaseDoc):
assert len(scores) == 0


def test_with_text_doc():
index = InMemoryExactNNIndex[TextDoc]()

docs = DocList[TextDoc](
[TextDoc(text='hey', embedding=np.random.rand(128)) for i in range(200)]
)
index.index(docs)
res = index.find_batched(docs[0:10], search_field='embedding')
assert len(res.documents) == 10
for r in res.documents:
assert len(r) == 5


def test_concatenated_queries(doc_index):
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))

Expand Down