-
-
Notifications
You must be signed in to change notification settings - Fork 25.5k
/
conftest.py
196 lines (153 loc) · 6.05 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import warnings
from os import environ
from os.path import exists, join
import pytest
from _pytest.doctest import DoctestItem
from sklearn.datasets import get_data_home
from sklearn.datasets._base import _pkl_filepath
from sklearn.datasets._twenty_newsgroups import CACHE_NAME
from sklearn.utils._testing import SkipTest, check_skip_network
from sklearn.utils.fixes import np_base_version, parse_version, sp_version
def setup_labeled_faces():
data_home = get_data_home()
if not exists(join(data_home, "lfw_home")):
raise SkipTest("Skipping dataset loading doctests")
def setup_rcv1():
check_skip_network()
# skip the test in rcv1.rst if the dataset is not already loaded
rcv1_dir = join(get_data_home(), "RCV1")
if not exists(rcv1_dir):
raise SkipTest("Download RCV1 dataset to run this test.")
def setup_twenty_newsgroups():
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_working_with_text_data():
check_skip_network()
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_loading_other_datasets():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed")
# checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run
run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
if not run_network_tests:
raise SkipTest(
"Skipping loading_other_datasets.rst, tests can be "
"enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0"
)
def setup_compose():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping compose.rst, pandas not installed")
def setup_impute():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping impute.rst, pandas not installed")
def setup_grid_search():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping grid_search.rst, pandas not installed")
def setup_preprocessing():
try:
import pandas # noqa
if parse_version(pandas.__version__) < parse_version("1.1.0"):
raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0")
except ImportError:
raise SkipTest("Skipping preprocessing.rst, pandas not installed")
def setup_unsupervised_learning():
try:
import skimage # noqa
except ImportError:
raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed")
# ignore deprecation warnings from scipy.misc.face
warnings.filterwarnings(
"ignore", "The binary mode of fromstring", DeprecationWarning
)
def skip_if_matplotlib_not_installed(fname):
try:
import matplotlib # noqa
except ImportError:
basename = os.path.basename(fname)
raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")
def skip_if_cupy_not_installed(fname):
try:
import cupy # noqa
except ImportError:
basename = os.path.basename(fname)
raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")
def pytest_runtest_setup(item):
fname = item.fspath.strpath
# normalize filename to use forward slashes on Windows for easier handling
# later
fname = fname.replace(os.sep, "/")
is_index = fname.endswith("datasets/index.rst")
if fname.endswith("datasets/labeled_faces.rst") or is_index:
setup_labeled_faces()
elif fname.endswith("datasets/rcv1.rst") or is_index:
setup_rcv1()
elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index:
setup_twenty_newsgroups()
elif fname.endswith("modules/compose.rst") or is_index:
setup_compose()
elif fname.endswith("datasets/loading_other_datasets.rst"):
setup_loading_other_datasets()
elif fname.endswith("modules/impute.rst"):
setup_impute()
elif fname.endswith("modules/grid_search.rst"):
setup_grid_search()
elif fname.endswith("modules/preprocessing.rst"):
setup_preprocessing()
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
setup_unsupervised_learning()
rst_files_requiring_matplotlib = [
"modules/partial_dependence.rst",
"modules/tree.rst",
]
for each in rst_files_requiring_matplotlib:
if fname.endswith(each):
skip_if_matplotlib_not_installed(fname)
if fname.endswith("array_api.rst"):
skip_if_cupy_not_installed(fname)
def pytest_configure(config):
# Use matplotlib agg backend during the tests including doctests
try:
import matplotlib
matplotlib.use("agg")
except ImportError:
pass
def pytest_collection_modifyitems(config, items):
"""Called after collect is completed.
Parameters
----------
config : pytest config
items : list of collected items
"""
skip_doctests = False
if np_base_version >= parse_version("2"):
# Skip doctests when using numpy 2 for now. See the following discussion
# to decide what to do in the longer term:
# https://github.com/scikit-learn/scikit-learn/issues/27339
reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2"
skip_doctests = True
if sp_version < parse_version("1.14"):
reason = "Scipy sparse matrix repr has changed in scipy 1.14"
skip_doctests = True
# Normally doctest has the entire module's scope. Here we set globs to an empty dict
# to remove the module's scope:
# https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
for item in items:
if isinstance(item, DoctestItem):
item.dtest.globs = {}
if skip_doctests:
skip_marker = pytest.mark.skip(reason=reason)
for item in items:
if isinstance(item, DoctestItem):
item.add_marker(skip_marker)