Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: reconstruct on filter
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
Joan Fontanals Martinez committed Jul 27, 2023
commit 3f56d4407308171b729cc9a30ed690718058e871
34 changes: 27 additions & 7 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, db_config=None, **kwargs):
self._column_names: List[str] = []
self._create_docs_table()
self._sqlite_conn.commit()
self._num_docs = 0 # recompute again when needed
self._logger.info(f'{self.__class__.__name__} has been initialized')

@property
Expand Down Expand Up @@ -279,6 +280,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
self._index(data_by_columns, docs_validated, **kwargs)
self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
self._num_docs = 0 # recompute again when needed

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -329,7 +331,19 @@ def _filter(
limit: int,
) -> DocList:
rows = self._execute_filter(filter_query=filter_query, limit=limit)
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined]
hashed_ids = [doc_id for doc_id, _ in rows]
embeddings: OrderedDict[str, list] = OrderedDict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why specifically OrderedDict? I think normal dict in python will already be ordered by default

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still like to be specific about the requirement for order

for col_name, index in self._hnsw_indices.items():
embeddings[col_name] = index.get_items(hashed_ids)

docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
for i, row in enumerate(rows):
reconstruct_embeddings = {}
for col_name in embeddings.keys():
reconstruct_embeddings[col_name] = embeddings[col_name][i]
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))

return docs

def _filter_batched(
self,
Expand Down Expand Up @@ -379,6 +393,7 @@ def _del_items(self, doc_ids: Sequence[str]):

self._delete_docs_from_sqlite(doc_ids)
self._sqlite_conn.commit()
self._num_docs = 0 # recompute again when needed

def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
"""Get Documents from the hnswlib index, by `id`.
Expand All @@ -403,7 +418,9 @@ def num_docs(self) -> int:
"""
Get the number of documents.
"""
return self._get_num_docs_sqlite()
if self._num_docs == 0:
self._num_docs = self._get_num_docs_sqlite()
return self._num_docs

###############################################
# Helpers #
Expand Down Expand Up @@ -641,15 +658,18 @@ def accept_hashed_ids(id):
"""Accepts IDs that are in hashed_ids."""
return id in hashed_ids # type: ignore[operator]

# Choose the appropriate filter function based on whether hashed_ids was provided
extra_kwargs = {}
if hashed_ids:
extra_kwargs['filter'] = accept_hashed_ids
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}

# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
k = min(limit, len(hashed_ids)) if hashed_ids else limit
index = self._hnsw_indices[search_field]
labels, distances = index.knn_query(queries, k=int(limit), **extra_kwargs)

try:
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
except RuntimeError:
k = min(k, self.num_docs())
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)

result_das = [
self._get_docs_sqlite_hashed_id(
ids_per_query.tolist(),
Expand Down
12 changes: 8 additions & 4 deletions tests/index/hnswlib/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ def test_build_query_invalid_query():
HnswDocumentIndex._build_filter_query(query, param_values)


def test_filter_eq(doc_index):
docs = doc_index.filter({'text': {'$eq': 'text 1'}})
assert len(docs) == 1
assert docs[0].text == 'text 1'
def test_filter_eq(doc_index, docs):
filter_result = doc_index.filter({'text': {'$eq': 'text 1'}})
assert len(filter_result) == 1
assert filter_result[0].text == 'text 1'
assert filter_result[0].text == docs[1].text
assert filter_result[0].price == docs[1].price
assert filter_result[0].id == docs[1].id
np.testing.assert_array_almost_equal(filter_result[0].tensor, docs[1].tensor)


def test_filter_neq(doc_index):
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.