Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
df47474
feat: support redis
jupyterjazz May 17, 2023
0920adc
chore: merge main
jupyterjazz Jun 12, 2023
51558b7
fix: index creation
jupyterjazz Jun 15, 2023
12da714
feat: 1st draft, needs polishing
jupyterjazz Jun 21, 2023
be7bed7
feat: query builder, tests
jupyterjazz Jun 28, 2023
9365327
Merge branch 'main' into feat-add-redis
jupyterjazz Jun 28, 2023
cb25869
chore: update poetry lock
jupyterjazz Jun 28, 2023
341fa9a
chore: run tests
jupyterjazz Jun 28, 2023
d37b28f
Merge branch 'main' into feat-add-redis
jupyterjazz Jun 28, 2023
abca1dd
fix: defaultdict for column config
jupyterjazz Jun 28, 2023
edc8a34
chore: update branch
jupyterjazz Jun 28, 2023
c5abd80
style: ignore mypy errors
jupyterjazz Jun 28, 2023
22434f1
refactor: put vectorfield args in column info
jupyterjazz Jun 28, 2023
4a96194
chore: remove unused code
jupyterjazz Jun 28, 2023
a52cdfd
docs: add docstrings
jupyterjazz Jun 28, 2023
9ac683f
test: add tensorflow test
jupyterjazz Jun 28, 2023
a7f54c3
fix: tensorflow test
jupyterjazz Jun 28, 2023
be9d771
fix: tf tst
jupyterjazz Jun 28, 2023
4dc99e6
refactor: reduce ignore types
jupyterjazz Jun 28, 2023
d2718b4
style: remove other type ignores
jupyterjazz Jun 28, 2023
9f58474
style: try removing import ignores
jupyterjazz Jun 28, 2023
7e791da
style: i think mypy hates me
jupyterjazz Jun 28, 2023
8ed3dbb
feat: batch indexing
jupyterjazz Jun 29, 2023
83afb5f
chore: bump redis version
jupyterjazz Jun 29, 2023
4b0bc73
feat: subindex not fully finished
jupyterjazz Jul 2, 2023
0b15adc
feat: finalize subindex
jupyterjazz Jul 5, 2023
e8fbc47
docs: update readme
jupyterjazz Jul 5, 2023
848e95d
chore: commits not showing
jupyterjazz Jul 5, 2023
c00abe2
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 5, 2023
27dd29a
feat: del and get batched
jupyterjazz Jul 6, 2023
162c6a8
docs: update batchsize docstring
jupyterjazz Jul 6, 2023
16c4323
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 6, 2023
56829d6
refactor: index name
jupyterjazz Jul 9, 2023
c414836
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 9, 2023
7a5ed5e
refactor: default index name following schema
jupyterjazz Jul 9, 2023
0fecb48
chore: update branch
jupyterjazz Jul 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: default index name following schema
Signed-off-by: jupyterjazz <[email protected]>
  • Loading branch information
jupyterjazz committed Jul 9, 2023
commit 7a5ed5e1ab400f02b01271ae1f29886b26d3e947
19 changes: 13 additions & 6 deletions docarray/index/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
"""Initialize RedisDocumentIndex"""
self._index_name = None
super().__init__(db_config=db_config, **kwargs)
self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config)

Expand Down Expand Up @@ -175,12 +174,20 @@ def _check_index_exists(self, index_name: str) -> bool:

