Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
df47474
feat: support redis
jupyterjazz May 17, 2023
0920adc
chore: merge main
jupyterjazz Jun 12, 2023
51558b7
fix: index creation
jupyterjazz Jun 15, 2023
12da714
feat: 1st draft, needs polishing
jupyterjazz Jun 21, 2023
be7bed7
feat: query builder, tests
jupyterjazz Jun 28, 2023
9365327
Merge branch 'main' into feat-add-redis
jupyterjazz Jun 28, 2023
cb25869
chore: update poetry lock
jupyterjazz Jun 28, 2023
341fa9a
chore: run tests
jupyterjazz Jun 28, 2023
d37b28f
Merge branch 'main' into feat-add-redis
jupyterjazz Jun 28, 2023
abca1dd
fix: defaultdict for column config
jupyterjazz Jun 28, 2023
edc8a34
chore: update branch
jupyterjazz Jun 28, 2023
c5abd80
style: ignore mypy errors
jupyterjazz Jun 28, 2023
22434f1
refactor: put vectorfield args in column info
jupyterjazz Jun 28, 2023
4a96194
chore: remove unused code
jupyterjazz Jun 28, 2023
a52cdfd
docs: add docstrings
jupyterjazz Jun 28, 2023
9ac683f
test: add tensorflow test
jupyterjazz Jun 28, 2023
a7f54c3
fix: tensorflow test
jupyterjazz Jun 28, 2023
be9d771
fix: tf tst
jupyterjazz Jun 28, 2023
4dc99e6
refactor: reduce ignore types
jupyterjazz Jun 28, 2023
d2718b4
style: remove other type ignores
jupyterjazz Jun 28, 2023
9f58474
style: try removing import ignores
jupyterjazz Jun 28, 2023
7e791da
style: i think mypy hates me
jupyterjazz Jun 28, 2023
8ed3dbb
feat: batch indexing
jupyterjazz Jun 29, 2023
83afb5f
chore: bump redis version
jupyterjazz Jun 29, 2023
4b0bc73
feat: subindex not fully finished
jupyterjazz Jul 2, 2023
0b15adc
feat: finalize subindex
jupyterjazz Jul 5, 2023
e8fbc47
docs: update readme
jupyterjazz Jul 5, 2023
848e95d
chore: commits not showing
jupyterjazz Jul 5, 2023
c00abe2
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 5, 2023
27dd29a
feat: del and get batched
jupyterjazz Jul 6, 2023
162c6a8
docs: update batchsize docstring
jupyterjazz Jul 6, 2023
16c4323
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 6, 2023
56829d6
refactor: index name
jupyterjazz Jul 9, 2023
c414836
Merge branch 'main' into feat-add-redis
jupyterjazz Jul 9, 2023
7a5ed5e
refactor: default index name following schema
jupyterjazz Jul 9, 2023
0fecb48
chore: update branch
jupyterjazz Jul 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: 1st draft, needs polishing
Signed-off-by: jupyterjazz <[email protected]>
  • Loading branch information
jupyterjazz committed Jun 21, 2023
commit 12da714a673af63045f11fbc226f6895687ddae0
269 changes: 245 additions & 24 deletions docarray/index/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from typing import (
TypeVar,
Generic,
Expand All @@ -10,12 +11,16 @@
Generator,
Type,
cast,
TYPE_CHECKING, Iterator,
TYPE_CHECKING,
Iterator,
Mapping,
)
from dataclasses import dataclass, field

import binascii
import numpy as np
import pickle

from redis.commands.search.query import Query

from docarray import BaseDoc, DocList
from docarray.index.abstract import BaseDocIndex
Expand All @@ -34,18 +39,31 @@
VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

from redis.commands.search.querystring import (
DistjunctUnion,
IntersectNode,
equal,
ge,
gt,
intersect,
le,
lt,
union,
)

TSchema = TypeVar('TSchema', bound=BaseDoc)

VALID_DISTANCES = ['L2', 'IP', 'COSINE']
VALID_ALGORITHMS = ['FLAT', 'HNSW']


class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
super().__init__(db_config=db_config, **kwargs)
self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config)

if not self._db_config.index_name:
self._db_config.index_name = 'index_name__' + 'random_name' # todo
self._db_config.index_name = 'index_name__' + self._random_name()
self._prefix = self._db_config.index_name + ':'

# initialize Redis client
Expand All @@ -54,15 +72,29 @@ def __init__(self, db_config=None, **kwargs):
port=self._db_config.port,
username=self._db_config.username,
password=self._db_config.password,
decode_responses=False,
)
self._create_index()
self._logger.info(f'{self.__class__.__name__} has been initialized')

