Skip to content

Commit 4b018d8

Browse files
author
Joan Fontanals Martinez
committed
feat: add method to create BaseDoc from schema
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent f507a5f commit 4b018d8

File tree

2 files changed

+405
-0
lines changed

2 files changed

+405
-0
lines changed

docarray/utils/create.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from docarray import DocList, BaseDoc
2+
from docarray.typing import AnyTensor
3+
from pydantic import create_model
4+
from typing import Dict, List, Any, Union, Optional
5+
6+
7+
def _create_aux_model_doc_list_to_list(model):
8+
fields = {}
9+
for field_name, field in model.__annotations__.items():
10+
try:
11+
if issubclass(field, DocList):
12+
fields[field_name] = (List[field.doc_type], {})
13+
else:
14+
fields[field_name] = (field, {})
15+
except TypeError:
16+
fields[field_name] = (field, {})
17+
return create_model(
18+
model.__name__, __base__=model, __validators__=model.__validators__, **fields
19+
)
20+
21+
22+
def _get_field_from_type(
23+
field_schema,
24+
field_name,
25+
root_schema,
26+
cached_models,
27+
is_tensor=False,
28+
num_recursions=0,
29+
):
30+
field_type = field_schema.get('type', None)
31+
tensor_shape = field_schema.get('tensor/array shape', None)
32+
if 'anyOf' in field_schema:
33+
any_of_types = []
34+
for any_of_schema in field_schema['anyOf']:
35+
if '$ref' in any_of_schema:
36+
obj_ref = any_of_schema.get('$ref')
37+
ref_name = obj_ref.split('/')[-1]
38+
any_of_types.append(
39+
create_base_doc_from_schema(
40+
root_schema['definitions'][ref_name],
41+
ref_name,
42+
cached_models=cached_models,
43+
)
44+
)
45+
else:
46+
any_of_types.append(
47+
_get_field_from_type(
48+
any_of_schema,
49+
field_name,
50+
root_schema=root_schema,
51+
cached_models=cached_models,
52+
is_tensor=tensor_shape is not None,
53+
num_recursions=0,
54+
)
55+
) # No Union of Lists
56+
ret = Union[tuple(any_of_types)]
57+
for rec in range(num_recursions):
58+
ret = List[ret]
59+
elif field_type == 'string':
60+
ret = str
61+
for rec in range(num_recursions):
62+
ret = List[ret]
63+
elif field_type == 'integer':
64+
ret = int
65+
for rec in range(num_recursions):
66+
ret = List[ret]
67+
elif field_type == 'number':
68+
if num_recursions <= 1:
69+
# This is a hack because AnyTensor is more generic than a simple List and it comes as simple List
70+
if is_tensor:
71+
ret = AnyTensor
72+
else:
73+
ret = List[float]
74+
else:
75+
ret = float
76+
for rec in range(num_recursions):
77+
ret = List[ret]
78+
elif field_type == 'boolean':
79+
ret = bool
80+
for rec in range(num_recursions):
81+
ret = List[ret]
82+
elif field_type == 'object' or field_type is None:
83+
if 'additionalProperties' in field_schema: # handle Dictionaries
84+
additional_props = field_schema['additionalProperties']
85+
if additional_props.get('type') == 'object':
86+
ret = Dict[
87+
str,
88+
create_base_doc_from_schema(
89+
additional_props, field_name, cached_models=cached_models
90+
),
91+
]
92+
else:
93+
ret = Dict[str, Any]
94+
else:
95+
obj_ref = field_schema.get('$ref') or field_schema.get('allOf', [{}])[
96+
0
97+
].get('$ref', None)
98+
if num_recursions == 0: # single object reference
99+
if obj_ref:
100+
ref_name = obj_ref.split('/')[-1]
101+
ret = create_base_doc_from_schema(
102+
root_schema['definitions'][ref_name],
103+
ref_name,
104+
cached_models=cached_models,
105+
)
106+
else:
107+
ret = Any
108+
else: # object reference in definitions
109+
if obj_ref:
110+
ref_name = obj_ref.split('/')[-1]
111+
ret = DocList[
112+
create_base_doc_from_schema(
113+
root_schema['definitions'][ref_name],
114+
ref_name,
115+
cached_models=cached_models,
116+
)
117+
]
118+
else:
119+
ret = DocList[
120+
create_base_doc_from_schema(
121+
field_schema, field_name, cached_models=cached_models
122+
)
123+
]
124+
elif field_type == 'array':
125+
ret = _get_field_from_type(
126+
field_schema=field_schema.get('items', {}),
127+
field_name=field_name,
128+
root_schema=root_schema,
129+
cached_models=cached_models,
130+
is_tensor=tensor_shape is not None,
131+
num_recursions=num_recursions + 1,
132+
)
133+
else:
134+
if num_recursions > 0:
135+
raise ValueError(
136+
f"Unknown array item type: {field_type} for field_name {field_name}"
137+
)
138+
else:
139+
raise ValueError(
140+
f"Unknown field type: {field_type} for field_name {field_name}"
141+
)
142+
return ret
143+
144+
145+
def create_base_doc_from_schema(
146+
schema: Dict[str, any], model_name: str, cached_models: Optional[Dict] = None
147+
) -> type:
148+
cached_models = cached_models if cached_models is not None else {}
149+
fields = {}
150+
if model_name in cached_models:
151+
return cached_models[model_name]
152+
for field_name, field_schema in schema.get('properties', {}).items():
153+
field_type = _get_field_from_type(
154+
field_schema=field_schema,
155+
field_name=field_name,
156+
root_schema=schema,
157+
cached_models=cached_models,
158+
is_tensor=False,
159+
num_recursions=0,
160+
)
161+
fields[field_name] = (field_type, field_schema.get('description'))
162+
163+
model = create_model(model_name, __base__=BaseDoc, **fields)
164+
cached_models[model_name] = model
165+
return model

0 commit comments

Comments
 (0)