25
25
from sklearn .cluster import KMeans
26
26
from sklearn .exceptions import SkipTestWarning
27
27
from sklearn .preprocessing import label_binarize
28
- from sklearn .utils .estimator_checks import _mark_xfail_checks
29
- from sklearn .utils .estimator_checks import _set_check_estimator_ids
28
+ from sklearn .utils .estimator_checks import _maybe_mark_xfail
29
+ from sklearn .utils .estimator_checks import _get_check_estimator_ids
30
30
from sklearn .utils ._testing import assert_allclose
31
31
from sklearn .utils ._testing import assert_raises_regex
32
32
from sklearn .utils .multiclass import type_of_target
@@ -44,7 +44,7 @@ def _set_checking_parameters(estimator):
44
44
if name == "ClusterCentroids" :
45
45
estimator .set_params (
46
46
voting = "soft" ,
47
- estimator = KMeans (random_state = 0 , algorithm = "full" ),
47
+ estimator = KMeans (random_state = 0 , algorithm = "full" , n_init = 1 ),
48
48
)
49
49
if name == "KMeansSMOTE" :
50
50
estimator .set_params (kmeans_estimator = 12 )
@@ -117,21 +117,19 @@ def parametrize_with_checks(estimators):
117
117
... def test_sklearn_compatible_estimator(estimator, check):
118
118
... check(estimator)
119
119
"""
120
- names = (type (estimator ).__name__ for estimator in estimators )
120
+ def checks_generator ():
121
+ for estimator in estimators :
122
+ name = type (estimator ).__name__
123
+ for check in _yield_all_checks (estimator ):
124
+ check = partial (check , name )
125
+ yield _maybe_mark_xfail (estimator , check , pytest )
121
126
122
- checks_generator = ((clone (estimator ), partial (check , name ))
123
- for name , estimator in zip (names , estimators )
124
- for check in _yield_all_checks (estimator ))
127
+ return pytest .mark .parametrize ("estimator, check" , checks_generator (),
128
+ ids = _get_check_estimator_ids )
125
129
126
- checks_with_marks = (
127
- _mark_xfail_checks (estimator , check , pytest )
128
- for estimator , check in checks_generator )
129
130
130
- return pytest .mark .parametrize ("estimator, check" , checks_with_marks ,
131
- ids = _set_check_estimator_ids )
132
-
133
-
134
- def check_target_type (name , estimator ):
131
+ def check_target_type (name , estimator_orig ):
132
+ estimator = clone (estimator_orig )
135
133
# should raise warning if the target is continuous (we cannot raise error)
136
134
X = np .random .random ((20 , 2 ))
137
135
y = np .linspace (0 , 1 , 20 )
@@ -148,7 +146,8 @@ def check_target_type(name, estimator):
148
146
)
149
147
150
148
151
- def check_samplers_one_label (name , sampler ):
149
+ def check_samplers_one_label (name , sampler_orig ):
150
+ sampler = clone (sampler_orig )
152
151
error_string_fit = "Sampler can't balance when only one class is present."
153
152
X = np .random .random ((20 , 2 ))
154
153
y = np .zeros (20 )
@@ -168,7 +167,8 @@ def check_samplers_one_label(name, sampler):
168
167
raise AssertionError (error_string_fit )
169
168
170
169
171
- def check_samplers_fit (name , sampler ):
170
+ def check_samplers_fit (name , sampler_orig ):
171
+ sampler = clone (sampler_orig )
172
172
np .random .seed (42 ) # Make this test reproducible
173
173
X = np .random .random ((30 , 2 ))
174
174
y = np .array ([1 ] * 20 + [0 ] * 10 )
@@ -178,7 +178,8 @@ def check_samplers_fit(name, sampler):
178
178
), "No fitted attribute sampling_strategy_"
179
179
180
180
181
- def check_samplers_fit_resample (name , sampler ):
181
+ def check_samplers_fit_resample (name , sampler_orig ):
182
+ sampler = clone (sampler_orig )
182
183
X , y = make_classification (
183
184
n_samples = 1000 ,
184
185
n_classes = 3 ,
@@ -213,7 +214,8 @@ def check_samplers_fit_resample(name, sampler):
213
214
)
214
215
215
216
216
- def check_samplers_sampling_strategy_fit_resample (name , sampler ):
217
+ def check_samplers_sampling_strategy_fit_resample (name , sampler_orig ):
218
+ sampler = clone (sampler_orig )
217
219
# in this test we will force all samplers to not change the class 1
218
220
X , y = make_classification (
219
221
n_samples = 1000 ,
@@ -240,7 +242,8 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler):
240
242
assert Counter (y_res )[1 ] == expected_stat
241
243
242
244
243
- def check_samplers_sparse (name , sampler ):
245
+ def check_samplers_sparse (name , sampler_orig ):
246
+ sampler = clone (sampler_orig )
244
247
# check that sparse matrices can be passed through the sampler leading to
245
248
# the same results than dense
246
249
X , y = make_classification (
@@ -252,14 +255,16 @@ def check_samplers_sparse(name, sampler):
252
255
)
253
256
X_sparse = sparse .csr_matrix (X )
254
257
X_res_sparse , y_res_sparse = sampler .fit_resample (X_sparse , y )
258
+ sampler = clone (sampler )
255
259
X_res , y_res = sampler .fit_resample (X , y )
256
260
assert sparse .issparse (X_res_sparse )
257
- assert_allclose (X_res_sparse .A , X_res )
261
+ assert_allclose (X_res_sparse .A , X_res , rtol = 1e-5 )
258
262
assert_allclose (y_res_sparse , y_res )
259
263
260
264
261
- def check_samplers_pandas (name , sampler ):
265
+ def check_samplers_pandas (name , sampler_orig ):
262
266
pd = pytest .importorskip ("pandas" )
267
+ sampler = clone (sampler_orig )
263
268
# Check that the samplers handle pandas dataframe and pandas series
264
269
X , y = make_classification (
265
270
n_samples = 1000 ,
@@ -290,7 +295,8 @@ def check_samplers_pandas(name, sampler):
290
295
assert_allclose (y_res_s .to_numpy (), y_res )
291
296
292
297
293
- def check_samplers_list (name , sampler ):
298
+ def check_samplers_list (name , sampler_orig ):
299
+ sampler = clone (sampler_orig )
294
300
# Check that the can samplers handle simple lists
295
301
X , y = make_classification (
296
302
n_samples = 1000 ,
@@ -312,7 +318,8 @@ def check_samplers_list(name, sampler):
312
318
assert_allclose (y_res , y_res_list )
313
319
314
320
315
- def check_samplers_multiclass_ova (name , sampler ):
321
+ def check_samplers_multiclass_ova (name , sampler_orig ):
322
+ sampler = clone (sampler_orig )
316
323
# Check that multiclass target lead to the same results than OVA encoding
317
324
X , y = make_classification (
318
325
n_samples = 1000 ,
@@ -329,7 +336,8 @@ def check_samplers_multiclass_ova(name, sampler):
329
336
assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
330
337
331
338
332
- def check_samplers_2d_target (name , sampler ):
339
+ def check_samplers_2d_target (name , sampler_orig ):
340
+ sampler = clone (sampler_orig )
333
341
X , y = make_classification (
334
342
n_samples = 100 ,
335
343
n_classes = 3 ,
@@ -342,7 +350,8 @@ def check_samplers_2d_target(name, sampler):
342
350
sampler .fit_resample (X , y )
343
351
344
352
345
- def check_samplers_preserve_dtype (name , sampler ):
353
+ def check_samplers_preserve_dtype (name , sampler_orig ):
354
+ sampler = clone (sampler_orig )
346
355
X , y = make_classification (
347
356
n_samples = 1000 ,
348
357
n_classes = 3 ,
@@ -358,7 +367,8 @@ def check_samplers_preserve_dtype(name, sampler):
358
367
assert y .dtype == y_res .dtype , "y dtype is not preserved"
359
368
360
369
361
- def check_samplers_sample_indices (name , sampler ):
370
+ def check_samplers_sample_indices (name , sampler_orig ):
371
+ sampler = clone (sampler_orig )
362
372
X , y = make_classification (
363
373
n_samples = 1000 ,
364
374
n_classes = 3 ,
@@ -374,17 +384,21 @@ def check_samplers_sample_indices(name, sampler):
374
384
assert not hasattr (sampler , "sample_indices_" )
375
385
376
386
377
- def check_classifier_on_multilabel_or_multioutput_targets (name , estimator ):
387
+ def check_classifier_on_multilabel_or_multioutput_targets (
388
+ name , estimator_orig
389
+ ):
390
+ estimator = clone (estimator_orig )
378
391
X , y = make_multilabel_classification (n_samples = 30 )
379
392
msg = "Multilabel and multioutput targets are not supported."
380
393
with pytest .raises (ValueError , match = msg ):
381
394
estimator .fit (X , y )
382
395
383
396
384
- def check_classifiers_with_encoded_labels (name , classifier ):
397
+ def check_classifiers_with_encoded_labels (name , classifier_orig ):
385
398
# Non-regression test for #709
386
399
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
387
400
pytest .importorskip ("pandas" )
401
+ classifier = clone (classifier_orig )
388
402
df , y = fetch_openml ("iris" , version = 1 , as_frame = True , return_X_y = True )
389
403
df , y = make_imbalance (
390
404
df , y , sampling_strategy = {
0 commit comments