Skip to content

Commit e7315e8

Browse files
author
Joan Fontanals Martinez
committed
test: add tests for torch and tf
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent bc0e40d commit e7315e8

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

tests/index/in_memory/test_in_memory.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from docarray.index.backends.in_memory import InMemoryExactNNIndex
1111
from docarray.typing import NdArray, TorchTensor
1212

13+
from docarray.utils._internal.misc import is_tf_available
14+
15+
tf_available = is_tf_available()
16+
if tf_available:
17+
import tensorflow as tf
18+
from docarray.typing import TensorFlowTensor
19+
1320

1421
class SchemaDoc(BaseDoc):
1522
text: str
@@ -113,11 +120,43 @@ class MyDoc(BaseDoc):
113120
assert len(scores) == 0
114121

115122

116-
def test_with_text_doc():
123+
def test_with_text_doc_ndarray():
124+
index = InMemoryExactNNIndex[TextDoc]()
125+
126+
docs = DocList[TextDoc](
127+
[TextDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)]
128+
)
129+
index.index(docs)
130+
res = index.find_batched(docs[0:10], search_field='embedding')
131+
assert len(res.documents) == 10
132+
for r in res.documents:
133+
assert len(r) == 5
134+
135+
136+
@pytest.mark.tensorflow
137+
def test_with_text_doc_tensorflow():
138+
index = InMemoryExactNNIndex[TextDoc]()
139+
140+
docs = DocList[TextDoc](
141+
[
142+
TextDoc(text='hey', embedding=tf.random.uniform(shape=[128]))
143+
for _ in range(200)
144+
]
145+
)
146+
index.index(docs)
147+
res = index.find_batched(docs[0:10], search_field='embedding')
148+
assert len(res.documents) == 10
149+
for r in res.documents:
150+
assert len(r) == 5
151+
152+
153+
def test_with_text_doc_torch():
154+
import torch
155+
117156
index = InMemoryExactNNIndex[TextDoc]()
118157

119158
docs = DocList[TextDoc](
120-
[TextDoc(text='hey', embedding=np.random.rand(128)) for i in range(200)]
159+
[TextDoc(text='hey', embedding=torch.rand(128)) for _ in range(200)]
121160
)
122161
index.index(docs)
123162
res = index.find_batched(docs[0:10], search_field='embedding')

0 commit comments

Comments
 (0)