|
10 | 10 | from docarray.index.backends.in_memory import InMemoryExactNNIndex |
11 | 11 | from docarray.typing import NdArray, TorchTensor |
12 | 12 |
|
| 13 | +from docarray.utils._internal.misc import is_tf_available |
| 14 | + |
| 15 | +tf_available = is_tf_available() |
| 16 | +if tf_available: |
| 17 | + import tensorflow as tf |
| 18 | + from docarray.typing import TensorFlowTensor |
| 19 | + |
13 | 20 |
|
14 | 21 | class SchemaDoc(BaseDoc): |
15 | 22 | text: str |
@@ -113,11 +120,43 @@ class MyDoc(BaseDoc): |
113 | 120 | assert len(scores) == 0 |
114 | 121 |
|
115 | 122 |
|
116 | | -def test_with_text_doc(): |
| 123 | +def test_with_text_doc_ndarray(): |
| 124 | + index = InMemoryExactNNIndex[TextDoc]() |
| 125 | + |
| 126 | + docs = DocList[TextDoc]( |
| 127 | + [TextDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)] |
| 128 | + ) |
| 129 | + index.index(docs) |
| 130 | + res = index.find_batched(docs[0:10], search_field='embedding') |
| 131 | + assert len(res.documents) == 10 |
| 132 | + for r in res.documents: |
| 133 | + assert len(r) == 5 |
| 134 | + |
| 135 | + |
| 136 | +@pytest.mark.tensorflow |
| 137 | +def test_with_text_doc_tensorflow(): |
| 138 | + index = InMemoryExactNNIndex[TextDoc]() |
| 139 | + |
| 140 | + docs = DocList[TextDoc]( |
| 141 | + [ |
| 142 | + TextDoc(text='hey', embedding=tf.random.uniform(shape=[128])) |
| 143 | + for _ in range(200) |
| 144 | + ] |
| 145 | + ) |
| 146 | + index.index(docs) |
| 147 | + res = index.find_batched(docs[0:10], search_field='embedding') |
| 148 | + assert len(res.documents) == 10 |
| 149 | + for r in res.documents: |
| 150 | + assert len(r) == 5 |
| 151 | + |
| 152 | + |
| 153 | +def test_with_text_doc_torch(): |
| 154 | + import torch |
| 155 | + |
117 | 156 | index = InMemoryExactNNIndex[TextDoc]() |
118 | 157 |
|
119 | 158 | docs = DocList[TextDoc]( |
120 | | - [TextDoc(text='hey', embedding=np.random.rand(128)) for i in range(200)] |
| 159 | + [TextDoc(text='hey', embedding=torch.rand(128)) for _ in range(200)] |
121 | 160 | ) |
122 | 161 | index.index(docs) |
123 | 162 | res = index.find_batched(docs[0:10], search_field='embedding') |
|
0 commit comments