Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions tests/index/hnswlib/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@ class MyDoc(BaseDoc):


@pytest.fixture(scope='session')
def index():
index = HnswDocumentIndex[MyDoc](work_dir='./tmp')
return index


def test_subindex_init(index):
assert isinstance(index._subindices['docs'], HnswDocumentIndex)
assert isinstance(index._subindices['list_docs'], HnswDocumentIndex)
assert isinstance(
index._subindices['list_docs']._subindices['docs'], HnswDocumentIndex
)


def test_subindex_index(index):
def index_docs():
my_docs = [
MyDoc(
id=f'{i}',
Expand Down Expand Up @@ -82,15 +69,31 @@ def test_subindex_index(index):
)
for i in range(5)
]
return my_docs


def test_subindex_init(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
assert isinstance(index._subindices['docs'], HnswDocumentIndex)
assert isinstance(index._subindices['list_docs'], HnswDocumentIndex)
assert isinstance(
index._subindices['list_docs']._subindices['docs'], HnswDocumentIndex
)


index.index(my_docs)
def test_subindex_index(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
assert index.num_docs() == 5
assert index._subindices['docs'].num_docs() == 25
assert index._subindices['list_docs'].num_docs() == 25
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 125


def test_subindex_get(index):
def test_subindex_get(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
doc = index['1']
assert type(doc) == MyDoc
assert doc.id == '1'
Expand All @@ -116,7 +119,9 @@ def test_subindex_get(index):
assert np.allclose(doc.my_tens, np.ones(30) * 2)


def test_find_subindex(index):
def test_find_subindex(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
# root level
query = np.ones((30,))
with pytest.raises(ValueError):
Expand Down Expand Up @@ -148,15 +153,19 @@ def test_find_subindex(index):
assert root_doc.id == f'{doc.id.split("-")[2]}'


def test_subindex_del(index):
def test_subindex_del(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
del index['0']
assert index.num_docs() == 4
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
def test_subindex_contain(tmpdir, index_docs):
index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir))
index.index(index_docs)
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
Expand Down