Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/index/hnswlib/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_filter_eq(doc_index, docs):
assert filter_result[0].text == docs[1].text
assert filter_result[0].price == docs[1].price
assert filter_result[0].id == docs[1].id
np.testing.assert_array_almost_equal(filter_result[0].tensor, docs[1].tensor)
assert np.allclose(filter_result[0].tensor, docs[1].tensor)


def test_filter_neq(doc_index):
Expand Down
22 changes: 11 additions & 11 deletions tests/index/hnswlib/test_index_get_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
for d in ten_simple_docs:
id_ = d.id
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)

# flat
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
Expand All @@ -221,8 +221,8 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
for d in ten_flat_docs:
id_ = d.id
assert index[id_].id == id_
assert np.all(index[id_].tens_one == d.tens_one)
assert np.all(index[id_].tens_two == d.tens_two)
assert np.allclose(index[id_].tens_one, d.tens_one)
assert np.allclose(index[id_].tens_two, d.tens_two)

# nested
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
Expand All @@ -233,7 +233,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
id_ = d.id
assert index[id_].id == id_
assert index[id_].d.id == d.d.id
assert np.all(index[id_].d.tens == d.d.tens)
assert np.allclose(index[id_].d.tens, d.d.tens)


def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -252,7 +252,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
retrieved_docs = index[ids_to_get]
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert np.all(d_out.tens == d_in.tens)
assert np.allclose(d_out.tens, d_in.tens)

# flat
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
Expand All @@ -264,8 +264,8 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
retrieved_docs = index[ids_to_get]
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert np.all(d_out.tens_one == d_in.tens_one)
assert np.all(d_out.tens_two == d_in.tens_two)
assert np.allclose(d_out.tens_one, d_in.tens_one)
assert np.allclose(d_out.tens_two, d_in.tens_two)

# nested
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
Expand All @@ -278,7 +278,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert d_out.d.id == d_in.d.id
assert np.all(d_out.d.tens == d_in.d.tens)
assert np.allclose(d_out.d.tens, d_in.d.tens)


def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -303,7 +303,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
index[id_]
else:
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)
# delete again
del index[ten_simple_docs[3].id]
assert index.num_docs() == 8
Expand All @@ -314,7 +314,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
index[id_]
else:
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)


def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -333,7 +333,7 @@ def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
index[doc.id]
else:
assert index[doc.id].id == doc.id
assert np.all(index[doc.id].tens == doc.tens)
assert np.allclose(index[doc.id].tens, doc.tens)


def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand Down
6 changes: 3 additions & 3 deletions tests/index/hnswlib/test_persist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_persist_and_restore(tmp_path):
query = SimpleDoc(tens=np.random.random((10,)))

# create index
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
_ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))

# load existing index file
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
Expand All @@ -38,7 +38,7 @@ def test_persist_and_restore(tmp_path):
find_results_after = index.find(query, search_field='tens', limit=5)
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
assert doc_before.id == doc_after.id
assert (doc_before.tens == doc_after.tens).all()
assert np.allclose(doc_before.tens, doc_after.tens)

# add new data
index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)])
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_persist_and_restore_nested(tmp_path):
find_results_after = index.find(query, search_field='d__tens', limit=5)
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
assert doc_before.id == doc_after.id
assert (doc_before.tens == doc_after.tens).all()
assert np.allclose(doc_before.tens, doc_after.tens)

# delete and restore
index.index(
Expand Down