diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index 2133e0af4..732f9c12e 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -74,7 +74,11 @@ Bug fixes - Fix bug which was not preserving the dtype of X and y when generating samples. - issue:`448` by :user:`Guillaume Lemaitre `. + :issue:`448` by :user:`Guillaume Lemaitre `. + +- Add the option to pass a ``Memory`` object to :func:`make_pipeline` like + in :class:`pipeline.Pipeline` class. + :issue:`458` by :user:`Christos Aridas `. Maintenance ........... @@ -117,4 +121,4 @@ Deprecation - Deprecate :class:`imblearn.ensemble.EasyEnsemble` in favor of meta-estimator :class:`imblearn.ensemble.EasyEnsembleClassifier` which follow the exact algorithm described in the literature. - :issue:`455` by :user:`Guillaume Lemaitre `. \ No newline at end of file + :issue:`455` by :user:`Guillaume Lemaitre `. diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index 90af7ab6c..ecb4c8b6c 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -600,15 +600,50 @@ def _fit_sample_one(sampler, X, y, **fit_params): return X_res, y_res, sampler -def make_pipeline(*steps): +def make_pipeline(*steps, **kwargs): """Construct a Pipeline from the given estimators. This is a shorthand for the Pipeline constructor; it does not require, and does not permit, naming the estimators. Instead, their names will be set to the lowercase of their types automatically. + Parameters + ---------- + *steps : list of estimators. + + memory : None, str or object with the joblib.Memory interface, optional + Used to cache the fitted transformers of the pipeline. By default, + no caching is performed. If a string is given, it is the path to + the caching directory. Enabling caching triggers a clone of + the transformers before fitting. Therefore, the transformer + instance given to the pipeline cannot be inspected + directly. Use the attribute ``named_steps`` or ``steps`` to + inspect estimators within the pipeline. Caching the + transformers is advantageous when fitting is time consuming. + Returns ------- p : Pipeline + + See also + -------- + imblearn.pipeline.Pipeline : Class for creating a pipeline of + transforms with a final estimator. + + Examples + -------- + >>> from sklearn.naive_bayes import GaussianNB + >>> from sklearn.preprocessing import StandardScaler + >>> make_pipeline(StandardScaler(), GaussianNB(priors=None)) + ... # doctest: +NORMALIZE_WHITESPACE + Pipeline(memory=None, + steps=[('standardscaler', + StandardScaler(copy=True, with_mean=True, with_std=True)), + ('gaussiannb', + GaussianNB(priors=None))]) """ - return Pipeline(pipeline._name_estimators(steps)) + memory = kwargs.pop('memory', None) + if kwargs: + raise TypeError('Unknown keyword arguments: "{}"' + .format(list(kwargs.keys())[0])) + return Pipeline(pipeline._name_estimators(steps), memory=memory) diff --git a/imblearn/tests/test_pipeline.py b/imblearn/tests/test_pipeline.py index 97b65c4ac..db11504ed 100644 --- a/imblearn/tests/test_pipeline.py +++ b/imblearn/tests/test_pipeline.py @@ -12,6 +12,7 @@ import numpy as np from pytest import raises +from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_allclose @@ -31,6 +32,7 @@ from imblearn.under_sampling import (RandomUnderSampler, EditedNearestNeighbours as ENN) + JUNK_FOOD_DOCS = ( "the pizza pizza beer copyright", "the pizza burger beer copyright", @@ -1073,3 +1075,15 @@ def test_pipeline_fit_then_sample_3_samplers_with_sampler_last_estimator(): X_fit_then_sample_res, y_fit_then_sample_res = pipeline.sample(X, y) assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_res) assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_res) + + +def test_make_pipeline_memory(): + cachedir = mkdtemp() + try: + memory = Memory(cachedir=cachedir, verbose=10) + pipeline = make_pipeline(DummyTransf(), SVC(), memory=memory) + assert pipeline.memory is memory + pipeline = make_pipeline(DummyTransf(), SVC()) + assert pipeline.memory is None + finally: + shutil.rmtree(cachedir)