@staticmethod
def _random_name():
return uuid.uuid4().hex

def _create_index(self):
if not self._check_index_exists(self._db_config.index_name):
schema = []
for column, info in self._column_infos.items():

if info.db_type == VectorField:
space = info.config.get('space')
if space:
for valid_dist in VALID_DISTANCES:
if space.upper() == valid_dist:
space = valid_dist
if space not in VALID_DISTANCES:
space = self._db_config.distance

schema.append(
info.db_type(
name=column,
Expand All @@ -72,18 +104,19 @@ def _create_index(self):
attributes={
'TYPE': 'FLOAT32',
'DIM': info.n_dim,
'DISTANCE_METRIC': 'COSINE',
'DISTANCE_METRIC': space,
},
)
)
else:
schema.append(info.db_type(name=column))


# Create Redis Index
self._client.ft(self._db_config.index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[self._prefix], index_type=IndexType.HASH),
definition=IndexDefinition(
prefix=[self._prefix], index_type=IndexType.HASH
),
)

self._logger.info(f'index {self._db_config.index_name} has been created')
Expand Down Expand Up @@ -111,7 +144,22 @@ class DBConfig(BaseDocIndex.DBConfig):
index_name: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
algorithm: str = 'FLAT'
algorithm: str = field(default='FLAT')
distance: str = field(default='COSINE')
ef_construction: Optional[int] = None
m: Optional[int] = None
ef_runtime: Optional[int] = None
block_size: Optional[int] = None
initial_cap: Optional[int] = None

def __post_init__(self):
if self.algorithm not in VALID_ALGORITHMS:
raise ValueError(f"Invalid algorithm '{self.algorithm}' provided. "
f"Must be one of: {', '.join(VALID_ALGORITHMS)}")

if self.distance not in VALID_DISTANCES:
raise ValueError(f"Invalid distance metric '{self.distance}' provided. "
f"Must be one of: {', '.join(VALID_DISTANCES)}")

@dataclass
class RuntimeConfig(BaseDocIndex.RuntimeConfig):
Expand Down Expand Up @@ -142,7 +190,9 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')