@property
def index_name(self):
if not self._index_name:
self._index_name = index_name = (
self._db_config.index_name or 'index_name__' + self._random_name()
default_index_name = (
self._schema.__name__.lower() if self._schema is not None else None
)
if default_index_name is None:
err_msg = (
'A RedisDocumentIndex must be typed with a Document type. '
'To do so, use the syntax: RedisDocumentIndex[DocumentType]'
)
self._logger.debug(f'Retrieved index name: {index_name}')
return self._index_name

self._logger.error(err_msg)
raise ValueError(err_msg)
index_name = self._db_config.index_name or default_index_name
self._logger.debug(f'Retrieved index name: {index_name}')
return index_name

@property
def out_schema(self) -> Type[BaseDoc]:
Expand Down
17 changes: 1 addition & 16 deletions tests/index/redis/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import uuid
import pytest
import redis


@pytest.fixture(scope='session', autouse=True)
Expand All @@ -18,19 +17,5 @@ def start_redis():


@pytest.fixture(scope='function')
def tmp_collection_name():
def tmp_index_name():
return uuid.uuid4().hex


@pytest.fixture
def redis_client():
"""This fixture provides a Redis client"""
client = redis.Redis(host='localhost', port=6379)
yield client
client.flushall()


@pytest.fixture
def redis_config(redis_client):
"""This fixture provides the Redis client and flushes all data after each test case"""
return redis_client
8 changes: 4 additions & 4 deletions tests/index/redis/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray import BaseDoc
from docarray.index import RedisDocumentIndex
from docarray.typing import NdArray
from tests.index.redis.fixtures import start_redis # noqa: F401
from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401


pytestmark = [pytest.mark.slow, pytest.mark.index]
Expand All @@ -23,16 +23,16 @@ class Schema(BaseDoc):
assert index.num_docs() == 10


def test_configure_index():
def test_configure_index(tmp_index_name):
class Schema(BaseDoc):
tens: NdArray[100] = Field(space='cosine')
title: str
year: int

types = {'id': 'TAG', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'}
index = RedisDocumentIndex[Schema](host='localhost')
index = RedisDocumentIndex[Schema](host='localhost', index_name=tmp_index_name)

attr_bytes = index._client.ft(index._index_name).info()['attributes']
attr_bytes = index._client.ft(index.index_name).info()['attributes']
attr = [[byte.decode() for byte in sublist] for sublist in attr_bytes]

assert len(Schema.__fields__) == len(attr)
Expand Down
36 changes: 19 additions & 17 deletions tests/index/redis/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from docarray import BaseDoc, DocList
from docarray.index import RedisDocumentIndex
from docarray.typing import NdArray, TorchTensor
from tests.index.redis.fixtures import start_redis # noqa: F401
from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401

pytestmark = [pytest.mark.slow, pytest.mark.index]

Expand All @@ -27,9 +27,9 @@ class TorchDoc(BaseDoc):


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_simple_schema(space):
def test_find_simple_schema(space, tmp_index_name):
schema = get_simple_schema(space=space)
db = RedisDocumentIndex[schema](host='localhost')
db = RedisDocumentIndex[schema](host='localhost', index_name=tmp_index_name)

index_docs = [schema(tens=np.random.rand(N_DIM)) for _ in range(10)]
index_docs.append(schema(tens=np.ones(N_DIM)))
Expand Down Expand Up @@ -68,8 +68,8 @@ def test_find_limit_larger_than_index():


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_torch(space):
db = RedisDocumentIndex[TorchDoc](host='localhost')
def test_find_torch(space, tmp_index_name):
db = RedisDocumentIndex[TorchDoc](host='localhost', index_name=tmp_index_name)
index_docs = [TorchDoc(tens=np.random.rand(N_DIM)) for _ in range(10)]
index_docs.append(TorchDoc(tens=np.ones(N_DIM, dtype=np.float32)))
db.index(index_docs)
Expand All @@ -91,13 +91,13 @@ def test_find_torch(space):

@pytest.mark.tensorflow
@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_tensorflow(space):
def test_find_tensorflow(space, tmp_index_name):
from docarray.typing import TensorFlowTensor

class TfDoc(BaseDoc):
tens: TensorFlowTensor[10]

db = RedisDocumentIndex[TfDoc](host='localhost')
db = RedisDocumentIndex[TfDoc](host='localhost', index_name=tmp_index_name)

index_docs = [TfDoc(tens=np.random.rand(N_DIM)) for _ in range(10)]
index_docs.append(TfDoc(tens=np.ones(10)))
Expand All @@ -121,12 +121,12 @@ class TfDoc(BaseDoc):


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_flat_schema(space):
def test_find_flat_schema(space, tmp_index_name):
class FlatSchema(BaseDoc):
tens_one: NdArray = Field(dim=N_DIM, space=space)
tens_two: NdArray = Field(dim=50, space=space)

index = RedisDocumentIndex[FlatSchema](host='localhost')
index = RedisDocumentIndex[FlatSchema](host='localhost', index_name=tmp_index_name)

index_docs = [
FlatSchema(tens_one=np.random.rand(N_DIM), tens_two=np.random.rand(50))
Expand Down Expand Up @@ -156,7 +156,7 @@ class FlatSchema(BaseDoc):


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_nested_schema(space):
def test_find_nested_schema(space, tmp_index_name):
class SimpleDoc(BaseDoc):
tens: NdArray[N_DIM] = Field(space=space)

Expand All @@ -168,7 +168,9 @@ class DeepNestedDoc(BaseDoc):
d: NestedDoc
tens: NdArray = Field(space=space, dim=N_DIM)

index = RedisDocumentIndex[DeepNestedDoc](host='localhost')
index = RedisDocumentIndex[DeepNestedDoc](
host='localhost', index_name=tmp_index_name
)

index_docs = [
DeepNestedDoc(
Expand Down Expand Up @@ -243,12 +245,12 @@ class MyDoc(BaseDoc):
assert q.id == matches[0].id


def test_query_builder():
def test_query_builder(tmp_index_name):
class SimpleSchema(BaseDoc):
tensor: NdArray[N_DIM] = Field(space='cosine')
price: int

db = RedisDocumentIndex[SimpleSchema](host='localhost')
db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name)

index_docs = [
SimpleSchema(tensor=np.array([i + 1] * 10), price=i + 1) for i in range(10)
Expand All @@ -269,7 +271,7 @@ class SimpleSchema(BaseDoc):
assert doc.price <= 3


def test_text_search():
def test_text_search(tmp_index_name):
class SimpleSchema(BaseDoc):
description: str
some_field: Optional[int]
Expand All @@ -286,15 +288,15 @@ class SimpleSchema(BaseDoc):

docs = [SimpleSchema(description=text) for text in texts_to_index]

db = RedisDocumentIndex[SimpleSchema](host='localhost')
db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name)
db.index(docs)

docs, _ = db.text_search(query=query_string, search_field='description')

assert docs[0].description == texts_to_index[0]


def test_filter():
def test_filter(tmp_index_name):
class SimpleSchema(BaseDoc):
description: str
price: int
Expand All @@ -304,7 +306,7 @@ class SimpleSchema(BaseDoc):
doc3 = SimpleSchema(description='Random book', price=40)
docs = [doc1, doc2, doc3]

db = RedisDocumentIndex[SimpleSchema](host='localhost')
db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name)
db.index(docs)

# filter on price < 45
Expand Down
22 changes: 11 additions & 11 deletions tests/index/redis/test_index_get_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray import BaseDoc
from docarray.index import RedisDocumentIndex
from docarray.typing import NdArray
from tests.index.redis.fixtures import start_redis # noqa: F401
from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401

pytestmark = [pytest.mark.slow, pytest.mark.index]

Expand Down Expand Up @@ -39,8 +39,8 @@ def test_num_docs(ten_simple_docs):
assert index.num_docs() == 10


def test_get_single(ten_simple_docs):
index = RedisDocumentIndex[SimpleDoc](host='localhost')
def test_get_single(ten_simple_docs, tmp_index_name):
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
index.index(ten_simple_docs)

assert index.num_docs() == 10
Expand All @@ -54,9 +54,9 @@ def test_get_single(ten_simple_docs):
index['some_id']


def test_get_multiple(ten_simple_docs):
def test_get_multiple(ten_simple_docs, tmp_index_name):
docs_to_get_idx = [0, 2, 4, 6, 8]
index = RedisDocumentIndex[SimpleDoc](host='localhost')
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
index.index(ten_simple_docs)

assert index.num_docs() == 10
Expand All @@ -68,8 +68,8 @@ def test_get_multiple(ten_simple_docs):
assert np.allclose(d_out.tens, d_in.tens)


def test_del_single(ten_simple_docs):
index = RedisDocumentIndex[SimpleDoc](host='localhost')
def test_del_single(ten_simple_docs, tmp_index_name):
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
index.index(ten_simple_docs)
assert index.num_docs() == 10

Expand All @@ -82,10 +82,10 @@ def test_del_single(ten_simple_docs):
index[doc_id]


def test_del_multiple(ten_simple_docs):
def test_del_multiple(ten_simple_docs, tmp_index_name):
docs_to_del_idx = [0, 2, 4, 6, 8]

index = RedisDocumentIndex[SimpleDoc](host='localhost')
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
index.index(ten_simple_docs)

assert index.num_docs() == 10
Expand All @@ -101,8 +101,8 @@ def test_del_multiple(ten_simple_docs):
assert np.allclose(index[doc.id].tens, doc.tens)


def test_contains(ten_simple_docs):
index = RedisDocumentIndex[SimpleDoc](host='localhost')
def test_contains(ten_simple_docs, tmp_index_name):
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
index.index(ten_simple_docs)

for doc in ten_simple_docs:
Expand Down
9 changes: 4 additions & 5 deletions tests/index/redis/test_persist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray import BaseDoc
from docarray.index import RedisDocumentIndex
from docarray.typing import NdArray
from tests.index.redis.fixtures import start_redis # noqa: F401
from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401


pytestmark = [pytest.mark.slow, pytest.mark.index]
Expand All @@ -15,12 +15,11 @@ class SimpleDoc(BaseDoc):
tens: NdArray[10] = Field(dim=1000)


def test_persist():
def test_persist(tmp_index_name):
query = SimpleDoc(tens=np.random.random((10,)))

# create index
index = RedisDocumentIndex[SimpleDoc](host='localhost')
index_name = index._index_name
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)

assert index.num_docs() == 0

Expand All @@ -29,7 +28,7 @@ def test_persist():
find_results_before = index.find(query, search_field='tens', limit=5)

# load existing index
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=index_name)
index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name)
assert index.num_docs() == 10
find_results_after = index.find(query, search_field='tens', limit=5)
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
Expand Down
14 changes: 7 additions & 7 deletions tests/index/redis/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ class ListDoc(BaseDoc):
list_tens: NdArray[20] = Field(space='l2')


