Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: query builder
Signed-off-by: jupyterjazz <[email protected]>
  • Loading branch information
jupyterjazz committed Jul 14, 2023
commit 87d487d230da59faf20b51febe65caa899d24213
152 changes: 134 additions & 18 deletions docarray/index/backends/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
TypeVar,
Union,
cast,
Tuple,
)

import numpy as np

from docarray import BaseDoc, DocList
from docarray.index.abstract import BaseDocIndex
from docarray.index.backends.helper import _execute_find_and_filter_query
from docarray.index.abstract import (
BaseDocIndex,
_raise_not_supported,
_raise_not_composable,
)
from docarray.index.backends.helper import (
_execute_find_and_filter_query,
_collect_query_args,
)
from docarray.typing import AnyTensor, NdArray
from docarray.typing.id import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor
Expand Down Expand Up @@ -53,8 +61,7 @@
Hits,
)

ID_VARCHAR_LEN = 1024
SERIALIZED_VARCHAR_LEN = 65_535 # Maximum length that Milvus allows for a VARCHAR field
MAX_LEN = 65_535 # Maximum length that Milvus allows for a VARCHAR field
VALID_METRICS = ['L2', 'IP']
VALID_INDEX_TYPES = [
'FLAT',
Expand Down Expand Up @@ -97,7 +104,23 @@ def __init__(self, db_config=None, **kwargs):

@dataclass
class DBConfig(BaseDocIndex.DBConfig):
"""Dataclass that contains all "static" configurations of MilvusDocumentIndex."""
"""Dataclass that contains all "static" configurations of MilvusDocumentIndex.

:param index_name: The name of the index in the Milvus database. If not provided, default index name will be used.
:param collection_description: Description of the collection in the database.
:param host: Hostname of the server where the database resides. Default is 'localhost'.
:param port: Port number used to connect to the database. Default is 19530.
:param user: User for the database. Can be an empty string if no user is required.
:param password: Password for the specified user. Can be an empty string if no password is required.
:param token: Token for secure connection. Can be an empty string if no token is required.
:param consistency_level: The level of consistency for the database session. Default is 'Session'.
:param search_params: Dictionary containing parameters for search operations,
default has a single key 'params' with 'nprobe' set to 10.
:param serialize_config: Dictionary containing configuration for serialization,
default is {'protocol': 'protobuf'}.
:param default_column_config: Dictionary that defines the default configuration
for each data type column.
"""

index_name: Optional[str] = None
collection_description: str = ""
Expand Down Expand Up @@ -135,6 +158,23 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig):

batch_size: int = 100

class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
super().__init__()
# list of tuples (method name, kwargs)
self._queries: List[Tuple[str, Dict]] = query or []

def build(self, *args, **kwargs) -> Any:
"""Build the query object."""
return self._queries

find = _collect_query_args('find')
filter = _collect_query_args('filter')
text_search = _raise_not_supported('text_search')
find_batched = _raise_not_composable('find_batched')
filter_batched = _raise_not_composable('filter_batched')
text_search_batched = _raise_not_supported('text_search_batched')

def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
Takes any python type and returns the corresponding database column type.
Expand Down Expand Up @@ -179,13 +219,13 @@ def _create_or_load_collection(self) -> Collection:
FieldSchema(
name="serialized",
dtype=DataType.VARCHAR,
max_length=SERIALIZED_VARCHAR_LEN,
max_length=MAX_LEN,
),
FieldSchema(
name="id",
dtype=DataType.VARCHAR,
is_primary=True,
max_length=ID_VARCHAR_LEN,
max_length=MAX_LEN,
),
]
for column_name, info in self._column_infos.items():
Expand All @@ -200,7 +240,7 @@ def _create_or_load_collection(self) -> Collection:
):
field_dict: Dict[str, Any] = {}
if info.db_type == DataType.VARCHAR:
field_dict = {'max_length': 256}
field_dict = {'max_length': MAX_LEN}
elif info.db_type == DataType.FLOAT_VECTOR:
field_dict = {'dim': info.n_dim or info.config.get('dim')}

Expand Down Expand Up @@ -454,6 +494,14 @@ def _filter(
filter_query: Any,
limit: int,
) -> Union[DocList, List[Dict]]:
"""
Filters the index based on the given filter query.

:param filter_query: The filter condition.
:param limit: The maximum number of results to return.
:return: Filter results.
"""

self._collection.load()

