Skip to content

Commit a643f6a

Browse files
author
Joan Fontanals
authored
refactor: hnswlib performance (#1727)
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent 87ec19f commit a643f6a

File tree

6 files changed

+92
-50
lines changed

6 files changed

+92
-50
lines changed

docarray/index/backends/hnswlib.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hashlib
33
import os
44
import sqlite3
5-
from collections import defaultdict
5+
from collections import OrderedDict, defaultdict
66
from dataclasses import dataclass, field
77
from pathlib import Path
88
from typing import (
@@ -32,7 +32,9 @@
3232
_raise_not_composable,
3333
_raise_not_supported,
3434
)
35-
from docarray.index.backends.helper import _collect_query_args
35+
from docarray.index.backends.helper import (
36+
_collect_query_args,
37+
)
3638
from docarray.proto import DocProto
3739
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3840
from docarray.typing.tensor.ndarray import NdArray
@@ -63,7 +65,6 @@
6365
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
6466
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)
6567

66-
6768
TSchema = TypeVar('TSchema', bound=BaseDoc)
6869
T = TypeVar('T', bound='HnswDocumentIndex')
6970

@@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs):
107108
if col.config
108109
}
109110
self._hnsw_indices = {}
111+
sub_docs_exist = False
112+
cosine_metric_index_exist = False
110113
for col_name, col in self._column_infos.items():
114+
if '__' in col_name:
115+
sub_docs_exist = True
111116
if safe_issubclass(col.docarray_type, AnyDocArray):
112117
continue
113118
if not col.config:
@@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs):
127132
else:
128133
self._hnsw_indices[col_name] = self._create_index(col_name, col)
129134
self._logger.info(f'Created a new index for column `{col_name}`')
135+
if self._hnsw_indices[col_name].space == 'cosine':
136+
cosine_metric_index_exist = True
130137

138+
self._apply_optim_no_embedding_in_sqlite = (
139+
not sub_docs_exist and not cosine_metric_index_exist
140+
) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
131141
# SQLite setup
132142
self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db')
133143
self._logger.debug(f'DB path set to {self._sqlite_db_path}')
@@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
276286
docs_validated = self._validate_docs(docs)
277287
self._update_subindex_data(docs_validated)
278288
data_by_columns = self._get_col_value_dict(docs_validated)
279-
280289
self._index(data_by_columns, docs_validated, **kwargs)
281-
282290
self._send_docs_to_sqlite(docs_validated)
283291
self._sqlite_conn.commit()
284292
self._num_docs = 0 # recompute again when needed
@@ -332,7 +340,19 @@ def _filter(
332340
limit: int,
333341
) -> DocList:
334342
rows = self._execute_filter(filter_query=filter_query, limit=limit)
335-
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined]
343+
hashed_ids = [doc_id for doc_id, _ in rows]
344+
embeddings: OrderedDict[str, list] = OrderedDict()
345+
for col_name, index in self._hnsw_indices.items():
346+
embeddings[col_name] = index.get_items(hashed_ids)
347+
348+
docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
349+
for i, row in enumerate(rows):
350+
reconstruct_embeddings = {}
351+
for col_name in embeddings.keys():
352+
reconstruct_embeddings[col_name] = embeddings[col_name][i]
353+
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))
354+
355+
return docs
336356

337357
def _filter_batched(
338358
self,
@@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
501521
assert isinstance(id_, int) or is_np_int(id_)
502522
sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')'
503523
self._sqlite_cursor.execute(
504-
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
524+
'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list,
505525
)
506-
rows = self._sqlite_cursor.fetchall()
526+
rows = (
527+
self._sqlite_cursor.fetchall()
528+
) # doc_ids do not come back in the same order
529+
embeddings: OrderedDict[str, list] = OrderedDict()
530+
for col_name, index in self._hnsw_indices.items():
531+
embeddings[col_name] = index.get_items([row[0] for row in rows])
532+
507533
schema = self.out_schema if out else self._schema
508-
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
509-
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows])
534+
docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))()
535+
for i, (_, data_bytes) in enumerate(rows):
536+
reconstruct_embeddings = {}
537+
for col_name in embeddings.keys():
538+
reconstruct_embeddings[col_name] = embeddings[col_name][i]
539+
docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out))
540+
541+
return docs
510542

