Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs):
self._column_names: List[str] = []
self._create_docs_table()
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()
self._num_docs = 0 # recompute again when needed
self._logger.info(f'{self.__class__.__name__} has been initialized')

@property
Expand Down Expand Up @@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):

self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()
self._num_docs = 0 # recompute again when needed

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -318,9 +318,6 @@ def _find_batched(
def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
) -> _FindResult:
if self.num_docs() == 0:
return _FindResult(documents=[], scores=[]) # type: ignore

query_batched = np.expand_dims(query, axis=0)
docs, scores = self._find_batched(
queries=query_batched, limit=limit, search_field=search_field
Expand Down Expand Up @@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]):

self._delete_docs_from_sqlite(doc_ids)
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()
self._num_docs = 0 # recompute again when needed

def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
"""Get Documents from the hnswlib index, by `id`.
Expand All @@ -410,6 +407,8 @@ def num_docs(self) -> int:
"""
Get the number of documents.
"""
if self._num_docs == 0:
self._num_docs = self._get_num_docs_sqlite()
return self._num_docs

###############################################
Expand Down Expand Up @@ -605,7 +604,7 @@ def _search_and_filter(
documents and their corresponding scores.
"""
# If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched
if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0):
if hashed_ids is not None and len(hashed_ids) == 0:
return _FindResultBatched(documents=[], scores=[]) # type: ignore

# Set limit as the minimum of the provided limit and the total number of documents
Expand All @@ -628,8 +627,11 @@ def accept_hashed_ids(id):

# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
k = min(limit, len(hashed_ids)) if hashed_ids else limit

labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
try:
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
except RuntimeError:
k = min(k, self.num_docs())
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)

result_das = [
self._get_docs_sqlite_hashed_id(
Expand Down