forked from data-apis/array-api-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_utility_functions.py
More file actions
62 lines (52 loc) · 2 KB
/
Copy pathtest_utility_functions.py
File metadata and controls
62 lines (52 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pytest
from hypothesis import given
from hypothesis import strategies as st
from . import _array_module as xp
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
pytestmark = pytest.mark.ci
@given(
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)),
data=st.data(),
)
def test_all(x, data):
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
out = xp.all(x, **kw)
ph.assert_dtype("all", x.dtype, out.dtype, xp.bool)
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
ph.assert_keepdimable_shape(
"all", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
)
scalar_type = dh.get_scalar_type(x.dtype)
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
result = bool(out[out_idx])
elements = []
for idx in indices:
s = scalar_type(x[idx])
elements.append(s)
expected = all(elements)
ph.assert_scalar_equals("all", scalar_type, out_idx, result, expected)
@given(
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
data=st.data(),
)
def test_any(x, data):
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
out = xp.any(x, **kw)
ph.assert_dtype("any", x.dtype, out.dtype, xp.bool)
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
ph.assert_keepdimable_shape(
"any", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
)
scalar_type = dh.get_scalar_type(x.dtype)
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
result = bool(out[out_idx])
elements = []
for idx in indices:
s = scalar_type(x[idx])
elements.append(s)
expected = any(elements)
ph.assert_scalar_equals("any", scalar_type, out_idx, result, expected)