Skip to content

Commit 10c4196

Browse files
chkoarglemaitre
authored andcommitted
BUG: Add memory to make_pipeline function (#458)
1 parent 7b704ea commit 10c4196

File tree

3 files changed

+57
-4
lines changed

3 files changed

+57
-4
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ Bug fixes
7474

7575
- Fix bug which was not preserving the dtype of X and y when generating
7676
samples.
77-
issue:`448` by :user:`Guillaume Lemaitre <glemaitre>`.
77+
:issue:`448` by :user:`Guillaume Lemaitre <glemaitre>`.
78+
79+
- Add the option to pass a ``Memory`` object to :func:`make_pipeline` like
80+
in :class:`pipeline.Pipeline` class.
81+
:issue:`458` by :user:`Christos Aridas <chkoar>`.
7882

7983
Maintenance
8084
...........
@@ -117,4 +121,4 @@ Deprecation
117121
- Deprecate :class:`imblearn.ensemble.EasyEnsemble` in favor of meta-estimator
118122
:class:`imblearn.ensemble.EasyEnsembleClassifier` which follow the exact
119123
algorithm described in the literature.
120-
:issue:`455` by :user:`Guillaume Lemaitre <glemaitre>`.
124+
:issue:`455` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/pipeline.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,15 +600,50 @@ def _fit_sample_one(sampler, X, y, **fit_params):
600600
return X_res, y_res, sampler
601601

602602

603-
def make_pipeline(*steps):
603+
def make_pipeline(*steps, **kwargs):
604604
"""Construct a Pipeline from the given estimators.
605605
606606
This is a shorthand for the Pipeline constructor; it does not require, and
607607
does not permit, naming the estimators. Instead, their names will be set
608608
to the lowercase of their types automatically.
609609
610+
Parameters
611+
----------
612+
*steps : list of estimators.
613+
614+
memory : None, str or object with the joblib.Memory interface, optional
615+
Used to cache the fitted transformers of the pipeline. By default,
616+
no caching is performed. If a string is given, it is the path to
617+
the caching directory. Enabling caching triggers a clone of
618+
the transformers before fitting. Therefore, the transformer
619+
instance given to the pipeline cannot be inspected
620+
directly. Use the attribute ``named_steps`` or ``steps`` to
621+
inspect estimators within the pipeline. Caching the
622+
transformers is advantageous when fitting is time consuming.
623+
610624
Returns
611625
-------
612626
p : Pipeline
627+
628+
See also
629+
--------
630+
imblearn.pipeline.Pipeline : Class for creating a pipeline of
631+
transforms with a final estimator.
632+
633+
Examples
634+
--------
635+
>>> from sklearn.naive_bayes import GaussianNB
636+
>>> from sklearn.preprocessing import StandardScaler
637+
>>> make_pipeline(StandardScaler(), GaussianNB(priors=None))
638+
... # doctest: +NORMALIZE_WHITESPACE
639+
Pipeline(memory=None,
640+
steps=[('standardscaler',
641+
StandardScaler(copy=True, with_mean=True, with_std=True)),
642+
('gaussiannb',
643+
GaussianNB(priors=None))])
613644
"""
614-
return Pipeline(pipeline._name_estimators(steps))
645+
memory = kwargs.pop('memory', None)
646+
if kwargs:
647+
raise TypeError('Unknown keyword arguments: "{}"'
648+
.format(list(kwargs.keys())[0]))
649+
return Pipeline(pipeline._name_estimators(steps), memory=memory)

imblearn/tests/test_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
from pytest import raises
1414

15+
from sklearn.utils.testing import assert_true
1516
from sklearn.utils.testing import assert_array_equal
1617
from sklearn.utils.testing import assert_array_almost_equal
1718
from sklearn.utils.testing import assert_allclose
@@ -31,6 +32,7 @@
3132
from imblearn.under_sampling import (RandomUnderSampler,
3233
EditedNearestNeighbours as ENN)
3334

35+
3436
JUNK_FOOD_DOCS = (
3537
"the pizza pizza beer copyright",
3638
"the pizza burger beer copyright",
@@ -1073,3 +1075,15 @@ def test_pipeline_fit_then_sample_3_samplers_with_sampler_last_estimator():
10731075
X_fit_then_sample_res, y_fit_then_sample_res = pipeline.sample(X, y)
10741076
assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_res)
10751077
assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_res)
1078+
1079+
1080+
def test_make_pipeline_memory():
1081+
cachedir = mkdtemp()
1082+
try:
1083+
memory = Memory(cachedir=cachedir, verbose=10)
1084+
pipeline = make_pipeline(DummyTransf(), SVC(), memory=memory)
1085+
assert pipeline.memory is memory
1086+
pipeline = make_pipeline(DummyTransf(), SVC())
1087+
assert pipeline.memory is None
1088+
finally:
1089+
shutil.rmtree(cachedir)

0 commit comments

Comments
 (0)