Skip to content

Commit f96171f

Browse files
author
Joan Fontanals Martinez
committed
fix: allow find in in memory index with AnyEmbedding
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent a2ccb08 commit f96171f

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

docarray/utils/find.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from docarray.helper import _get_field_type_by_access_path
1212
from docarray.typing import AnyTensor
1313
from docarray.typing.tensor.abstract_tensor import AbstractTensor
14+
from docarray.typing.tensor.embedding import AnyEmbedding
15+
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
1416
from docarray.utils._internal._typing import safe_issubclass
1517

1618

@@ -207,12 +209,10 @@ class MyDocument(BaseDoc):
207209
descending = metric.endswith('_sim') # similarity metrics are descending
208210

209211
embedding_type = _da_attr_type(index, search_field)
210-
try:
211-
comp_backend = embedding_type.get_comp_backend()
212-
except:
213-
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
212+
if embedding_type == AnyEmbedding:
214213
embedding_type = NdArrayEmbedding
215-
comp_backend = embedding_type.get_comp_backend()
214+
215+
comp_backend = embedding_type.get_comp_backend()
216216

217217
# extract embeddings from query and index
218218
if cache is not None and search_field in cache:

tests/index/in_memory/test_in_memory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import rand
77

88
from docarray import BaseDoc, DocList
9+
from docarray.documents import TextDoc
910
from docarray.index.backends.in_memory import InMemoryExactNNIndex
1011
from docarray.typing import NdArray, TorchTensor
1112

@@ -112,6 +113,19 @@ class MyDoc(BaseDoc):
112113
assert len(scores) == 0
113114

114115

116+
def test_with_text_doc():
117+
index = InMemoryExactNNIndex[TextDoc]()
118+
119+
docs = DocList[TextDoc](
120+
[TextDoc(text='hey', embedding=np.random.rand(128)) for i in range(200)]
121+
)
122+
index.index(docs)
123+
res = index.find_batched(docs[0:10], search_field='embedding')
124+
assert len(res.documents) == 10
125+
for r in res.documents:
126+
assert len(r) == 5
127+
128+
115129
def test_concatenated_queries(doc_index):
116130
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))
117131

0 commit comments

Comments
 (0)