Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
]


def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
def create_pure_python_type_model(
model: BaseModel,
cached_models: Optional[Dict[str, Any]] = None,
) -> BaseDoc:
"""
Take a Pydantic model and cast DocList fields into List fields.

Expand All @@ -44,16 +47,18 @@ class MyDoc(BaseDoc):
texts: DocList[TextDoc]


MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
MyDocCorrected = create_pure_python_type_model(CustomDoc)
```

---
:param model: The input model
:param cached_models: A set of names of models that have been converted to their pure python type model
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
"""
fields: Dict[str, Any] = {}
import copy

cached_models = cached_models or {}
fields_copy = copy.deepcopy(model.__fields__)
annotations_copy = copy.deepcopy(model.__annotations__)
for field_name, field in annotations_copy.items():
Expand All @@ -67,14 +72,23 @@ class MyDoc(BaseDoc):
try:
if safe_issubclass(field, DocList):
t: Any = field.doc_type
t_aux = create_pure_python_type_model(t)
fields[field_name] = (List[t_aux], field_info)
if t.__name__ in cached_models:
fields[field_name] = (List[cached_models[t.__name__]], field_info)
else:
t_aux = create_pure_python_type_model(t, cached_models)
cached_models[t.__name__] = t_aux
fields[field_name] = (List[t_aux], field_info)
else:
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, field_info)

return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
new_model = create_model(
model.__name__, __base__=model, __doc__=model.__doc__, **fields
)
cached_models[model.__name__] = new_model

return new_model


def _get_field_annotation_from_schema(
Expand Down
26 changes: 26 additions & 0 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,29 @@ class SearchResult(BaseDoc):
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
)
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'


def test_create_pure_python_model_with_multiple_doclists_of_same_type():
from docarray import DocList, BaseDoc

class MyTextDoc(BaseDoc):
text: str

class QuoteFile(BaseDoc):
texts: DocList[MyTextDoc]

class QuoteFileType(BaseDoc):
"""
QuoteFileType class.
"""

id: str = (
None # same as name, compatibility reasons for a generic, shared `id` field
)
name: str = None
total_count: int = None
docs: DocList[QuoteFile] = None
chunks: DocList[QuoteFile] = None

new_model = create_pure_python_type_model(QuoteFileType)
new_model.schema()