Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: configurations
Signed-off-by: jupyterjazz <[email protected]>
  • Loading branch information
jupyterjazz committed Jul 14, 2023
commit cf8aa92fb868a65390c5e988853a067e741235d7
173 changes: 113 additions & 60 deletions docarray/index/backends/milvus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Expand Down Expand Up @@ -35,7 +33,7 @@
from docarray.array.any_array import AnyDocArray

if TYPE_CHECKING:
from pymilvus import (
from pymilvus import ( # type: ignore[import]
Collection,
CollectionSchema,
DataType,
Expand All @@ -57,6 +55,16 @@

ID_VARCHAR_LEN = 1024
SERIALIZED_VARCHAR_LEN = 65_535 # Maximum length that Milvus allows for a VARCHAR field
VALID_METRICS = ['L2', 'IP']
VALID_INDEX_TYPES = [
'FLAT',
'IVF_FLAT',
'IVF_SQ8',
'IVF_PQ',
'HNSW',
'ANNOY',
'DISKANN',
]

TSchema = TypeVar('TSchema', bound=BaseDoc)

Expand All @@ -68,6 +76,9 @@ def __init__(self, db_config=None, **kwargs):
self._db_config: MilvusDocumentIndex.DBConfig = cast(
MilvusDocumentIndex.DBConfig, self._db_config
)
self._runtime_config: MilvusDocumentIndex.RuntimeConfig = cast(
MilvusDocumentIndex.RuntimeConfig, self._runtime_config
)

self._client = connections.connect(
db_name="default",
Expand All @@ -79,8 +90,7 @@ def __init__(self, db_config=None, **kwargs):

self._validate_columns()
self._field_name = self._get_vector_field_name()
self._create_collection_name()
self._collection = self._init_index()
self._collection = self._create_or_load_collection()
self._build_index()
self._collection.load()
self._logger.info(f'{self.__class__.__name__} has been initialized')
Expand All @@ -89,28 +99,42 @@ def __init__(self, db_config=None, **kwargs):
class DBConfig(BaseDocIndex.DBConfig):
"""Dataclass that contains all "static" configurations of MilvusDocumentIndex."""

collection_name: Optional[str] = None
index_name: Optional[str] = None
collection_description: str = ""
host: str = "localhost"
port: int = 19530
user: Optional[str] = ""
password: Optional[str] = ""
token: Optional[str] = ""
index_type: str = "IVF_FLAT"
index_metric: str = "L2"
index_params: Dict = field(default_factory=lambda: {"nlist": 1024})
consistency_level: str = 'Session'
search_params: Dict = field(
default_factory=lambda: {
"metric_type": "L2",
"params": {"nprobe": 10},
}
)
serialize_config: Dict = field(default_factory=lambda: {"protocol": "protobuf"})
default_column_config: Dict[Type, Dict[str, Any]] = field(
default_factory=lambda: defaultdict(dict)
default_factory=lambda: defaultdict(
dict,
{
DataType.FLOAT_VECTOR: {
'index_type': 'IVF_FLAT',
'metric_type': 'L2',
'params': {"nlist": 1024},
},
},
)
)

@dataclass
class RuntimeConfig(BaseDocIndex.RuntimeConfig):
"""Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.

:param batch_size: Batch size for index/get/del.
"""

batch_size: int = 100

def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
Takes any python type and returns the corresponding database column type.
Expand Down Expand Up @@ -139,7 +163,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:

return None

def _init_index(self) -> Collection:
def _create_or_load_collection(self) -> Collection:
"""
This function initializes or retrieves a Milvus collection with a specified schema,
storing documents as serialized data and using the document's ID as the collection's ID
Expand All @@ -150,7 +174,7 @@ def _init_index(self) -> Collection:
column can store in the schema (others are stored in the serialized data)
"""

if not utility.has_collection(self._db_config.collection_name):
if not utility.has_collection(self.index_name):
fields = [
FieldSchema(
name="serialized",
Expand All @@ -170,16 +194,15 @@ def _init_index(self) -> Collection:
and not (
info.db_type == DataType.FLOAT_VECTOR
and column_name
!= self._field_name # Only store one vector field in column
!= self._field_name # Only store one vector field as a column
)
and not safe_issubclass(info.docarray_type, AnyDocArray)
):
field_dict: Dict[str, Any] = {}
if info.db_type == DataType.VARCHAR:
field_dict = {'max_length': 256}
elif info.db_type == DataType.FLOAT_VECTOR:
field_dict = {'dim': info.n_dim or info.config.get('dim')}
else:
field_dict = {}

fields.append(
FieldSchema(
Expand All @@ -192,29 +215,15 @@ def _init_index(self) -> Collection:

self._logger.info("Collection has been created")
return Collection(
name=self._db_config.collection_name,
name=self.index_name,
schema=CollectionSchema(
fields=fields,
description=self._db_config.collection_description,
),
using='default',
)

return Collection(self._db_config.collection_name)

def _create_collection_name(self):
"""
This function generates a unique and sanitized name for the collection,
, ensuring a unique identifier is used if the user does not specify a
collection name.
"""
if self._db_config.collection_name is None:
id = uuid.uuid4().hex
self._db_config.collection_name = f"{self.__class__.__name__}__" + id

self._db_config.collection_name = ''.join(
re.findall('[a-zA-Z0-9_]', self._db_config.collection_name)
)
return Collection(self.index_name)

def _validate_columns(self):
"""
Expand Down Expand Up @@ -245,7 +254,20 @@ def _validate_columns(self):

@property
def index_name(self):
return self._db_config.collection_name
default_index_name = (
self._schema.__name__.lower() if self._schema is not None else None
)
if default_index_name is None:
err_msg = (
'A MilvusDocumentIndex must be typed with a Document type. '
'To do so, use the syntax: MilvusDocumentIndex[DocumentType]'
)

self._logger.error(err_msg)
raise ValueError(err_msg)
index_name = self._db_config.index_name or default_index_name
self._logger.debug(f'Retrieved index name: {index_name}')
return index_name

@property
def out_schema(self) -> Type[BaseDoc]:
Expand All @@ -260,10 +282,31 @@ def _build_index(self):
required by the Milvus backend.
"""

existing_indices = [index.field_name for index in self._collection.indexes]
if self._field_name in existing_indices:
return

index_type = self._column_infos[self._field_name].config['index_type'].upper()
if index_type not in VALID_INDEX_TYPES:
raise ValueError(
f"Invalid index type '{index_type}' provided. "
f"Must be one of: {', '.join(VALID_INDEX_TYPES)}"
)
metric_type = (
self._column_infos[self._field_name].config.get('space', '').upper()
)
if metric_type not in VALID_METRICS:
self._logger.warning(
f"Invalid or no distance metric '{metric_type}' was provided. "
f"Should be one of: {', '.join(VALID_INDEX_TYPES)}. "
f"Default distance metric will be used."
)
metric_type = self._column_infos[self._field_name].config['metric_type']

index = {
"index_type": self._db_config.index_type,
"metric_type": self._db_config.index_metric,
"params": self._db_config.index_params,
"index_type": index_type,
"metric_type": metric_type,
"params": self._column_infos[self._field_name].config['params'],
}

self._collection.create_index(self._field_name, index)
Expand Down Expand Up @@ -305,13 +348,13 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
data_by_columns = self._get_col_value_dict(docs)
self._index_subindex(data_by_columns)

batch_size = 10

positions: Dict[str, int] = {
info.name: num for num, info in enumerate(self._collection.schema.fields)
}

for batch in self._get_batches(docs, batch_size=batch_size):
for batch in self._get_batches(
docs, batch_size=self._runtime_config.batch_size
):
entities: List[List[Any]] = [
[] for _ in range(len(self._collection.schema))
]
Expand All @@ -321,7 +364,9 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
for schema_field in self._collection.schema.fields:
if schema_field.name == 'serialized':
continue
column_value = self._get_values_by_column([doc], schema_field.name)[0]
column_value = self._get_values_by_column([doc], schema_field.name)[
0
]
if schema_field.dtype == DataType.FLOAT_VECTOR:
column_value = self._map_embedding(column_value)

Expand All @@ -338,7 +383,7 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
:return: a list of ids of the subindex documents
"""
docs = self._filter(filter_query=f"parent_id == '{id}'", limit=self.num_docs())
return [doc.id for doc in docs]
return [doc.id for doc in docs] # type: ignore[union-attr]

def num_docs(self) -> int:
"""
Expand Down Expand Up @@ -372,29 +417,36 @@ def _get_items(
"""

self._collection.load()

result = self._collection.query(
expr="id in " + str([id for id in doc_ids]),
offset=0,
output_fields=["serialized"],
consistency_level=self._db_config.consistency_level,
)
results: List[Dict] = []
for batch in self._get_batches(
doc_ids, batch_size=self._runtime_config.batch_size
):
results.extend(
self._collection.query(
expr="id in " + str([id for id in batch]),
offset=0,
output_fields=["serialized"],
consistency_level=self._db_config.consistency_level,
)
)

self._collection.release()

return self._docs_from_query_response(result)
return self._docs_from_query_response(results)

def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.

:param doc_ids: ids to delete from the Document Store
"""
self._collection.load()
self._collection.delete(
expr="id in " + str([id for id in doc_ids]),
consistency_level=self._db_config.consistency_level,
)

for batch in self._get_batches(
doc_ids, batch_size=self._runtime_config.batch_size
):
self._collection.delete(
expr="id in " + str([id for id in batch]),
consistency_level=self._db_config.consistency_level,
)
self._logger.info(f"{len(doc_ids)} documents has been deleted")

def _filter(
Expand Down Expand Up @@ -651,14 +703,15 @@ def _map_embedding(self, embedding: Optional[AnyTensor]) -> Optional[AnyTensor]:
:param embedding: The original raw embedding, which can be in the form of a TensorFlow or PyTorch tensor.
:return embedding: A one-dimensional numpy array representing the flattened version of the original embedding.
"""
if embedding is None:
raise ValueError(
"Embedding is None. Each document must have a valid embedding."
)

if embedding is not None:
embedding = self._to_numpy(embedding)
embedding = self._to_numpy(embedding)
if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()

if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()
else:
embedding = np.zeros(self._db_config.n_dim)
return embedding

def __contains__(self, item) -> bool:
Expand Down
12 changes: 11 additions & 1 deletion tests/index/milvus/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import string
import random

import pytest
import time
import os
Expand All @@ -10,7 +13,14 @@
@pytest.fixture(scope='session', autouse=True)
def start_storage():
os.system(f"docker compose -f {milvus_yml} up -d --remove-orphans")
time.sleep(10)
time.sleep(2)

yield
os.system(f"docker compose -f {milvus_yml} down --remove-orphans")


@pytest.fixture(scope='function')
def tmp_index_name():
letters = string.ascii_lowercase
random_string = ''.join(random.choice(letters) for _ in range(15))
return random_string
13 changes: 12 additions & 1 deletion tests/index/milvus/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray import BaseDoc
from docarray.index import MilvusDocumentIndex
from docarray.typing import NdArray
from tests.index.milvus.fixtures import start_storage # noqa: F401
from tests.index.milvus.fixtures import start_storage, tmp_index_name # noqa: F401


pytestmark = [pytest.mark.slow, pytest.mark.index]
Expand Down Expand Up @@ -54,3 +54,14 @@ class Schema2(BaseDoc):
ValueError, match='Specifying multiple vector fields is not supported'
):
MilvusDocumentIndex[Schema2]()


def test_runtime_config():
class Schema(BaseDoc):
tens: NdArray = Field(dim=10, is_embedding=True)

index = MilvusDocumentIndex[Schema]()
assert index._runtime_config.batch_size == 100

index.configure(batch_size=10)
assert index._runtime_config.batch_size == 10
Loading