Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: apply some mypy changes
  • Loading branch information
Joan Fontanals Martinez committed Jun 21, 2023
commit e3e26a7eba3d6418a4fd3976225a1ba68efb2ad4
12 changes: 7 additions & 5 deletions docarray/utils/create.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from docarray import DocList, BaseDoc
from docarray.typing import AnyTensor
from pydantic import create_model
from typing import Dict, List, Any, Union, Optional
from typing import Dict, List, Any, Union, Optional, Tuple, Type
from typing_extensions import TypeAlias


def _create_aux_model_doc_list_to_list(model):
fields = {}
def create_new_model_cast_doclist_to_list(model: BaseDoc) -> BaseDoc:
fields: Dict[str, Tuple[Type, Dict]] = {}
for field_name, field in model.__annotations__.items():
try:
if issubclass(field, DocList):
Expand All @@ -29,6 +30,7 @@ def _get_field_from_type(
):
field_type = field_schema.get('type', None)
tensor_shape = field_schema.get('tensor/array shape', None)
ret: TypeAlias
if 'anyOf' in field_schema:
any_of_types = []
for any_of_schema in field_schema['anyOf']:
Expand Down Expand Up @@ -143,8 +145,8 @@ def _get_field_from_type(


def create_base_doc_from_schema(
schema: Dict[str, any], model_name: str, cached_models: Optional[Dict] = None
) -> type:
schema: Dict[str, Any], model_name: str, cached_models: Optional[Dict] = None
) -> Type:
cached_models = cached_models if cached_models is not None else {}
fields = {}
if model_name in cached_models:
Expand Down
14 changes: 7 additions & 7 deletions tests/units/util/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Dict, Union, Any
from docarray.utils.create import (
create_base_doc_from_schema,
_create_aux_model_doc_list_to_list,
create_new_model_cast_doclist_to_list,
)
import numpy as np
from typing import Optional
Expand All @@ -26,7 +26,7 @@ class CustomDoc(BaseDoc):
lu: List[Union[str, int]] = [0, 1, 2]
tags: Optional[Dict[str, Any]] = None

CustomDocCopy = _create_aux_model_doc_list_to_list(CustomDoc)
CustomDocCopy = create_new_model_cast_doclist_to_list(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc', {}
)
Expand Down Expand Up @@ -95,7 +95,7 @@ class CustomDoc(BaseDoc):
class TextDocWithId(BaseDoc):
ia: str

TextDocWithIdCopy = _create_aux_model_doc_list_to_list(TextDocWithId)
TextDocWithIdCopy = create_new_model_cast_doclist_to_list(TextDocWithId)
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)
Expand Down Expand Up @@ -125,7 +125,7 @@ class TextDocWithId(BaseDoc):
class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

ResultTestDocCopy = _create_aux_model_doc_list_to_list(ResultTestDoc)
ResultTestDocCopy = create_new_model_cast_doclist_to_list(ResultTestDoc)
new_result_test_doc_with_id_model = create_base_doc_from_schema(
ResultTestDocCopy.schema(), 'ResultTestDoc', {}
)
Expand Down Expand Up @@ -171,7 +171,7 @@ class CustomDoc(BaseDoc):
tags: Optional[Dict[str, Any]] = None
lf: List[float] = [3.0, 4.1]

CustomDocCopy = _create_aux_model_doc_list_to_list(CustomDoc)
CustomDocCopy = create_new_model_cast_doclist_to_list(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc'
)
Expand All @@ -196,7 +196,7 @@ class CustomDoc(BaseDoc):
class TextDocWithId(BaseDoc):
ia: str

TextDocWithIdCopy = _create_aux_model_doc_list_to_list(TextDocWithId)
TextDocWithIdCopy = create_new_model_cast_doclist_to_list(TextDocWithId)
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)
Expand All @@ -219,7 +219,7 @@ class TextDocWithId(BaseDoc):
class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

ResultTestDocCopy = _create_aux_model_doc_list_to_list(ResultTestDoc)
ResultTestDocCopy = create_new_model_cast_doclist_to_list(ResultTestDoc)
new_result_test_doc_with_id_model = create_base_doc_from_schema(
ResultTestDocCopy.schema(), 'ResultTestDoc', {}
)
Expand Down