Skip to content

Commit f95b623

Browse files
author
Joan Fontanals Martinez
committed
refactor: do not recompute every time num_docs
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent efeab90 commit f95b623

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

docarray/index/backends/hnswlib.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs):
137137
self._column_names: List[str] = []
138138
self._create_docs_table()
139139
self._sqlite_conn.commit()
140-
self._num_docs = self._get_num_docs_sqlite()
140+
self._num_docs = 0 # recompute again when needed
141141
self._logger.info(f'{self.__class__.__name__} has been initialized')
142142

143143
@property
@@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
281281

282282
self._send_docs_to_sqlite(docs_validated)
283283
self._sqlite_conn.commit()
284-
self._num_docs = self._get_num_docs_sqlite()
284+
self._num_docs = 0 # recompute again when needed
285285

286286
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
287287
"""
@@ -318,9 +318,6 @@ def _find_batched(
318318
def _find(
319319
self, query: np.ndarray, limit: int, search_field: str = ''
320320
) -> _FindResult:
321-
if self.num_docs() == 0:
322-
return _FindResult(documents=[], scores=[]) # type: ignore
323-
324321
query_batched = np.expand_dims(query, axis=0)
325322
docs, scores = self._find_batched(
326323
queries=query_batched, limit=limit, search_field=search_field
@@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]):
385382

386383
self._delete_docs_from_sqlite(doc_ids)
387384
self._sqlite_conn.commit()
388-
self._num_docs = self._get_num_docs_sqlite()
385+
self._num_docs = 0 # recompute again when needed
389386

390387
def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
391388
"""Get Documents from the hnswlib index, by `id`.
@@ -410,6 +407,8 @@ def num_docs(self) -> int:
410407
"""
411408
Get the number of documents.
412409
"""
410+
if self._num_docs == 0:
411+
self._num_docs = self._get_num_docs_sqlite()
413412
return self._num_docs
414413

415414
###############################################
@@ -605,11 +604,11 @@ def _search_and_filter(
605604
documents and their corresponding scores.
606605
"""
607606
# If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched
608-
if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0):
607+
if hashed_ids is not None and len(hashed_ids) == 0:
609608
return _FindResultBatched(documents=[], scores=[]) # type: ignore
610609

611610
# Set limit as the minimum of the provided limit and the total number of documents
612-
limit = min(limit, self.num_docs())
611+
limit = limit
613612

614613
# Ensure the search field is in the HNSW indices
615614
if search_field not in self._hnsw_indices:

0 commit comments

Comments
 (0)