Skip to content

Commit 7730bbf

Browse files
committed
fix: fix float in doc
1 parent f71a5e6 commit 7730bbf

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

docarray/utils/create_dynamic_doc_class.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def _get_field_annotation_from_schema(
140140
for rec in range(num_recursions):
141141
ret = List[ret]
142142
elif field_type == 'number':
143-
if num_recursions <= 1:
143+
if num_recursions == 0:
144+
ret = float
145+
elif num_recursions == 1:
144146
# This is a hack because AnyTensor is more generic than a simple List and it comes as simple List
145147
if is_tensor:
146148
ret = AnyTensor

tests/units/util/test_create_dynamic_code_class.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class Nested1Doc(BaseDoc):
2727
class CustomDoc(BaseDoc):
2828
tensor: Optional[AnyTensor] = None
2929
url: ImageUrl
30+
num: float = 0.5
31+
num_num: List[float] = [1.5, 2.5]
3032
lll: List[List[List[int]]] = [[[5]]]
3133
fff: List[List[List[float]]] = [[[5.2]]]
3234
single_text: TextDoc
@@ -47,6 +49,8 @@ class CustomDoc(BaseDoc):
4749
original_custom_docs = DocList[CustomDoc](
4850
[
4951
CustomDoc(
52+
num=3.5,
53+
num_num=[4.5, 5.5],
5054
url='photo.jpg',
5155
lll=[[[40]]],
5256
fff=[[[40.2]]],
@@ -78,6 +82,8 @@ class CustomDoc(BaseDoc):
7882

7983
assert len(custom_partial_da) == 1
8084
assert custom_partial_da[0].url == 'photo.jpg'
85+
assert custom_partial_da[0].num == 3.5
86+
assert custom_partial_da[0].num_num == [4.5, 5.5]
8187
assert custom_partial_da[0].lll == [[[40]]]
8288
if is_pydantic_v2:
8389
assert custom_partial_da[0].lu == [3, 4]
@@ -94,6 +100,8 @@ class CustomDoc(BaseDoc):
94100
assert custom_partial_da[0].single_text.text == 'single hey ha'
95101
assert custom_partial_da[0].single_text.embedding.shape == (2,)
96102
assert original_back[0].nested.nested.value == 'hello world'
103+
assert original_back[0].num == 3.5
104+
assert original_back[0].num_num == [4.5, 5.5]
97105
assert original_back[0].classvar == 'classvar'
98106
assert original_back[0].nested.classvar == 'classvar1'
99107
assert original_back[0].nested.nested.classvar == 'classvar2'

0 commit comments

Comments
 (0)