|
1 | 1 | __all__ = ['find', 'find_batched'] |
2 | 2 |
|
3 | | -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast |
4 | | - |
5 | | -from typing_inspect import is_union_type |
| 3 | +from typing import ( |
| 4 | + Any, |
| 5 | + Dict, |
| 6 | + List, |
| 7 | + NamedTuple, |
| 8 | + Optional, |
| 9 | + Tuple, |
| 10 | + Union, |
| 11 | + cast, |
| 12 | + Type, |
| 13 | + TYPE_CHECKING, |
| 14 | +) |
6 | 15 |
|
7 | 16 | from docarray.array.any_array import AnyDocArray |
8 | 17 | from docarray.array.doc_list.doc_list import DocList |
9 | 18 | from docarray.array.doc_vec.doc_vec import DocVec |
10 | 19 | from docarray.base_doc import BaseDoc |
11 | | -from docarray.helper import _get_field_type_by_access_path |
12 | 20 | from docarray.typing import AnyTensor |
13 | | -from docarray.typing.tensor.abstract_tensor import AbstractTensor |
14 | | -from docarray.typing.tensor.embedding import AnyEmbedding |
15 | | -from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding |
16 | | -from docarray.utils._internal._typing import safe_issubclass |
| 21 | +from docarray.computation.numpy_backend import NumpyCompBackend |
| 22 | +from docarray.typing.tensor import NdArray |
| 23 | +from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa |
| 24 | + |
| 25 | +torch_available = is_torch_available() |
| 26 | +if torch_available: |
| 27 | + import torch |
| 28 | + |
| 29 | + from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 |
| 30 | + from docarray.computation.torch_backend import TorchCompBackend |
| 31 | + |
| 32 | +tf_available = is_tf_available() |
| 33 | +if tf_available: |
| 34 | + import tensorflow as tf # type: ignore |
| 35 | + |
| 36 | + from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 |
| 37 | + from docarray.computation.tensorflow_backend import TensorFlowCompBackend |
| 38 | + |
| 39 | +if TYPE_CHECKING: |
| 40 | + from docarray.computation.abstract_numpy_based_backend import ( |
| 41 | + AbstractComputationalBackend, |
| 42 | + ) |
| 43 | + from docarray.typing.tensor.abstract_tensor import AbstractTensor |
17 | 44 |
|
18 | 45 |
|
19 | 46 | class FindResult(NamedTuple): |
@@ -110,6 +137,7 @@ class MyDocument(BaseDoc): |
110 | 137 | can be either `cpu` or a `cuda` device. |
111 | 138 | :param descending: sort the results in descending order. |
112 | 139 | Per default, this is chosen based on the `metric` argument. |
| 140 | + :param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries. |
113 | 141 | :return: A named tuple of the form (DocList, AnyTensor), |
114 | 142 | where the first element contains the closes matches for the query, |
115 | 143 | and the second element contains the corresponding scores. |
@@ -201,32 +229,26 @@ class MyDocument(BaseDoc): |
201 | 229 | can be either `cpu` or a `cuda` device. |
202 | 230 | :param descending: sort the results in descending order. |
203 | 231 | Per default, this is chosen based on the `metric` argument. |
| 232 | + :param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries. |
204 | 233 | :return: A named tuple of the form (DocList, AnyTensor), |
205 | 234 | where the first element contains the closest matches for each query, |
206 | 235 | and the second element contains the corresponding scores. |
207 | 236 | """ |
208 | 237 | if descending is None: |
209 | 238 | descending = metric.endswith('_sim') # similarity metrics are descending |
210 | 239 |
|
211 | | - embedding_type = _da_attr_type(index, search_field) |
212 | | - if embedding_type == AnyEmbedding: |
213 | | - embedding_type = NdArrayEmbedding |
214 | | - |
215 | | - comp_backend = embedding_type.get_comp_backend() |
216 | | - |
217 | 240 | # extract embeddings from query and index |
218 | 241 | if cache is not None and search_field in cache: |
219 | 242 | index_embeddings, valid_idx = cache[search_field] |
220 | 243 | else: |
221 | | - index_embeddings, valid_idx = _extract_embeddings( |
222 | | - index, search_field, embedding_type |
223 | | - ) |
| 244 | + index_embeddings, valid_idx = _extract_embeddings(index, search_field) |
224 | 245 | if cache is not None: |
225 | 246 | cache[search_field] = ( |
226 | 247 | index_embeddings, |
227 | 248 | valid_idx, |
228 | 249 | ) # cache embedding for next query |
229 | | - query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type) |
| 250 | + query_embeddings, _ = _extract_embeddings(query, search_field) |
| 251 | + _, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings) |
230 | 252 |
|
231 | 253 | # compute distances and return top results |
232 | 254 | metric_fn = getattr(comp_backend.Metrics, metric) |
@@ -272,60 +294,64 @@ def _extract_embedding_single( |
272 | 294 | return emb |
273 | 295 |
|
274 | 296 |
|
| 297 | +def _get_tensor_type_and_comp_backend_from_tensor( |
| 298 | + tensor, |
| 299 | +) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']: |
| 300 | + """Extract the embeddings from the data. |
| 301 | +
|
| 302 | + :param tensor: the tensor for which to extract |
| 303 | + :return: a tuple of the tensor type and the computational backend |
| 304 | + """ |
| 305 | + da_tensor_type: Type['AbstractTensor'] = NdArray |
| 306 | + comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend() |
| 307 | + if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)): |
| 308 | + comp_backend = TorchCompBackend() |
| 309 | + da_tensor_type = TorchTensor |
| 310 | + elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)): |
| 311 | + comp_backend = TensorFlowCompBackend() |
| 312 | + da_tensor_type = TensorFlowTensor |
| 313 | + |
| 314 | + return da_tensor_type, comp_backend |
| 315 | + |
| 316 | + |
275 | 317 | def _extract_embeddings( |
276 | 318 | data: Union[AnyDocArray, BaseDoc, AnyTensor], |
277 | 319 | search_field: str, |
278 | | - embedding_type: Type, |
279 | 320 | ) -> Tuple[AnyTensor, Optional[List[int]]]: |
280 | 321 | """Extract the embeddings from the data. |
281 | 322 |
|
282 | 323 | :param data: the data |
283 | 324 | :param search_field: the embedding field |
284 | | - :param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc. |
285 | 325 | :return: a tuple of the embeddings and optionally a list of the non-null indices |
286 | 326 | """ |
287 | 327 | emb: AnyTensor |
288 | 328 | valid_idx = None |
| 329 | + comp_backend = None |
| 330 | + da_tensor_type = None |
289 | 331 | if isinstance(data, DocList): |
290 | 332 | emb_valid = [ |
291 | 333 | (emb, i) |
292 | 334 | for i, emb in enumerate(AnyDocArray._traverse(data, search_field)) |
293 | 335 | if emb is not None |
294 | 336 | ] |
295 | 337 | emb_list, valid_idx = zip(*emb_valid) |
296 | | - emb = embedding_type._docarray_stack(emb_list) |
| 338 | + if len(emb_list > 0): |
| 339 | + ( |
| 340 | + da_tensor_type, |
| 341 | + comp_backend, |
| 342 | + ) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0]) |
| 343 | + else: |
| 344 | + raise Exception(f'No embedding could be extracted from data {data}') |
| 345 | + |
| 346 | + emb = da_tensor_type._docarray_stack(emb_list) |
297 | 347 | elif isinstance(data, (DocVec, BaseDoc)): |
298 | 348 | emb = next(AnyDocArray._traverse(data, search_field)) |
299 | 349 | else: # treat data as tensor |
300 | 350 | emb = cast(AnyTensor, data) |
301 | 351 |
|
| 352 | + if comp_backend is None: |
| 353 | + _, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb) |
| 354 | + |
302 | 355 | if len(emb.shape) == 1: |
303 | | - emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1)) |
| 356 | + emb = comp_backend.reshape(tensor=emb, shape=(1, -1)) |
304 | 357 | return emb, valid_idx |
305 | | - |
306 | | - |
307 | | -def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]: |
308 | | - """Get the type of the attribute according to the Document type |
309 | | - (schema) of the DocList. |
310 | | -
|
311 | | - :param docs: the DocList |
312 | | - :param access_path: the "__"-separated access path |
313 | | - :return: the type of the attribute |
314 | | - """ |
315 | | - field_type: Optional[Type] = _get_field_type_by_access_path( |
316 | | - docs.doc_type, access_path |
317 | | - ) |
318 | | - if field_type is None: |
319 | | - raise ValueError(f"Access path is not valid: {access_path}") |
320 | | - |
321 | | - if is_union_type(field_type): |
322 | | - # determine type based on the fist element |
323 | | - field_type = type(next(AnyDocArray._traverse(docs[0], access_path))) |
324 | | - |
325 | | - if not safe_issubclass(field_type, AbstractTensor): |
326 | | - raise ValueError( |
327 | | - f'attribute {access_path} is not a tensor-like type, ' |
328 | | - f'but {field_type.__class__.__name__}' |
329 | | - ) |
330 | | - |
331 | | - return cast(Type[AnyTensor], field_type) |
|
0 commit comments