11from docarray import DocList , BaseDoc
22from docarray .typing import AnyTensor
33from pydantic import create_model
4- from typing import Dict , List , Any , Union , Optional , Tuple , Type
5- from typing_extensions import TypeAlias
4+ from typing import Dict , List , Any , Union , Optional , Type
65
76
8- def create_new_model_cast_doclist_to_list (model : BaseDoc ) -> BaseDoc :
9- fields : Dict [str , Tuple [ Type , Dict ] ] = {}
7+ def create_new_model_cast_doclist_to_list (model : Any ) -> BaseDoc :
8+ fields : Dict [str , Any ] = {}
109 for field_name , field in model .__annotations__ .items ():
1110 try :
1211 if issubclass (field , DocList ):
13- fields [field_name ] = (List [field .doc_type ], {})
12+ t : Any = field .doc_type
13+ fields [field_name ] = (List [t ], {})
1414 else :
1515 fields [field_name ] = (field , {})
1616 except TypeError :
@@ -30,7 +30,7 @@ def _get_field_from_type(
3030):
3131 field_type = field_schema .get ('type' , None )
3232 tensor_shape = field_schema .get ('tensor/array shape' , None )
33- ret : TypeAlias
33+ ret : Any
3434 if 'anyOf' in field_schema :
3535 any_of_types = []
3636 for any_of_schema in field_schema ['anyOf' ]:
@@ -82,15 +82,14 @@ def _get_field_from_type(
8282 for rec in range (num_recursions ):
8383 ret = List [ret ]
8484 elif field_type == 'object' or field_type is None :
85+ doc_type : Any
8586 if 'additionalProperties' in field_schema : # handle Dictionaries
8687 additional_props = field_schema ['additionalProperties' ]
8788 if additional_props .get ('type' ) == 'object' :
88- ret = Dict [
89- str ,
90- create_base_doc_from_schema (
91- additional_props , field_name , cached_models = cached_models
92- ),
93- ]
89+ doc_type = create_base_doc_from_schema (
90+ additional_props , field_name , cached_models = cached_models
91+ )
92+ ret = Dict [str , doc_type ]
9493 else :
9594 ret = Dict [str , Any ]
9695 else :
@@ -110,19 +109,17 @@ def _get_field_from_type(
110109 else : # object reference in definitions
111110 if obj_ref :
112111 ref_name = obj_ref .split ('/' )[- 1 ]
113- ret = DocList [
114- create_base_doc_from_schema (
115- root_schema ['definitions' ][ref_name ],
116- ref_name ,
117- cached_models = cached_models ,
118- )
119- ]
112+ doc_type = create_base_doc_from_schema (
113+ root_schema ['definitions' ][ref_name ],
114+ ref_name ,
115+ cached_models = cached_models ,
116+ )
117+ ret = DocList [doc_type ]
120118 else :
121- ret = DocList [
122- create_base_doc_from_schema (
123- field_schema , field_name , cached_models = cached_models
124- )
125- ]
119+ doc_type = create_base_doc_from_schema (
120+ field_schema , field_name , cached_models = cached_models
121+ )
122+ ret = DocList [doc_type ]
126123 elif field_type == 'array' :
127124 ret = _get_field_from_type (
128125 field_schema = field_schema .get ('items' , {}),
@@ -148,7 +145,7 @@ def create_base_doc_from_schema(
148145 schema : Dict [str , Any ], model_name : str , cached_models : Optional [Dict ] = None
149146) -> Type :
150147 cached_models = cached_models if cached_models is not None else {}
151- fields = {}
148+ fields : Dict [ str , Any ] = {}
152149 if model_name in cached_models :
153150 return cached_models [model_name ]
154151 for field_name , field_schema in schema .get ('properties' , {}).items ():
0 commit comments