Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class MyDoc(BaseDoc):
"""
fields: Dict[str, Any] = {}
for field_name, field in model.__annotations__.items():
if field_name not in model.__fields__:
continue
field_info = model.__fields__[field_name].field_info
Copy link

@NarekA NarekA Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW: It might be a good idea to use something like pydash.get here in case traversal breaks at a different location.

try:
if safe_issubclass(field, DocList):
Expand Down
8 changes: 7 additions & 1 deletion tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, ClassVar

import numpy as np
import pytest
Expand All @@ -17,9 +17,11 @@
def test_create_pydantic_model_from_schema(transformation):
class Nested2Doc(BaseDoc):
value: str
classvar: ClassVar[str] = 'classvar2'

class Nested1Doc(BaseDoc):
nested: Nested2Doc
classvar: ClassVar[str] = 'classvar1'

class CustomDoc(BaseDoc):
tensor: Optional[AnyTensor]
Expand All @@ -34,6 +36,7 @@ class CustomDoc(BaseDoc):
lu: List[Union[str, int]] = [0, 1, 2]
tags: Optional[Dict[str, Any]] = None
nested: Nested1Doc
classvar: ClassVar[str] = 'classvar'

CustomDocCopy = create_pure_python_type_model(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
Expand Down Expand Up @@ -87,6 +90,9 @@ class CustomDoc(BaseDoc):
assert custom_partial_da[0].single_text.text == 'single hey ha'
assert custom_partial_da[0].single_text.embedding.shape == (2,)
assert original_back[0].nested.nested.value == 'hello world'
assert original_back[0].classvar == 'classvar'
assert original_back[0].nested.classvar == 'classvar1'
assert original_back[0].nested.nested.classvar == 'classvar2'

assert len(original_back) == 1
assert original_back[0].url == 'photo.jpg'
Expand Down