Skip to content

Commit dc4051f

Browse files
authored
MNT update test framework for sklearn 0.24 (#788)
1 parent edd7522 commit dc4051f

File tree

8 files changed

+54
-37
lines changed

8 files changed

+54
-37
lines changed

azure-pipelines.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
# Linux environment to test the latest available dependencies and MKL.
3333
pylatest_pip_openblas_pandas:
3434
DISTRIB: 'conda-pip-latest'
35-
PYTHON_VERSION: '3.8'
35+
PYTHON_VERSION: '3.9'
3636
COVERAGE: 'true'
3737
PANDAS_VERSION: '*'
3838
TEST_DOCSTRINGS: 'true'

build_tools/azure/test_script.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ except ImportError:
2121
python -c "import multiprocessing as mp; print('%d CPUs' % mp.cpu_count())"
2222
pip list
2323

24-
TEST_CMD="python -m pytest --showlocals --durations=20 --junitxml=$JUNITXML"
24+
TEST_CMD="python -m pytest -vsl --durations=20 --junitxml=$JUNITXML"
2525

2626
if [[ "$COVERAGE" == "true" ]]; then
2727
export COVERAGE_PROCESS_START="$BUILD_SOURCESDIRECTORY/.coveragerc"

doc/conf.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
'sphinx_gallery.gen_gallery',
4444
]
4545

46+
# bibtex file
47+
bibtex_bibfiles = ['bibtex/refs.bib']
48+
4649
# this is needed for some reason...
4750
# see https://github.com/numpy/numpydoc/issues/69
4851
numpydoc_show_class_members = False
@@ -345,8 +348,8 @@ def patch_signature(subject, bound_method=False, follow_wrapped=True):
345348
# https://github.com/readthedocs/sphinx_rtd_theme/pull/747/files
346349
def setup(app):
347350
app.registry.documenters["class"] = PatchedClassDocumenter
348-
app.add_javascript("js/copybutton.js")
349-
app.add_stylesheet("basic.css")
351+
app.add_js_file("js/copybutton.js")
352+
app.add_css_file("basic.css")
350353
# app.connect('autodoc-process-docstring', generate_example_rst)
351354

352355

doc/over_sampling.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ In addition, :class:`RandomOverSampler` allows to sample heterogeneous data
6060

6161
>>> import numpy as np
6262
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
63-
... dtype=np.object)
63+
... dtype=object)
6464
>>> y_hetero = np.array([0, 0, 1])
6565
>>> X_resampled, y_resampled = ros.fit_resample(X_hetero, y_hetero)
6666
>>> print(X_resampled)

doc/under_sampling.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ In addition, :class:`RandomUnderSampler` allows to sample heterogeneous data
107107
(e.g. containing some strings)::
108108

109109
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
110-
... dtype=np.object)
110+
... dtype=object)
111111
>>> y_hetero = np.array([0, 0, 1])
112112
>>> X_resampled, y_resampled = rus.fit_resample(X_hetero, y_hetero)
113113
>>> print(X_resampled)

imblearn/over_sampling/tests/test_random_over_sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_multiclass_fit_resample():
115115

116116
def test_random_over_sampling_heterogeneous_data():
117117
X_hetero = np.array(
118-
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=np.object
118+
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
119119
)
120120
y = np.array([0, 0, 1])
121121
ros = RandomOverSampler(random_state=RND_SEED)

imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_multiclass_fit_resample():
101101

102102
def test_random_under_sampling_heterogeneous_data():
103103
X_hetero = np.array(
104-
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=np.object
104+
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
105105
)
106106
y = np.array([0, 0, 1])
107107
rus = RandomUnderSampler(random_state=RND_SEED)

imblearn/utils/estimator_checks.py

