22import hashlib
33import os
44import sqlite3
5- from collections import defaultdict
5+ from collections import OrderedDict , defaultdict
66from dataclasses import dataclass , field
77from pathlib import Path
88from typing import (
3232 _raise_not_composable ,
3333 _raise_not_supported ,
3434)
35- from docarray .index .backends .helper import _collect_query_args
35+ from docarray .index .backends .helper import (
36+ _collect_query_args ,
37+ )
3638from docarray .proto import DocProto
3739from docarray .typing .tensor .abstract_tensor import AbstractTensor
3840from docarray .typing .tensor .ndarray import NdArray
6365 HNSWLIB_PY_VEC_TYPES .append (tf .Tensor )
6466 HNSWLIB_PY_VEC_TYPES .append (TensorFlowTensor )
6567
66-
6768TSchema = TypeVar ('TSchema' , bound = BaseDoc )
6869T = TypeVar ('T' , bound = 'HnswDocumentIndex' )
6970
@@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs):
107108 if col .config
108109 }
109110 self ._hnsw_indices = {}
111+ sub_docs_exist = False
112+ cosine_metric_index_exist = False
110113 for col_name , col in self ._column_infos .items ():
114+ if '__' in col_name :
115+ sub_docs_exist = True
111116 if safe_issubclass (col .docarray_type , AnyDocArray ):
112117 continue
113118 if not col .config :
@@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs):
127132 else :
128133 self ._hnsw_indices [col_name ] = self ._create_index (col_name , col )
129134 self ._logger .info (f'Created a new index for column `{ col_name } `' )
135+ if self ._hnsw_indices [col_name ].space == 'cosine' :
136+ cosine_metric_index_exist = True
130137
138+ self ._apply_optim_no_embedding_in_sqlite = (
139+ not sub_docs_exist and not cosine_metric_index_exist
140+ ) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
131141 # SQLite setup
132142 self ._sqlite_db_path = os .path .join (self ._work_dir , 'docs_sqlite.db' )
133143 self ._logger .debug (f'DB path set to { self ._sqlite_db_path } ' )
@@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
276286 docs_validated = self ._validate_docs (docs )
277287 self ._update_subindex_data (docs_validated )
278288 data_by_columns = self ._get_col_value_dict (docs_validated )
279-
280289 self ._index (data_by_columns , docs_validated , ** kwargs )
281-
282290 self ._send_docs_to_sqlite (docs_validated )
283291 self ._sqlite_conn .commit ()
284292 self ._num_docs = 0 # recompute again when needed
@@ -332,7 +340,19 @@ def _filter(
332340 limit : int ,
333341 ) -> DocList :
334342 rows = self ._execute_filter (filter_query = filter_query , limit = limit )
335- return DocList [self .out_schema ](self ._doc_from_bytes (blob ) for _ , blob in rows ) # type: ignore[name-defined]
343+ hashed_ids = [doc_id for doc_id , _ in rows ]
344+ embeddings : OrderedDict [str , list ] = OrderedDict ()
345+ for col_name , index in self ._hnsw_indices .items ():
346+ embeddings [col_name ] = index .get_items (hashed_ids )
347+
348+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], self .out_schema ))()
349+ for i , row in enumerate (rows ):
350+ reconstruct_embeddings = {}
351+ for col_name in embeddings .keys ():
352+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
353+ docs .append (self ._doc_from_bytes (row [1 ], reconstruct_embeddings ))
354+
355+ return docs
336356
337357 def _filter_batched (
338358 self ,
@@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
501521 assert isinstance (id_ , int ) or is_np_int (id_ )
502522 sql_id_list = '(' + ', ' .join (str (id_ ) for id_ in univ_ids ) + ')'
503523 self ._sqlite_cursor .execute (
504- 'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list ,
524+ 'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list ,
505525 )
506- rows = self ._sqlite_cursor .fetchall ()
526+ rows = (
527+ self ._sqlite_cursor .fetchall ()
528+ ) # doc_ids do not come back in the same order
529+ embeddings : OrderedDict [str , list ] = OrderedDict ()
530+ for col_name , index in self ._hnsw_indices .items ():
531+ embeddings [col_name ] = index .get_items ([row [0 ] for row in rows ])
532+
507533 schema = self .out_schema if out else self ._schema
508- docs_cls = DocList .__class_getitem__ (cast (Type [BaseDoc ], schema ))
509- return docs_cls ([self ._doc_from_bytes (row [0 ], out ) for row in rows ])
534+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], schema ))()
535+ for i , (_ , data_bytes ) in enumerate (rows ):
536+ reconstruct_embeddings = {}
537+ for col_name in embeddings .keys ():
538+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
539+ docs .append (self ._doc_from_bytes (data_bytes , reconstruct_embeddings , out ))
540+
541+ return docs
510542
511543 def _get_docs_sqlite_doc_id (
512544 self , doc_ids : Sequence [str ], out : bool = True
@@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int:
541573
542574 # serialization helpers
543575 def _doc_to_bytes (self , doc : BaseDoc ) -> bytes :
544- return doc .to_protobuf ().SerializeToString ()
545-
546- def _doc_from_bytes (self , data : bytes , out : bool = True ) -> BaseDoc :
576+ pb = doc .to_protobuf ()
577+ if self ._apply_optim_no_embedding_in_sqlite :
578+ for col_name in self ._hnsw_indices .keys ():
579+ pb .data [col_name ].Clear ()
580+ pb .data [col_name ].Clear ()
581+ return pb .SerializeToString ()
582+
583+ def _doc_from_bytes (
584+ self , data : bytes , reconstruct_embeddings : Dict , out : bool = True
585+ ) -> BaseDoc :
547586 schema = self .out_schema if out else self ._schema
548587 schema_cls = cast (Type [BaseDoc ], schema )
549- return schema_cls .from_protobuf (DocProto .FromString (data ))
588+ pb = DocProto .FromString (
589+ data
590+ ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
591+ if self ._apply_optim_no_embedding_in_sqlite :
592+ for k , v in reconstruct_embeddings .items ():
593+ node_proto = (
594+ schema_cls ._get_field_type (k )
595+ ._docarray_from_ndarray (np .array (v ))
596+ ._to_node_protobuf ()
597+ )
598+ pb .data [k ].MergeFrom (node_proto )
599+
600+ doc = schema_cls .from_protobuf (pb )
601+ return doc
550602
551603 def _get_root_doc_id (self , id : str , root : str , sub : str ) -> str :
552604 """Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
@@ -608,25 +660,24 @@ def _search_and_filter(
608660 return _FindResultBatched (documents = [], scores = []) # type: ignore
609661
610662 # Set limit as the minimum of the provided limit and the total number of documents
611- limit = min ( limit , self . num_docs ())
663+ limit = limit
612664
613665 # Ensure the search field is in the HNSW indices
614666 if search_field not in self ._hnsw_indices :
615667 raise ValueError (
616668 f'Search field { search_field } is not present in the HNSW indices'
617669 )
618670
619- index = self ._hnsw_indices [search_field ]
620-
621671 def accept_hashed_ids (id ):
622672 """Accepts IDs that are in hashed_ids."""
623673 return id in hashed_ids # type: ignore[operator]
624674
625- # Choose the appropriate filter function based on whether hashed_ids was provided
626675 extra_kwargs = {'filter' : accept_hashed_ids } if hashed_ids else {}
627676
628677 # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
629678 k = min (limit , len (hashed_ids )) if hashed_ids else limit
679+ index = self ._hnsw_indices [search_field ]
680+
630681 try :
631682 labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
632683 except RuntimeError :
@@ -639,7 +690,6 @@ def accept_hashed_ids(id):
639690 )
640691 for ids_per_query in labels
641692 ]
642-
643693 return _FindResultBatched (documents = result_das , scores = distances )
644694
645695 @classmethod
0 commit comments