class MyDoc(BaseDoc):
class NestedDoc(BaseDoc):
docs: DocList[SimpleDoc]
list_docs: DocList[ListDoc]
my_tens: NdArray[30] = Field(space='l2')


@pytest.fixture(scope='session')
def index():
index = RedisDocumentIndex[MyDoc](host='localhost')
index = RedisDocumentIndex[NestedDoc](host='localhost')
return index


@pytest.fixture(scope='session')
def data():
my_docs = [
MyDoc(
NestedDoc(
id=f'{i}',
docs=DocList[SimpleDoc](
[
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_subindex_index(index, data):
def test_subindex_get(index, data):
index.index(data)
doc = index['1']
assert type(doc) == MyDoc
assert type(doc) == NestedDoc
assert doc.id == '1'
assert len(doc.docs) == 5
assert type(doc.docs[0]) == SimpleDoc
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_subindex_contain(index, data):
assert not index.subindex_contains(empty_doc)

# Empty index
empty_index = RedisDocumentIndex[MyDoc](host='localhost')
empty_index = RedisDocumentIndex[NestedDoc](host='localhost')
assert empty_doc not in empty_index


Expand All @@ -174,7 +174,7 @@ def test_find_subindex(index, data):
root_docs, docs, scores = index.find_subindex(
query, subindex='docs', search_field='simple_tens', limit=5
)
assert type(root_docs[0]) == MyDoc
assert type(root_docs[0]) == NestedDoc
assert type(docs[0]) == SimpleDoc
assert len(scores) == 5
for root_doc, doc in zip(root_docs, docs):
Expand All @@ -188,7 +188,7 @@ def test_find_subindex(index, data):
)
assert len(docs) == 5
assert len(scores) == 5
assert type(root_docs[0]) == MyDoc
assert type(root_docs[0]) == NestedDoc
assert type(docs[0]) == SimpleDoc
for root_doc, doc in zip(root_docs, docs):
assert np.allclose(doc.simple_tens, np.ones(10))
Expand Down