File tree Expand file tree Collapse file tree 3 files changed +31
-18
lines changed
Expand file tree Collapse file tree 3 files changed +31
-18
lines changed Original file line number Diff line number Diff line change 11from typing import (
2- TYPE_CHECKING ,
32 Any ,
43 Iterable ,
54 List ,
1413import numpy as np
1514from typing_extensions import SupportsIndex
1615
17- from docarray .utils ._internal .misc import import_library
16+ from docarray .utils ._internal .misc import (
17+ is_torch_available ,
18+ is_tf_available ,
19+ )
20+
21+ torch_available = is_torch_available ()
22+ if torch_available :
23+ import torch
24+ tf_available = is_tf_available ()
25+ if tf_available :
26+ import tensorflow as tf # type: ignore
27+ from docarray .typing .tensor .tensorflow_tensor import TensorFlowTensor
1828
1929T_item = TypeVar ('T_item' )
2030T = TypeVar ('T' , bound = 'ListAdvancedIndexing' )
@@ -75,17 +85,20 @@ def _normalize_index_item(
7585 return item .tolist ()
7686
7787 # torch index types
78- if TYPE_CHECKING :
79- import torch
80- else :
81- torch = import_library ('torch' , raise_error = True )
82-
83- allowed_torch_dtypes = [
84- torch .bool ,
85- torch .int64 ,
86- ]
87- if isinstance (item , torch .Tensor ) and (item .dtype in allowed_torch_dtypes ):
88- return item .tolist ()
88+ if torch_available :
89+
90+ allowed_torch_dtypes = [
91+ torch .bool ,
92+ torch .int64 ,
93+ ]
94+ if isinstance (item , torch .Tensor ) and (item .dtype in allowed_torch_dtypes ):
95+ return item .tolist ()
96+
97+ if tf_available :
98+ if isinstance (item , tf .Tensor ):
99+ return item .numpy ().tolist ()
100+ if isinstance (item , TensorFlowTensor ):
101+ return item .tensor .numpy ().tolist ()
89102
90103 return item
91104
Original file line number Diff line number Diff line change @@ -127,7 +127,7 @@ def test_with_text_doc_ndarray():
127127 [TextDoc (text = 'hey' , embedding = np .random .rand (128 )) for _ in range (200 )]
128128 )
129129 index .index (docs )
130- res = index .find_batched (docs [0 :10 ], search_field = 'embedding' )
130+ res = index .find_batched (docs [0 :10 ], search_field = 'embedding' , limit = 5 )
131131 assert len (res .documents ) == 10
132132 for r in res .documents :
133133 assert len (r ) == 5
@@ -144,7 +144,7 @@ def test_with_text_doc_tensorflow():
144144 ]
145145 )
146146 index .index (docs )
147- res = index .find_batched (docs [0 :10 ], search_field = 'embedding' )
147+ res = index .find_batched (docs [0 :10 ], search_field = 'embedding' , limit = 5 )
148148 assert len (res .documents ) == 10
149149 for r in res .documents :
150150 assert len (r ) == 5
@@ -159,7 +159,7 @@ def test_with_text_doc_torch():
159159 [TextDoc (text = 'hey' , embedding = torch .rand (128 )) for _ in range (200 )]
160160 )
161161 index .index (docs )
162- res = index .find_batched (docs [0 :10 ], search_field = 'embedding' )
162+ res = index .find_batched (docs [0 :10 ], search_field = 'embedding' , limit = 5 )
163163 assert len (res .documents ) == 10
164164 for r in res .documents :
165165 assert len (r ) == 5
Original file line number Diff line number Diff line change 1- version : ' 3.3 '
1+ version : ' 3.8 '
22
33services :
44
@@ -24,4 +24,4 @@ services:
2424 LOG_LEVEL : debug # verbose
2525 LOG_FORMAT : text
2626 # LOG_LEVEL: trace # very verbose
27- GODEBUG : gctrace=1 # make go garbage collector verbose
27+ GODEBUG : gctrace=1 # make go garbage collector verbose
You can’t perform that action at this time.
0 commit comments