511543
def _get_docs_sqlite_doc_id(
512544
self, doc_ids: Sequence[str], out: bool = True
@@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int:
541573

542574
# serialization helpers
543575
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
544-
return doc.to_protobuf().SerializeToString()
545-
546-
def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc:
576+
pb = doc.to_protobuf()
577+
if self._apply_optim_no_embedding_in_sqlite:
578+
for col_name in self._hnsw_indices.keys():
579+
pb.data[col_name].Clear()
580+
pb.data[col_name].Clear()
581+
return pb.SerializeToString()
582+
583+
def _doc_from_bytes(
584+
self, data: bytes, reconstruct_embeddings: Dict, out: bool = True
585+
) -> BaseDoc:
547586
schema = self.out_schema if out else self._schema
548587
schema_cls = cast(Type[BaseDoc], schema)
549-
return schema_cls.from_protobuf(DocProto.FromString(data))
588+
pb = DocProto.FromString(
589+
data
590+
) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
591+
if self._apply_optim_no_embedding_in_sqlite:
592+
for k, v in reconstruct_embeddings.items():
593+
node_proto = (
594+
schema_cls._get_field_type(k)
595+
._docarray_from_ndarray(np.array(v))
596+
._to_node_protobuf()
597+
)
598+
pb.data[k].MergeFrom(node_proto)
599+
600+
doc = schema_cls.from_protobuf(pb)
601+
return doc
550602

551603
def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
552604
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
@@ -608,25 +660,24 @@ def _search_and_filter(
608660
return _FindResultBatched(documents=[], scores=[]) # type: ignore
609661

610662
# Set limit as the minimum of the provided limit and the total number of documents
611-
limit = min(limit, self.num_docs())
663+
limit = limit
612664

613665
# Ensure the search field is in the HNSW indices
614666
if search_field not in self._hnsw_indices:
615667
raise ValueError(
616668
f'Search field {search_field} is not present in the HNSW indices'
617669
)
618670

619-
index = self._hnsw_indices[search_field]
620-
621671
def accept_hashed_ids(id):
622672
"""Accepts IDs that are in hashed_ids."""
623673
return id in hashed_ids # type: ignore[operator]
624674

625-
# Choose the appropriate filter function based on whether hashed_ids was provided
626675
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}
627676

628677
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
629678
k = min(limit, len(hashed_ids)) if hashed_ids else limit
679+
index = self._hnsw_indices[search_field]
680+
630681
try:
631682
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
632683
except RuntimeError:
@@ -639,7 +690,6 @@ def accept_hashed_ids(id):
639690
)
640691
for ids_per_query in labels
641692
]
642-
643693
return _FindResultBatched(documents=result_das, scores=distances)
644694

645695
@classmethod

tests/index/elastic/v7/docker-compose.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,3 @@ services:
88
- ES_JAVA_OPTS=-Xmx1024m
99
ports:
1010
- "9200:9200"
11-
networks:
12-
- elastic
13-
14-
networks:
15-
elastic:
16-
name: elastic

tests/index/elastic/v8/docker-compose.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,3 @@ services:
88
- ES_JAVA_OPTS=-Xmx1024m
99
ports:
1010
- "9200:9200"
11-
networks:
12-
- elastic
13-
14-
networks:
15-
elastic:
16-
name: elastic

tests/index/hnswlib/test_filter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ def test_build_query_invalid_query():
6969
HnswDocumentIndex._build_filter_query(query, param_values)
7070

7171

72-
def test_filter_eq(doc_index):
73-
docs = doc_index.filter({'text': {'$eq': 'text 1'}})
74-
assert len(docs) == 1
75-
assert docs[0].text == 'text 1'
72+
def test_filter_eq(doc_index, docs):
73+
filter_result = doc_index.filter({'text': {'$eq': 'text 1'}})
74+
assert len(filter_result) == 1
75+
assert filter_result[0].text == 'text 1'
76+
assert filter_result[0].text == docs[1].text
77+
assert filter_result[0].price == docs[1].price
78+
assert filter_result[0].id == docs[1].id
79+
assert np.allclose(filter_result[0].tensor, docs[1].tensor)
7680

7781

7882
def test_filter_neq(doc_index):

