Skip to content

Commit c69b883

Browse files
committed
fix: fix create dynamic code class
1 parent 791e4a0 commit c69b883

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

docarray/utils/create_dynamic_doc_class.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Type, Union
1+
from typing import Any, Dict, List, Optional, Type, Union, Set
22

33
from pydantic import BaseModel, create_model
44
from pydantic.fields import FieldInfo
@@ -22,7 +22,7 @@
2222
]
2323

2424

25-
def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
25+
def create_pure_python_type_model(model: BaseModel, cached_models: Set[str],) -> BaseDoc:
2626
"""
2727
Take a Pydantic model and cast DocList fields into List fields.
2828
@@ -44,16 +44,18 @@ class MyDoc(BaseDoc):
4444
texts: DocList[TextDoc]
4545
4646
47-
MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
47+
MyDocCorrected = create_pure_python_type_model(CustomDoc)
4848
```
4949
5050
---
5151
:param model: The input model
52+
:param cached_models: A set of names of models that have been converted to their pure python type model
5253
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
5354
"""
5455
fields: Dict[str, Any] = {}
5556
import copy
5657

58+
cached_models = cached_models or set()
5759
fields_copy = copy.deepcopy(model.__fields__)
5860
annotations_copy = copy.deepcopy(model.__annotations__)
5961
for field_name, field in annotations_copy.items():
@@ -67,14 +69,21 @@ class MyDoc(BaseDoc):
6769
try:
6870
if safe_issubclass(field, DocList):
6971
t: Any = field.doc_type
70-
t_aux = create_pure_python_type_model(t)
71-
fields[field_name] = (List[t_aux], field_info)
72+
if t.__name__ in cached_models:
73+
fields[field_name] = (List[cached_models[t.__name__]], field_info)
74+
else:
75+
t_aux = create_pure_python_type_model(t, cached_models)
76+
fields[field_name] = (List[t_aux], field_info)
77+
cached_models.add(t.__name__)
7278
else:
7379
fields[field_name] = (field, field_info)
7480
except TypeError:
7581
fields[field_name] = (field, field_info)
7682

77-
return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
83+
new_model = create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
84+
cached_models.add(new_model.__name__)
85+
86+
return new_model
7887

7988

8089
def _get_field_annotation_from_schema(

tests/units/util/test_create_dynamic_code_class.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,26 @@ class SearchResult(BaseDoc):
315315
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
316316
)
317317
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'
318+
319+
320+
def test_create_pure_python_model_with_multiple_doclists_of_same_type():
321+
from docarray import DocList, BaseDoc
322+
323+
class MyTextDoc(BaseDoc):
324+
text: str
325+
326+
class QuoteFile(BaseDoc):
327+
texts: DocList[MyTextDoc]
328+
329+
class QuoteFileType(BaseDoc):
330+
"""
331+
QuoteFileType class.
332+
"""
333+
id: str = None # same as name, compatibility reasons for a generic, shared `id` field
334+
name: str = None
335+
total_count: int = None
336+
docs: DocList[QuoteFile] = None
337+
chunks: DocList[QuoteFile] = None
338+
339+
new_model = create_pure_python_type_model(QuoteFileType)
340+
new_model.schema()

0 commit comments

Comments
 (0)