result = self._collection.query(
Expand All @@ -472,6 +520,13 @@ def _filter_batched(
filter_queries: Any,
limit: int,
) -> Union[List[DocList], List[List[Dict]]]:
"""
Filters the index based on the given batch of filter queries.

:param filter_queries: The filter conditions.
:param limit: The maximum number of results to return for each filter query.
:return: Filter results.
"""
return [
self._filter(filter_query=query, limit=limit) for query in filter_queries
]
Expand Down Expand Up @@ -542,6 +597,33 @@ def _find(
limit: int,
search_field: str = '',
) -> _FindResult:
"""
Conducts a search on the index.

:param query: The vector query to search.
:param limit: The maximum number of results to return.
:param search_field: The field to search the query.
:return: Search results.
"""

return self._hybrid_search(query=query, limit=limit, search_field=search_field)

def _hybrid_search(
self,
query: np.ndarray,
limit: int,
search_field: str = '',
expr: Optional[str] = None,
):
"""
Conducts a hybrid search on the index.

:param query: The vector query to search.
:param limit: The maximum number of results to return.
:param search_field: The field to search the query.
:param expr: Boolean expression used for filtering.
:return: Search results.
"""
self._collection.load()

results = self._collection.search(
Expand All @@ -550,7 +632,7 @@ def _find(
param=self._db_config.search_params,
limit=limit,
offset=0,
expr=None,
expr=expr,
output_fields=["serialized"],
consistency_level=self._db_config.consistency_level,
)
Expand Down Expand Up @@ -624,6 +706,15 @@ def _find_batched(
limit: int,
search_field: str = '',
) -> _FindResultBatched:
"""
Conducts a batched search on the index.

:param queries: The queries to search.
:param limit: The maximum number of results to return for each query.
:param search_field: The field to search the queries.
:return: Search results.
"""

self._collection.load()

results = self._collection.search(
Expand All @@ -648,15 +739,44 @@ def _find_batched(
)

def execute_query(self, query: Any, *args, **kwargs) -> Any:
if args or kwargs:
"""
Executes a hybrid query on the index.

:param query: Query to execute on the index.
:return: Query results.
"""
components: Dict[str, List[Dict[str, Any]]] = {}
for component, value in query:
if component not in components:
components[component] = []
components[component].append(value)

if (
len(components) != 2
or len(components.get('find', [])) != 1
or len(components.get('filter', [])) != 1
):
raise ValueError(
f'args and kwargs not supported for `execute_query` on {type(self)}'
'The query must contain exactly one "find" and "filter" components.'
)
find_res = _execute_find_and_filter_query(
doc_index=self,

expr = components['filter'][0]['filter_query']
query = components['find'][0]['query']
limit = (
components['find'][0].get('limit')
or components['filter'][0].get('limit')
or 10
)
docs, scores = self._hybrid_search(
query=query,
expr=expr,
search_field=self._field_name,
limit=limit,
)
return find_res
if isinstance(docs, List) and not isinstance(docs, DocList):
docs = self._dict_list_to_docarray(docs)

return FindResult(documents=docs, scores=scores)

def _docs_from_query_response(self, result: Sequence[Dict]) -> Sequence[TSchema]:
return DocList[self._schema](
Expand Down Expand Up @@ -722,7 +842,3 @@ def __contains__(self, item) -> bool:
)

return len(result) > 0

def __exit__(self, exc_type, exc_val, exc_tb):
self._collection.release()
self._loaded = False
7 changes: 1 addition & 6 deletions docarray/index/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ def __init__(self, db_config=None, **kwargs):
self._create_index()
self._logger.info(f'{self.__class__.__name__} has been initialized')

@staticmethod
def _random_name() -> str:
"""Generate a random index name."""
return uuid.uuid4().hex

def _create_index(self) -> None:
"""Create a new index in the Redis database if it doesn't already exist."""
if not self._check_index_exists(self.index_name):
Expand Down Expand Up @@ -220,7 +215,7 @@ class DBConfig(BaseDocIndex.DBConfig):
:param host: The host address for the Redis server. Default is 'localhost'.
:param port: The port number for the Redis server. Default is 6379.
:param index_name: The name of the index in the Redis database.
In case it's not provided, a random index name will be generated.
If not provided, default index name will be used.
:param username: The username for the Redis server. Default is None.
:param password: The password for the Redis server. Default is None.
:param text_scorer: The method for scoring text during text search.
Expand Down
26 changes: 26 additions & 0 deletions tests/index/milvus/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,29 @@ class SimpleSchema(BaseDoc):

docs = index.filter(f"id == '{index_docs[0].id}'", limit=5)
assert docs[0].id == index_docs[0].id


def test_query_builder(tmp_index_name):
class SimpleSchema(BaseDoc):
tensor: NdArray[10] = Field(is_embedding=True)
price: int

db = MilvusDocumentIndex[SimpleSchema](index_name=tmp_index_name)

index_docs = [
SimpleSchema(tensor=np.array([i + 1] * 10), price=i + 1) for i in range(10)
]
db.index(index_docs)

q = (
db.build_query()
.find(query=np.ones(10), limit=5)
.filter(filter_query='price <= 3')
.build()
)

docs, scores = db.execute_query(q)

assert len(docs) == 3
for doc in docs:
assert doc.price <= 3