@@ -28,6 +28,70 @@ def test_from_list():
2828 assert (tensor == np .zeros ((2 , 2 ))).all ()
2929
3030
31+ def test_from_scalar_int ():
32+ """Test that scalar integers are properly converted to 1-dimensional arrays"""
33+ tensor = parse_obj_as (NdArray , 10 )
34+
35+ assert isinstance (tensor , NdArray )
36+ assert isinstance (tensor , np .ndarray )
37+ # Scalar should be wrapped in 1-dimensional array
38+ assert tensor .shape == (1 ,)
39+ assert tensor [0 ] == 10
40+ assert tensor .dtype in [np .int32 , np .int64 ] # Platform dependent
41+
42+
43+ def test_from_scalar_float ():
44+ """Test that scalar floats are properly converted to 1-dimensional arrays"""
45+ tensor = parse_obj_as (NdArray , 10.5 )
46+
47+ assert isinstance (tensor , NdArray )
48+ assert isinstance (tensor , np .ndarray )
49+ # Scalar should be wrapped in 1-dimensional array
50+ assert tensor .shape == (1 ,)
51+ assert tensor [0 ] == 10.5
52+ assert tensor .dtype in [np .float32 , np .float64 ]
53+
54+
55+ def test_from_scalar_complex ():
56+ """Test that scalar complex numbers are properly converted to 1-dimensional arrays"""
57+ tensor = parse_obj_as (NdArray , 3 + 4j )
58+
59+ assert isinstance (tensor , NdArray )
60+ assert isinstance (tensor , np .ndarray )
61+ # Scalar should be wrapped in 1-dimensional array
62+ assert tensor .shape == (1 ,)
63+ assert tensor [0 ] == 3 + 4j
64+ assert np .iscomplexobj (tensor )
65+
66+
67+ def test_from_scalar_bool ():
68+ """Test that scalar booleans are properly converted to 1-dimensional arrays"""
69+ tensor = parse_obj_as (NdArray , True )
70+
71+ assert isinstance (tensor , NdArray )
72+ assert isinstance (tensor , np .ndarray )
73+ # Scalar should be wrapped in 1-dimensional array
74+ assert tensor .shape == (1 ,)
75+ assert tensor [0 ] == True
76+ assert tensor .dtype == np .bool_
77+
78+
79+ def test_from_scalar_in_document ():
80+ """Test that scalar values work correctly when used in a Document"""
81+ class MyDoc (BaseDoc ):
82+ arr : NdArray
83+
84+ # Test with integer
85+ doc_int = MyDoc (arr = 42 )
86+ assert doc_int .arr .shape == (1 ,)
87+ assert doc_int .arr [0 ] == 42
88+
89+ # Test with float
90+ doc_float = MyDoc (arr = 3.14 )
91+ assert doc_float .arr .shape == (1 ,)
92+ assert doc_float .arr [0 ] == 3.14
93+
94+
3195def test_json_schema ():
3296 schema_json_of (NdArray )
3397
0 commit comments