@staticmethod
def _generate_item(column_to_data: Dict[str, Generator[Any, None, None]]) -> Iterator[Dict[str, Any]]:
def _generate_item(
column_to_data: Dict[str, Generator[Any, None, None]]
) -> Iterator[Dict[str, Any]]:
"""
Given a dictionary of generators, yield a dictionary where each item consists of a key and
a single item from the corresponding generator.
Expand All @@ -159,21 +209,25 @@ def _generate_item(column_to_data: Dict[str, Generator[Any, None, None]]) -> Ite
item_dict = {}
for key, it in zip(keys, iterators):
item = next(it, None)
if item is None: # If item is not None, add it to the dictionary
continue
if isinstance(item, AbstractTensor):
item_dict[key] = pickle.dumps(item)

if key == 'id' and not item:
return

if item is None:
item_dict[key] = '__None__'
elif isinstance(item, AbstractTensor):
item_dict[key] = np.array(
item._docarray_to_ndarray(), dtype=np.float32
).tobytes()
else:
item_dict[key] = item

if not item_dict: # If item_dict is empty, break the loop
break
yield item_dict

def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
ids = []
pipeline = self._client.pipeline(transaction=False)
batch_size = 10
batch_size = 10 # variable [1k]
for item in self._generate_item(column_to_data):
doc_id = self._prefix + item.pop('id')
pipeline.hset(
Expand All @@ -195,40 +249,207 @@ def num_docs(self) -> int:
return self._client.ft(self._db_config.index_name).info()['num_docs']

def _del_items(self, doc_ids: Sequence[str]):
pass
doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)]
if doc_ids:
self._client.delete(*doc_ids)

def _doc_exists(self, doc_id):
return self._client.exists(self._prefix + doc_id)

def _get_items(
self, doc_ids: Sequence[str]
) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
pass
if not doc_ids:
return []

pipe = self._client.pipeline()
for id in doc_ids:
pipe.hgetall(self._prefix + id)

results = pipe.execute()

docs = [
{k.decode('utf-8'): v.decode('utf-8', 'ignore') for k, v in d.items()}
for d in results
]

docs = [{k: v for k, v in d.items() if k != 'tens'} for d in docs] # todo (vector decoding problem)
docs = [{k: None if v == '__None__' else v for k, v in d.items()} for d in docs] # todo (converting to None)
return docs

def execute_query(self, query: Any, *args, **kwargs) -> Any:
pass

def _convert_to_schema(self, document):
doc_kwargs = {}
for column, info in self._column_infos.items():
if column == 'id':
doc_kwargs['id'] = document.id[len(self._prefix) :]
elif document[column] == '__None__':
doc_kwargs[column] = None
elif info.db_type == VectorField:
# byte_string = document[column]
# byte_data = byte_string.encode('utf-8')
doc_kwargs[column] = np.frombuffer(document[column], dtype=np.float32)
elif info.db_type == NumericField:
doc_kwargs[column] = info.docarray_type(document[column])
else:
doc_kwargs[column] = document[column]

return doc_kwargs

def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
) -> _FindResult:
pass
limit = 5
query_str = '*'
redis_query = (
Query(f'{query_str}=>[KNN {limit} @{search_field} $vec AS vector_score]')
.sort_by('vector_score')
.paging(0, limit)
.dialect(2)
)
query_params: Mapping[str, str] = { # type: ignore
'vec': np.array(query, dtype=np.float32).tobytes()
}
results = (
self._client.ft(self._db_config.index_name)
.search(redis_query, query_params)
.docs
)

scores = [document['vector_score'] for document in results]
docs = [self._convert_to_schema(document) for document in results]

return _FindResult(documents=docs, scores=scores)

def _find_batched(
self, queries: np.ndarray, limit: int, search_field: str = ''
) -> _FindResultBatched:
pass
docs, scores = [], []
for query in queries:
results = self._find(query=query, search_field=search_field, limit=limit)
docs.append(results.documents)
scores.append(results.scores)

return _FindResultBatched(documents=docs, scores=scores)

def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]:
pass
query_str = self._get_redis_filter_query(filter_query)
q = Query(query_str)
q.paging(0, limit)

results = self._client.ft(index_name=self._db_config.index_name).search(q).docs
docs = [self._convert_to_schema(document) for document in results]

return docs

def _build_query_node(self, key, condition):
operator = list(condition.keys())[0]
value = condition[operator]

query_dict = {}

if operator in ['$ne', '$eq']:
if isinstance(value, bool):
query_dict[key] = equal(int(value))
elif isinstance(value, (int, float)):
query_dict[key] = equal(value)
else:
query_dict[key] = value
elif operator == '$gt':
query_dict[key] = gt(value)
elif operator == '$gte':
query_dict[key] = ge(value)
elif operator == '$lt':
query_dict[key] = lt(value)
elif operator == '$lte':
query_dict[key] = le(value)
else:
raise ValueError(
f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead'
)

if operator == '$ne':
return DistjunctUnion(**query_dict)
return IntersectNode(**query_dict)

def _build_query_nodes(self, filter):
nodes = []
for k, v in filter.items():
if k == '$and':
children = self._build_query_nodes(v)
node = intersect(*children)
nodes.append(node)
elif k == '$or':
children = self._build_query_nodes(v)
node = union(*children)
nodes.append(node)
else:
child = self._build_query_node(k, v)
nodes.append(child)

return nodes

def _get_redis_filter_query(self, filter: Union[str, Dict]):
if isinstance(filter, dict):
nodes = self._build_query_nodes(filter)
query_str = intersect(*nodes).to_string()
elif isinstance(filter, str):
query_str = filter
else:
raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str')

return query_str

def _filter_batched(
self, filter_queries: Any, limit: int
) -> Union[List[DocList], List[List[Dict]]]:
pass
results = []
for query in filter_queries:
results.append(self._filter(filter_query=query, limit=limit))
return results

def _text_search(
self, query: str, limit: int, search_field: str = ''
) -> _FindResult:
pass
query_str = '|'.join(query.split(' '))

scorer = 'BM25'
if scorer not in [
'BM25',
'TFIDF',
'TFIDF.DOCNORM',
'DISMAX',
'DOCSCORE',
'HAMMING',
]:
raise ValueError(
f'Expecting a valid text similarity ranking algorithm, got {scorer} instead'
)
q = (
Query(f'@{search_field}:{query_str}')
.scorer(scorer)
.with_scores()
.paging(0, limit)
)

results = self._client.ft(index_name=self._db_config.index_name).search(q).docs

scores = [document['score'] for document in results]
docs = [self._convert_to_schema(document) for document in results]

return _FindResult(documents=docs, scores=scores)

def _text_search_batched(
self, queries: Sequence[str], limit: int, search_field: str = ''
) -> _FindResultBatched:
pass
docs, scores = [], []
for query in queries:
results = self._text_search(
query=query, search_field=search_field, limit=limit
)
docs.append(results.documents)
scores.append(results.scores)

return _FindResultBatched(documents=docs, scores=scores)
Loading