File tree Expand file tree Collapse file tree 2 files changed +19
-5
lines changed
Expand file tree Collapse file tree 2 files changed +19
-5
lines changed Original file line number Diff line number Diff line change 1111from docarray .helper import _get_field_type_by_access_path
1212from docarray .typing import AnyTensor
1313from docarray .typing .tensor .abstract_tensor import AbstractTensor
14+ from docarray .typing .tensor .embedding import AnyEmbedding
15+ from docarray .typing .tensor .embedding .ndarray import NdArrayEmbedding
1416from docarray .utils ._internal ._typing import safe_issubclass
1517
1618
@@ -207,12 +209,10 @@ class MyDocument(BaseDoc):
207209 descending = metric .endswith ('_sim' ) # similarity metrics are descending
208210
209211 embedding_type = _da_attr_type (index , search_field )
210- try :
211- comp_backend = embedding_type .get_comp_backend ()
212- except :
213- from docarray .typing .tensor .embedding .ndarray import NdArrayEmbedding
212+ if embedding_type == AnyEmbedding :
214213 embedding_type = NdArrayEmbedding
215- comp_backend = embedding_type .get_comp_backend ()
214+
215+ comp_backend = embedding_type .get_comp_backend ()
216216
217217 # extract embeddings from query and index
218218 if cache is not None and search_field in cache :
Original file line number Diff line number Diff line change 66from torch import rand
77
88from docarray import BaseDoc , DocList
9+ from docarray .documents import TextDoc
910from docarray .index .backends .in_memory import InMemoryExactNNIndex
1011from docarray .typing import NdArray , TorchTensor
1112
@@ -112,6 +113,19 @@ class MyDoc(BaseDoc):
112113 assert len (scores ) == 0
113114
114115
116+ def test_with_text_doc ():
117+ index = InMemoryExactNNIndex [TextDoc ]()
118+
119+ docs = DocList [TextDoc ](
120+ [TextDoc (text = 'hey' , embedding = np .random .rand (128 )) for i in range (200 )]
121+ )
122+ index .index (docs )
123+ res = index .find_batched (docs [0 :10 ], search_field = 'embedding' )
124+ assert len (res .documents ) == 10
125+ for r in res .documents :
126+ assert len (r ) == 5
127+
128+
115129def test_concatenated_queries (doc_index ):
116130 query = SchemaDoc (text = 'query' , price = 0 , tensor = np .ones (10 ))
117131
You can’t perform that action at this time.
0 commit comments