Skip to content

Commit bab8099

Browse files
author
Joan Fontanals Martinez
committed
test: undo changes in docker compose
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent 2babe51 commit bab8099

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

docarray/array/list_advance_indexing.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import (
2-
TYPE_CHECKING,
32
Any,
43
Iterable,
54
List,
@@ -14,7 +13,18 @@
1413
import numpy as np
1514
from typing_extensions import SupportsIndex
1615

17-
from docarray.utils._internal.misc import import_library
16+
from docarray.utils._internal.misc import (
17+
is_torch_available,
18+
is_tf_available,
19+
)
20+
21+
torch_available = is_torch_available()
22+
if torch_available:
23+
import torch
24+
tf_available = is_tf_available()
25+
if tf_available:
26+
import tensorflow as tf # type: ignore
27+
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
1828

1929
T_item = TypeVar('T_item')
2030
T = TypeVar('T', bound='ListAdvancedIndexing')
@@ -75,17 +85,20 @@ def _normalize_index_item(
7585
return item.tolist()
7686

7787
# torch index types
78-
if TYPE_CHECKING:
79-
import torch
80-
else:
81-
torch = import_library('torch', raise_error=True)
82-
83-
allowed_torch_dtypes = [
84-
torch.bool,
85-
torch.int64,
86-
]
87-
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
88-
return item.tolist()
88+
if torch_available:
89+
90+
allowed_torch_dtypes = [
91+
torch.bool,
92+
torch.int64,
93+
]
94+
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
95+
return item.tolist()
96+
97+
if tf_available:
98+
if isinstance(item, tf.Tensor):
99+
return item.numpy().tolist()
100+
if isinstance(item, TensorFlowTensor):
101+
return item.tensor.numpy().tolist()
89102

90103
return item
91104

tests/index/in_memory/test_in_memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_with_text_doc_ndarray():
127127
[TextDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)]
128128
)
129129
index.index(docs)
130-
res = index.find_batched(docs[0:10], search_field='embedding')
130+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
131131
assert len(res.documents) == 10
132132
for r in res.documents:
133133
assert len(r) == 5
@@ -144,7 +144,7 @@ def test_with_text_doc_tensorflow():
144144
]
145145
)
146146
index.index(docs)
147-
res = index.find_batched(docs[0:10], search_field='embedding')
147+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
148148
assert len(res.documents) == 10
149149
for r in res.documents:
150150
assert len(r) == 5
@@ -159,7 +159,7 @@ def test_with_text_doc_torch():
159159
[TextDoc(text='hey', embedding=torch.rand(128)) for _ in range(200)]
160160
)
161161
index.index(docs)
162-
res = index.find_batched(docs[0:10], search_field='embedding')
162+
res = index.find_batched(docs[0:10], search_field='embedding', limit=5)
163163
assert len(res.documents) == 10
164164
for r in res.documents:
165165
assert len(r) == 5

tests/index/weaviate/docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: '3.3'
1+
version: '3.8'
22

33
services:
44

@@ -24,4 +24,4 @@ services:
2424
LOG_LEVEL: debug # verbose
2525
LOG_FORMAT: text
2626
# LOG_LEVEL: trace # very verbose
27-
GODEBUG: gctrace=1 # make go garbage collector verbose
27+
GODEBUG: gctrace=1 # make go garbage collector verbose

0 commit comments

Comments
 (0)