Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor: hnswlib performance
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
Joan Fontanals Martinez committed Jul 26, 2023
commit 9aebf2e912515716c4ef35088141673a5ac270bc
57 changes: 36 additions & 21 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import os
import sqlite3
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -34,7 +35,7 @@
_collect_query_args,
_execute_find_and_filter_query,
)
from docarray.proto import DocProto
from docarray.proto import DocProto, NdArrayProto, NodeProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
from docarray.utils._internal._typing import safe_issubclass
Expand Down Expand Up @@ -63,7 +64,6 @@
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)


TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='HnswDocumentIndex')

Expand Down Expand Up @@ -127,7 +127,6 @@ def __init__(self, db_config=None, **kwargs):
self._sqlite_cursor = self._sqlite_conn.cursor()
self._create_docs_table()
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()
self._logger.info(f'{self.__class__.__name__} has been initialized')

@property
Expand Down Expand Up @@ -255,12 +254,9 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
docs_validated = self._validate_docs(docs)
self._update_subindex_data(docs_validated)
data_by_columns = self._get_col_value_dict(docs_validated)

self._index(data_by_columns, docs_validated, **kwargs)

self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -293,10 +289,6 @@ def _find_batched(
limit: int,
search_field: str = '',
) -> _FindResultBatched:
if self.num_docs() == 0:
return _FindResultBatched(documents=[], scores=[]) # type: ignore

limit = min(limit, self.num_docs())

index = self._hnsw_indices[search_field]
labels, distances = index.knn_query(queries, k=int(limit))
Expand All @@ -311,9 +303,6 @@ def _find_batched(
def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
) -> _FindResult:
if self.num_docs() == 0:
return _FindResult(documents=[], scores=[]) # type: ignore

query_batched = np.expand_dims(query, axis=0)
docs, scores = self._find_batched(
queries=query_batched, limit=limit, search_field=search_field
Expand Down Expand Up @@ -381,7 +370,6 @@ def _del_items(self, doc_ids: Sequence[str]):

self._delete_docs_from_sqlite(doc_ids)
self._sqlite_conn.commit()
self._num_docs = self._get_num_docs_sqlite()

def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
"""Get Documents from the hnswlib index, by `id`.
Expand All @@ -406,7 +394,7 @@ def num_docs(self) -> int:
"""
Get the number of documents.
"""
return self._num_docs
return self._get_num_docs_sqlite()

###############################################
# Helpers #
Expand Down Expand Up @@ -471,10 +459,19 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
self._sqlite_cursor.execute(
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
)
embeddings: OrderedDict[str, list] = OrderedDict()
for col_name, index in self._hnsw_indices.items():
embeddings[col_name] = index.get_items(univ_ids)
rows = self._sqlite_cursor.fetchall()
schema = self.out_schema if out else self._schema
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows])
docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))()
for i, row in enumerate(rows):
reconstruct_embeddings = {}
for col_name in embeddings.keys():
reconstruct_embeddings[col_name] = embeddings[col_name][i]
docs.append(self._doc_from_bytes(row[0], reconstruct_embeddings, out))

return docs

def _get_docs_sqlite_doc_id(
self, doc_ids: Sequence[str], out: bool = True
Expand Down Expand Up @@ -509,12 +506,30 @@ def _get_num_docs_sqlite(self) -> int:

# serialization helpers
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
return doc.to_protobuf().SerializeToString()

def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc:
pb = doc.to_protobuf()
for col_name in self._hnsw_indices.keys():
pb.data[col_name].Clear()
return pb.SerializeToString()

def _doc_from_bytes(
self, data: bytes, reconstruct_embeddings: Dict, out: bool = True
) -> BaseDoc:
schema = self.out_schema if out else self._schema
schema_cls = cast(Type[BaseDoc], schema)
return schema_cls.from_protobuf(DocProto.FromString(data))
pb = DocProto.FromString(data)
for k, v in reconstruct_embeddings.items():
nd_proto = NdArrayProto()
np_array = np.array(v)
nd_proto.dense.buffer = np_array.tobytes()
nd_proto.dense.ClearField('shape')
nd_proto.dense.shape.extend(list(np_array.shape))
nd_proto.dense.dtype = np_array.dtype.str
node_proto = NodeProto(ndarray=nd_proto, type='ndarray')

pb.data[k].MergeFrom(node_proto)

doc = schema_cls.from_protobuf(pb)
return doc

def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
Expand Down