Skip to content

Commit b9243a8

Browse files
committed
Fix scalar coercion & add comprehensive scalar validation tests for NdArray and JaxArray
1 parent f5fc0f6 commit b9243a8

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

docarray/typing/tensor/jaxarray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def _docarray_validate(
142142
pass # handled below
143143
elif isinstance(value, str):
144144
value = orjson.loads(value)
145-
145+
# Handle scalar values (int, float, etc.) - wrap in 1D array
146+
elif isinstance(value, (int, float, complex, bool, np.number)):
147+
arr_from_scalar: jnp.ndarray = jnp.array([value])
148+
return cls._docarray_from_native(arr_from_scalar)
146149
try:
147150
arr: jnp.ndarray = jnp.ndarray(value)
148151
return cls._docarray_from_native(arr)

docarray/typing/tensor/ndarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def _docarray_validate(
137137
return cls._docarray_from_native(arr_from_list)
138138
except Exception:
139139
pass # handled below
140+
# Handle scalar values (int, float, etc.) - wrap in 1D array
141+
elif isinstance(value, (int, float, complex, bool, np.number)):
142+
arr_from_scalar: np.ndarray = np.array([value])
143+
return cls._docarray_from_native(arr_from_scalar)
140144
try:
141145
arr: np.ndarray = np.ndarray(value)
142146
return cls._docarray_from_native(arr)

tests/units/typing/tensor/test_ndarray.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,70 @@ def test_from_list():
2828
assert (tensor == np.zeros((2, 2))).all()
2929

3030

31+
def test_from_scalar_int():
32+
"""Test that scalar integers are properly converted to 1-dimensional arrays"""
33+
tensor = parse_obj_as(NdArray, 10)
34+
35+
assert isinstance(tensor, NdArray)
36+
assert isinstance(tensor, np.ndarray)
37+
# Scalar should be wrapped in 1-dimensional array
38+
assert tensor.shape == (1,)
39+
assert tensor[0] == 10
40+
assert tensor.dtype in [np.int32, np.int64] # Platform dependent
41+
42+
43+
def test_from_scalar_float():
44+
"""Test that scalar floats are properly converted to 1-dimensional arrays"""
45+
tensor = parse_obj_as(NdArray, 10.5)
46+
47+
assert isinstance(tensor, NdArray)
48+
assert isinstance(tensor, np.ndarray)
49+
# Scalar should be wrapped in 1-dimensional array
50+
assert tensor.shape == (1,)
51+
assert tensor[0] == 10.5
52+
assert tensor.dtype in [np.float32, np.float64]
53+
54+
55+
def test_from_scalar_complex():
56+
"""Test that scalar complex numbers are properly converted to 1-dimensional arrays"""
57+
tensor = parse_obj_as(NdArray, 3+4j)
58+
59+
assert isinstance(tensor, NdArray)
60+
assert isinstance(tensor, np.ndarray)
61+
# Scalar should be wrapped in 1-dimensional array
62+
assert tensor.shape == (1,)
63+
assert tensor[0] == 3+4j
64+
assert np.iscomplexobj(tensor)
65+
66+
67+
def test_from_scalar_bool():
68+
"""Test that scalar booleans are properly converted to 1-dimensional arrays"""
69+
tensor = parse_obj_as(NdArray, True)
70+
71+
assert isinstance(tensor, NdArray)
72+
assert isinstance(tensor, np.ndarray)
73+
# Scalar should be wrapped in 1-dimensional array
74+
assert tensor.shape == (1,)
75+
assert tensor[0] == True
76+
assert tensor.dtype == np.bool_
77+
78+
79+
def test_from_scalar_in_document():
80+
"""Test that scalar values work correctly when used in a Document"""
81+
class MyDoc(BaseDoc):
82+
arr: NdArray
83+
84+
# Test with integer
85+
doc_int = MyDoc(arr=42)
86+
assert doc_int.arr.shape == (1,)
87+
assert doc_int.arr[0] == 42
88+
89+
# Test with float
90+
doc_float = MyDoc(arr=3.14)
91+
assert doc_float.arr.shape == (1,)
92+
assert doc_float.arr[0] == 3.14
93+
94+
3195
def test_json_schema():
3296
schema_json_of(NdArray)
3397

0 commit comments

Comments
 (0)