tests/index/hnswlib/test_index_get_del.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
211211
for d in ten_simple_docs:
212212
id_ = d.id
213213
assert index[id_].id == id_
214-
assert np.all(index[id_].tens == d.tens)
214+
assert np.allclose(index[id_].tens, d.tens)
215215

216216
# flat
217217
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
@@ -221,8 +221,8 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
221221
for d in ten_flat_docs:
222222
id_ = d.id
223223
assert index[id_].id == id_
224-
assert np.all(index[id_].tens_one == d.tens_one)
225-
assert np.all(index[id_].tens_two == d.tens_two)
224+
assert np.allclose(index[id_].tens_one, d.tens_one)
225+
assert np.allclose(index[id_].tens_two, d.tens_two)
226226

227227
# nested
228228
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
@@ -233,7 +233,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
233233
id_ = d.id
234234
assert index[id_].id == id_
235235
assert index[id_].d.id == d.d.id
236-
assert np.all(index[id_].d.tens == d.d.tens)
236+
assert np.allclose(index[id_].d.tens, d.d.tens)
237237

238238

239239
def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -252,7 +252,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
252252
retrieved_docs = index[ids_to_get]
253253
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
254254
assert d_out.id == id_
255-
assert np.all(d_out.tens == d_in.tens)
255+
assert np.allclose(d_out.tens, d_in.tens)
256256

257257
# flat
258258
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
@@ -264,8 +264,8 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
264264
retrieved_docs = index[ids_to_get]
265265
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
266266
assert d_out.id == id_
267-
assert np.all(d_out.tens_one == d_in.tens_one)
268-
assert np.all(d_out.tens_two == d_in.tens_two)
267+
assert np.allclose(d_out.tens_one, d_in.tens_one)
268+
assert np.allclose(d_out.tens_two, d_in.tens_two)
269269

270270
# nested
271271
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
@@ -278,7 +278,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
278278
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
279279
assert d_out.id == id_
280280
assert d_out.d.id == d_in.d.id
281-
assert np.all(d_out.d.tens == d_in.d.tens)
281+
assert np.allclose(d_out.d.tens, d_in.d.tens)
282282

283283

284284
def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -303,7 +303,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
303303
index[id_]
304304
else:
305305
assert index[id_].id == id_
306-
assert np.all(index[id_].tens == d.tens)
306+
assert np.allclose(index[id_].tens, d.tens)
307307
# delete again
308308
del index[ten_simple_docs[3].id]
309309
assert index.num_docs() == 8
@@ -314,7 +314,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
314314
index[id_]
315315
else:
316316
assert index[id_].id == id_
317-
assert np.all(index[id_].tens == d.tens)
317+
assert np.allclose(index[id_].tens, d.tens)
318318

319319

320320
def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -333,7 +333,7 @@ def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
333333
index[doc.id]
334334
else:
335335
assert index[doc.id].id == doc.id
336-
assert np.all(index[doc.id].tens == doc.tens)
336+
assert np.allclose(index[doc.id].tens, doc.tens)
337337

338338

339339
def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -410,5 +410,5 @@ class TextSimpleDoc(SimpleDoc):
410410
for doc in res.documents:
411411
if doc.id == docs[0].id:
412412
found = True
413-
assert (doc.tens == new_tensor).all()
413+
assert np.allclose(doc.tens, new_tensor)
414414
assert found

tests/index/hnswlib/test_persist_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_persist_and_restore(tmp_path):
2222
query = SimpleDoc(tens=np.random.random((10,)))
2323

2424
# create index
25-
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
25+
_ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
2626

2727
# load existing index file
2828
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
@@ -38,7 +38,7 @@ def test_persist_and_restore(tmp_path):
3838
find_results_after = index.find(query, search_field='tens', limit=5)
3939
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
4040
assert doc_before.id == doc_after.id
41-
assert (doc_before.tens == doc_after.tens).all()
41+
assert np.allclose(doc_before.tens, doc_after.tens)
4242

4343
# add new data
4444
index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)])
@@ -70,7 +70,7 @@ def test_persist_and_restore_nested(tmp_path):
7070
find_results_after = index.find(query, search_field='d__tens', limit=5)
7171
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
7272
assert doc_before.id == doc_after.id
73-
assert (doc_before.tens == doc_after.tens).all()
73+
assert np.allclose(doc_before.tens, doc_after.tens)
7474

7575
# delete and restore
7676
index.index(

0 commit comments

Comments
 (0)