@@ -138,6 +138,7 @@ def __init__(self, db_config=None, **kwargs):
138138 self ._column_names : List [str ] = []
139139 self ._create_docs_table ()
140140 self ._sqlite_conn .commit ()
141+ self ._num_docs = 0 # recompute again when needed
141142 self ._logger .info (f'{ self .__class__ .__name__ } has been initialized' )
142143
143144 @property
@@ -279,6 +280,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
279280 self ._index (data_by_columns , docs_validated , ** kwargs )
280281 self ._send_docs_to_sqlite (docs_validated )
281282 self ._sqlite_conn .commit ()
283+ self ._num_docs = 0 # recompute again when needed
282284
283285 def execute_query (self , query : List [Tuple [str , Dict ]], * args , ** kwargs ) -> Any :
284286 """
@@ -329,7 +331,19 @@ def _filter(
329331 limit : int ,
330332 ) -> DocList :
331333 rows = self ._execute_filter (filter_query = filter_query , limit = limit )
332- return DocList [self .out_schema ](self ._doc_from_bytes (blob ) for _ , blob in rows ) # type: ignore[name-defined]
334+ hashed_ids = [doc_id for doc_id , _ in rows ]
335+ embeddings : OrderedDict [str , list ] = OrderedDict ()
336+ for col_name , index in self ._hnsw_indices .items ():
337+ embeddings [col_name ] = index .get_items (hashed_ids )
338+
339+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], self .out_schema ))()
340+ for i , row in enumerate (rows ):
341+ reconstruct_embeddings = {}
342+ for col_name in embeddings .keys ():
343+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
344+ docs .append (self ._doc_from_bytes (row [1 ], reconstruct_embeddings ))
345+
346+ return docs
333347
334348 def _filter_batched (
335349 self ,
@@ -379,6 +393,7 @@ def _del_items(self, doc_ids: Sequence[str]):
379393
380394 self ._delete_docs_from_sqlite (doc_ids )
381395 self ._sqlite_conn .commit ()
396+ self ._num_docs = 0 # recompute again when needed
382397
383398 def _get_items (self , doc_ids : Sequence [str ], out : bool = True ) -> Sequence [TSchema ]:
384399 """Get Documents from the hnswlib index, by `id`.
@@ -403,7 +418,9 @@ def num_docs(self) -> int:
403418 """
404419 Get the number of documents.
405420 """
406- return self ._get_num_docs_sqlite ()
421+ if self ._num_docs == 0 :
422+ self ._num_docs = self ._get_num_docs_sqlite ()
423+ return self ._num_docs
407424
408425 ###############################################
409426 # Helpers #
@@ -641,15 +658,18 @@ def accept_hashed_ids(id):
641658 """Accepts IDs that are in hashed_ids."""
642659 return id in hashed_ids # type: ignore[operator]
643660
644- # Choose the appropriate filter function based on whether hashed_ids was provided
645- extra_kwargs = {}
646- if hashed_ids :
647- extra_kwargs ['filter' ] = accept_hashed_ids
661+ extra_kwargs = {'filter' : accept_hashed_ids } if hashed_ids else {}
648662
649663 # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
650664 k = min (limit , len (hashed_ids )) if hashed_ids else limit
651665 index = self ._hnsw_indices [search_field ]
652- labels , distances = index .knn_query (queries , k = int (limit ), ** extra_kwargs )
666+
667+ try :
668+ labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
669+ except RuntimeError :
670+ k = min (k , self .num_docs ())
671+ labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
672+
653673 result_das = [
654674 self ._get_docs_sqlite_hashed_id (
655675 ids_per_query .tolist (),
0 commit comments