Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: fix mypy more
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
Joan Fontanals Martinez committed Jun 21, 2023
commit e5f065ffd712948a528284e3bb00d68981e7d109
47 changes: 22 additions & 25 deletions docarray/utils/create.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from docarray import DocList, BaseDoc
from docarray.typing import AnyTensor
from pydantic import create_model
from typing import Dict, List, Any, Union, Optional, Tuple, Type
from typing_extensions import TypeAlias
from typing import Dict, List, Any, Union, Optional, Type


def create_new_model_cast_doclist_to_list(model: BaseDoc) -> BaseDoc:
fields: Dict[str, Tuple[Type, Dict]] = {}
def create_new_model_cast_doclist_to_list(model: Any) -> BaseDoc:
fields: Dict[str, Any] = {}
for field_name, field in model.__annotations__.items():
try:
if issubclass(field, DocList):
fields[field_name] = (List[field.doc_type], {})
t: Any = field.doc_type
fields[field_name] = (List[t], {})
else:
fields[field_name] = (field, {})
except TypeError:
Expand All @@ -30,7 +30,7 @@ def _get_field_from_type(
):
field_type = field_schema.get('type', None)
tensor_shape = field_schema.get('tensor/array shape', None)
ret: TypeAlias
ret: Any
if 'anyOf' in field_schema:
any_of_types = []
for any_of_schema in field_schema['anyOf']:
Expand Down Expand Up @@ -82,15 +82,14 @@ def _get_field_from_type(
for rec in range(num_recursions):
ret = List[ret]
elif field_type == 'object' or field_type is None:
doc_type: Any
if 'additionalProperties' in field_schema: # handle Dictionaries
additional_props = field_schema['additionalProperties']
if additional_props.get('type') == 'object':
ret = Dict[
str,
create_base_doc_from_schema(
additional_props, field_name, cached_models=cached_models
),
]
doc_type = create_base_doc_from_schema(
additional_props, field_name, cached_models=cached_models
)
ret = Dict[str, doc_type]
else:
ret = Dict[str, Any]
else:
Expand All @@ -110,19 +109,17 @@ def _get_field_from_type(
else: # object reference in definitions
if obj_ref:
ref_name = obj_ref.split('/')[-1]
ret = DocList[
create_base_doc_from_schema(
root_schema['definitions'][ref_name],
ref_name,
cached_models=cached_models,
)
]
doc_type = create_base_doc_from_schema(
root_schema['definitions'][ref_name],
ref_name,
cached_models=cached_models,
)
ret = DocList[doc_type]
else:
ret = DocList[
create_base_doc_from_schema(
field_schema, field_name, cached_models=cached_models
)
]
doc_type = create_base_doc_from_schema(
field_schema, field_name, cached_models=cached_models
)
ret = DocList[doc_type]
elif field_type == 'array':
ret = _get_field_from_type(
field_schema=field_schema.get('items', {}),
Expand All @@ -148,7 +145,7 @@ def create_base_doc_from_schema(
schema: Dict[str, Any], model_name: str, cached_models: Optional[Dict] = None
) -> Type:
cached_models = cached_models if cached_models is not None else {}
fields = {}
fields: Dict[str, Any] = {}
if model_name in cached_models:
return cached_models[model_name]
for field_name, field_schema in schema.get('properties', {}).items():
Expand Down