Skip to content

Commit 19aec21

Browse files
Joan Fontanalsjupyterjazz
andauthored
refactor: do not recompute every time num_docs (#1729)
Signed-off-by: Joan Fontanals Martinez <[email protected]> Signed-off-by: jupyterjazz <[email protected]> Co-authored-by: jupyterjazz <[email protected]>
1 parent 24143a1 commit 19aec21

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

docarray/index/backends/hnswlib.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs):
137137
self._column_names: List[str] = []
138138
self._create_docs_table()
139139
self._sqlite_conn.commit()
140-
self._num_docs = self._get_num_docs_sqlite()
140+
self._num_docs = 0 # recompute again when needed
141141
self._logger.info(f'{self.__class__.__name__} has been initialized')
142142

143143
@property
@@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
281281

282282
self._send_docs_to_sqlite(docs_validated)
283283
self._sqlite_conn.commit()
284-
self._num_docs = self._get_num_docs_sqlite()
284+
self._num_docs = 0 # recompute again when needed
285285

286286
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
287287
"""
@@ -318,9 +318,6 @@ def _find_batched(
318318
def _find(
319319
self, query: np.ndarray, limit: int, search_field: str = ''
320320
) -> _FindResult:
321-
if self.num_docs() == 0:
322-
return _FindResult(documents=[], scores=[]) # type: ignore
323-
324321
query_batched = np.expand_dims(query, axis=0)
325322
docs, scores = self._find_batched(
326323
queries=query_batched, limit=limit, search_field=search_field
@@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]):
385382

386383
self._delete_docs_from_sqlite(doc_ids)
387384
self._sqlite_conn.commit()
388-
self._num_docs = self._get_num_docs_sqlite()
385+
self._num_docs = 0 # recompute again when needed
389386

390387
def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
391388
"""Get Documents from the hnswlib index, by `id`.
@@ -410,6 +407,8 @@ def num_docs(self) -> int:
410407
"""
411408
Get the number of documents.
412409
"""
410+
if self._num_docs == 0:
411+
self._num_docs = self._get_num_docs_sqlite()
413412
return self._num_docs
414413

415414
###############################################
@@ -605,7 +604,7 @@ def _search_and_filter(
605604
documents and their corresponding scores.
606605
"""
607606
# If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched
608-
if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0):
607+
if hashed_ids is not None and len(hashed_ids) == 0:
609608
return _FindResultBatched(documents=[], scores=[]) # type: ignore
610609

611610
# Set limit as the minimum of the provided limit and the total number of documents
@@ -628,8 +627,11 @@ def accept_hashed_ids(id):
628627

629628
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
630629
k = min(limit, len(hashed_ids)) if hashed_ids else limit
631-
632-
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
630+
try:
631+
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
632+
except RuntimeError:
633+
k = min(k, self.num_docs())
634+
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
633635

634636
result_das = [
635637
self._get_docs_sqlite_hashed_id(

0 commit comments

Comments
 (0)