Skip to content

Commit 586babd

Browse files
committed
fix: fix create dynamic code class
Signed-off-by: Joan Martinez <[email protected]>
1 parent 791e4a0 commit 586babd

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

docarray/utils/create_dynamic_doc_class.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
]
2323

2424

25-
def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
25+
def create_pure_python_type_model(
26+
model: BaseModel,
27+
cached_models: Optional[Dict[str, Any]] = None,
28+
) -> BaseDoc:
2629
"""
2730
Take a Pydantic model and cast DocList fields into List fields.
2831
@@ -44,16 +47,18 @@ class MyDoc(BaseDoc):
4447
texts: DocList[TextDoc]
4548
4649
47-
MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
50+
MyDocCorrected = create_pure_python_type_model(CustomDoc)
4851
```
4952
5053
---
5154
:param model: The input model
55+
:param cached_models: A set of names of models that have been converted to their pure python type model
5256
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
5357
"""
5458
fields: Dict[str, Any] = {}
5559
import copy
5660

61+
cached_models = cached_models or {}
5762
fields_copy = copy.deepcopy(model.__fields__)
5863
annotations_copy = copy.deepcopy(model.__annotations__)
5964
for field_name, field in annotations_copy.items():
@@ -67,14 +72,23 @@ class MyDoc(BaseDoc):
6772
try:
6873
if safe_issubclass(field, DocList):
6974
t: Any = field.doc_type
70-
t_aux = create_pure_python_type_model(t)
71-
fields[field_name] = (List[t_aux], field_info)
75+
if t.__name__ in cached_models:
76+
fields[field_name] = (List[cached_models[t.__name__]], field_info)
77+
else:
78+
t_aux = create_pure_python_type_model(t, cached_models)
79+
cached_models[t.__name__] = t_aux
80+
fields[field_name] = (List[t_aux], field_info)
7281
else:
7382
fields[field_name] = (field, field_info)
7483
except TypeError:
7584
fields[field_name] = (field, field_info)
7685

77-
return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
86+
new_model = create_model(
87+
model.__name__, __base__=model, __doc__=model.__doc__, **fields
88+
)
89+
cached_models[model.__name__] = new_model
90+
91+
return new_model
7892

7993

8094
def _get_field_annotation_from_schema(

tests/units/util/test_create_dynamic_code_class.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,29 @@ class SearchResult(BaseDoc):
315315
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
316316
)
317317
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'
318+
319+
320+
def test_create_pure_python_model_with_multiple_doclists_of_same_type():
321+
from docarray import DocList, BaseDoc
322+
323+
class MyTextDoc(BaseDoc):
324+
text: str
325+
326+
class QuoteFile(BaseDoc):
327+
texts: DocList[MyTextDoc]
328+
329+
class QuoteFileType(BaseDoc):
330+
"""
331+
QuoteFileType class.
332+
"""
333+
334+
id: str = (
335+
None # same as name, compatibility reasons for a generic, shared `id` field
336+
)
337+
name: str = None
338+
total_count: int = None
339+
docs: DocList[QuoteFile] = None
340+
chunks: DocList[QuoteFile] = None
341+
342+
new_model = create_pure_python_type_model(QuoteFileType)
343+
new_model.schema()

0 commit comments

Comments
 (0)