+43-29
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from sklearn.cluster import KMeans
2626
from sklearn.exceptions import SkipTestWarning
2727
from sklearn.preprocessing import label_binarize
28-
from sklearn.utils.estimator_checks import _mark_xfail_checks
29-
from sklearn.utils.estimator_checks import _set_check_estimator_ids
28+
from sklearn.utils.estimator_checks import _maybe_mark_xfail
29+
from sklearn.utils.estimator_checks import _get_check_estimator_ids
3030
from sklearn.utils._testing import assert_allclose
3131
from sklearn.utils._testing import assert_raises_regex
3232
from sklearn.utils.multiclass import type_of_target
@@ -44,7 +44,7 @@ def _set_checking_parameters(estimator):
4444
if name == "ClusterCentroids":
4545
estimator.set_params(
4646
voting="soft",
47-
estimator=KMeans(random_state=0, algorithm="full"),
47+
estimator=KMeans(random_state=0, algorithm="full", n_init=1),
4848
)
4949
if name == "KMeansSMOTE":
5050
estimator.set_params(kmeans_estimator=12)
@@ -117,21 +117,19 @@ def parametrize_with_checks(estimators):
117117
... def test_sklearn_compatible_estimator(estimator, check):
118118
... check(estimator)
119119
"""
120-
names = (type(estimator).__name__ for estimator in estimators)
120+
def checks_generator():
121+
for estimator in estimators:
122+
name = type(estimator).__name__
123+
for check in _yield_all_checks(estimator):
124+
check = partial(check, name)
125+
yield _maybe_mark_xfail(estimator, check, pytest)
121126

122-
checks_generator = ((clone(estimator), partial(check, name))
123-
for name, estimator in zip(names, estimators)
124-
for check in _yield_all_checks(estimator))
127+
return pytest.mark.parametrize("estimator, check", checks_generator(),
128+
ids=_get_check_estimator_ids)
125129

126-
checks_with_marks = (
127-
_mark_xfail_checks(estimator, check, pytest)
128-
for estimator, check in checks_generator)
129130

130-
return pytest.mark.parametrize("estimator, check", checks_with_marks,
131-
ids=_set_check_estimator_ids)
132-
133-
134-
def check_target_type(name, estimator):
131+
def check_target_type(name, estimator_orig):
132+
estimator = clone(estimator_orig)
135133
# should raise warning if the target is continuous (we cannot raise error)
136134
X = np.random.random((20, 2))
137135
y = np.linspace(0, 1, 20)
@@ -148,7 +146,8 @@ def check_target_type(name, estimator):
148146
)
149147

150148

151-
def check_samplers_one_label(name, sampler):
149+
def check_samplers_one_label(name, sampler_orig):
150+
sampler = clone(sampler_orig)
152151
error_string_fit = "Sampler can't balance when only one class is present."
153152
X = np.random.random((20, 2))
154153
y = np.zeros(20)
@@ -168,7 +167,8 @@ def check_samplers_one_label(name, sampler):
168167
raise AssertionError(error_string_fit)
169168

170169

171-
def check_samplers_fit(name, sampler):
170+
def check_samplers_fit(name, sampler_orig):
171+
sampler = clone(sampler_orig)
172172
np.random.seed(42) # Make this test reproducible
173173
X = np.random.random((30, 2))
174174
y = np.array([1] * 20 + [0] * 10)
@@ -178,7 +178,8 @@ def check_samplers_fit(name, sampler):
178178
), "No fitted attribute sampling_strategy_"
179179

180180

181-
def check_samplers_fit_resample(name, sampler):
181+
def check_samplers_fit_resample(name, sampler_orig):
182+
sampler = clone(sampler_orig)
182183
X, y = make_classification(
183184
n_samples=1000,
184185
n_classes=3,
@@ -213,7 +214,8 @@ def check_samplers_fit_resample(name, sampler):
213214
)
214215

215216

216-
def check_samplers_sampling_strategy_fit_resample(name, sampler):
217+
def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
218+
sampler = clone(sampler_orig)
217219
# in this test we will force all samplers to not change the class 1
218220
X, y = make_classification(
219221
n_samples=1000,
@@ -240,7 +242,8 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler):
240242
assert Counter(y_res)[1] == expected_stat
241243

242244

243-
def check_samplers_sparse(name, sampler):
245+
def check_samplers_sparse(name, sampler_orig):
246+
sampler = clone(sampler_orig)
244247
# check that sparse matrices can be passed through the sampler leading to
245248
# the same results than dense
246249
X, y = make_classification(
@@ -252,14 +255,16 @@ def check_samplers_sparse(name, sampler):
252255
)
253256
X_sparse = sparse.csr_matrix(X)
254257
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
258+
sampler = clone(sampler)
255259
X_res, y_res = sampler.fit_resample(X, y)
256260
assert sparse.issparse(X_res_sparse)
257-
assert_allclose(X_res_sparse.A, X_res)
261+
assert_allclose(X_res_sparse.A, X_res, rtol=1e-5)
258262
assert_allclose(y_res_sparse, y_res)
259263

260264

261-
def check_samplers_pandas(name, sampler):
265+
def check_samplers_pandas(name, sampler_orig):
262266
pd = pytest.importorskip("pandas")
267+
sampler = clone(sampler_orig)
263268
# Check that the samplers handle pandas dataframe and pandas series
264269
X, y = make_classification(
265270
n_samples=1000,
@@ -290,7 +295,8 @@ def check_samplers_pandas(name, sampler):
290295
assert_allclose(y_res_s.to_numpy(), y_res)
291296

292297

293-
def check_samplers_list(name, sampler):
298+
def check_samplers_list(name, sampler_orig):
299+
sampler = clone(sampler_orig)
294300
# Check that the can samplers handle simple lists
295301
X, y = make_classification(
296302
n_samples=1000,
@@ -312,7 +318,8 @@ def check_samplers_list(name, sampler):
312318
assert_allclose(y_res, y_res_list)
313319

314320

315-
def check_samplers_multiclass_ova(name, sampler):
321+
def check_samplers_multiclass_ova(name, sampler_orig):
322+
sampler = clone(sampler_orig)
316323
# Check that multiclass target lead to the same results than OVA encoding
317324
X, y = make_classification(
318325
n_samples=1000,
@@ -329,7 +336,8 @@ def check_samplers_multiclass_ova(name, sampler):
329336
assert_allclose(y_res, y_res_ova.argmax(axis=1))
330337

331338

332-
def check_samplers_2d_target(name, sampler):
339+
def check_samplers_2d_target(name, sampler_orig):
340+
sampler = clone(sampler_orig)
333341
X, y = make_classification(
334342
n_samples=100,
335343
n_classes=3,
@@ -342,7 +350,8 @@ def check_samplers_2d_target(name, sampler):
342350
sampler.fit_resample(X, y)
343351

344352

345-
def check_samplers_preserve_dtype(name, sampler):
353+
def check_samplers_preserve_dtype(name, sampler_orig):
354+
sampler = clone(sampler_orig)
346355
X, y = make_classification(
347356
n_samples=1000,
348357
n_classes=3,
@@ -358,7 +367,8 @@ def check_samplers_preserve_dtype(name, sampler):
358367
assert y.dtype == y_res.dtype, "y dtype is not preserved"
359368

360369

361-
def check_samplers_sample_indices(name, sampler):
370+
def check_samplers_sample_indices(name, sampler_orig):
371+
sampler = clone(sampler_orig)
362372
X, y = make_classification(
363373
n_samples=1000,
364374
n_classes=3,
@@ -374,17 +384,21 @@ def check_samplers_sample_indices(name, sampler):
374384
assert not hasattr(sampler, "sample_indices_")
375385

376386

377-
def check_classifier_on_multilabel_or_multioutput_targets(name, estimator):
387+
def check_classifier_on_multilabel_or_multioutput_targets(
388+
name, estimator_orig
389+
):
390+
estimator = clone(estimator_orig)
378391
X, y = make_multilabel_classification(n_samples=30)
379392
msg = "Multilabel and multioutput targets are not supported."
380393
with pytest.raises(ValueError, match=msg):
381394
estimator.fit(X, y)
382395

383396

384-
def check_classifiers_with_encoded_labels(name, classifier):
397+
def check_classifiers_with_encoded_labels(name, classifier_orig):
385398
# Non-regression test for #709
386399
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
387400
pytest.importorskip("pandas")
401+
classifier = clone(classifier_orig)
388402
df, y = fetch_openml("iris", version=1, as_frame=True, return_X_y=True)
389403
df, y = make_imbalance(
390404
df, y, sampling_strategy={

0 commit comments

Comments
 (0)