From fb5a498d0bd00fc2b42fbd19b6ef18e1dfeee47e Mon Sep 17 00:00:00 2001 From: RAKOTOARISON Herilalaina <rkt.herilalaina@gmail.com> Date: Thu, 30 Mar 2017 14:21:51 +0200 Subject: [PATCH] [MRG+1] Change named_steps to Bunch object (#8586) * Change named_steps to Bunch object * Update named_steps attribute documentation * Add test for named steps bunch object * Delete whitespace in test_pipeline * Update test_pipeline.py * Add comment for named_steps usage * Move dataset/Bunch to utils * Fix to PEP8 format * Add __getattribute method to Bunch class, Fix pep8 bug * Remove __getattribute__, update test_pipeline * Update test with conflict and non-conflict named_steps * Add reference to class Pipeline --- doc/modules/pipeline.rst | 5 +++ doc/whats_new.rst | 6 +++ sklearn/datasets/base.py | 47 +---------------------- sklearn/datasets/california_housing.py | 3 +- sklearn/datasets/covtype.py | 2 +- sklearn/datasets/kddcup99.py | 2 +- sklearn/datasets/lfw.py | 3 +- sklearn/datasets/mldata.py | 3 +- sklearn/datasets/olivetti_faces.py | 10 +++-- sklearn/datasets/rcv1.py | 2 +- sklearn/datasets/species_distributions.py | 3 +- sklearn/datasets/twenty_newsgroups.py | 3 +- sklearn/pipeline.py | 13 +++++-- sklearn/tests/test_pipeline.py | 17 ++++++++ sklearn/utils/__init__.py | 46 ++++++++++++++++++++++ 15 files changed, 103 insertions(+), 62 deletions(-) diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index a48164b094..c90f35753f 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -79,6 +79,11 @@ Parameters of the estimators in the pipeline can be accessed using the steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',...)), ('clf', SVC(C=10, cache_size=200, class_weight=None,...))]) +Attributes of named_steps map to keys, enabling tab completion in interactive environments:: + + >>> pipe.named_steps.reduce_dim is pipe.named_steps['reduce_dim'] + True + This is particularly important for doing grid searches:: >>> from sklearn.model_selection import GridSearchCV diff --git a/doc/whats_new.rst b/doc/whats_new.rst index df096d1fae..f2189ba26c 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -277,6 +277,12 @@ API changes summary needed for the perplexity calculation. :issue:`7954` by :user:`Gary Foreman <garyForeman>`. + - Replace attribute ``named_steps`` ``dict`` to :class:`sklearn.utils.Bunch` + in :class:`sklearn.pipeline.Pipeline` to enable tab completion in interactive + environment. In the case conflict value on ``named_steps`` and ``dict`` + attribute, ``dict`` behavior will be prioritized. + :issue:`8481` by :user:`Herilalaina Rakotoarison <herilalaina>`. + - The :func:`sklearn.multioutput.MultiOutputClassifier.predict_proba` function used to return a 3d array (``n_samples``, ``n_classes``, ``n_outputs``). In the case where different target columns had different diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 2325d97142..2ad2bdb16c 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -20,58 +20,13 @@ from os.path import isdir from os.path import splitext from os import listdir from os import makedirs +from ..utils import Bunch import numpy as np from ..utils import check_random_state -class Bunch(dict): - """Container object for datasets - - Dictionary-like object that exposes its keys as attributes. - - >>> b = Bunch(a=1, b=2) - >>> b['b'] - 2 - >>> b.b - 2 - >>> b.a = 3 - >>> b['a'] - 3 - >>> b.c = 6 - >>> b['c'] - 6 - - """ - - def __init__(self, **kwargs): - super(Bunch, self).__init__(kwargs) - - def __setattr__(self, key, value): - self[key] = value - - def __dir__(self): - return self.keys() - - def __getattr__(self, key): - try: - return self[key] - except KeyError: - raise AttributeError(key) - - def __setstate__(self, state): - # Bunch pickles generated with scikit-learn 0.16.* have an non - # empty __dict__. This causes a surprising behaviour when - # loading these pickles scikit-learn 0.17: reading bunch.key - # uses __dict__ but assigning to bunch.key use __setattr__ and - # only changes bunch['key']. More details can be found at: - # https://github.com/scikit-learn/scikit-learn/issues/6196. - # Overriding __setstate__ to be a noop has the effect of - # ignoring the pickled __dict__ - pass - - def get_data_home(data_home=None): """Return the path of the scikit-learn data dir. diff --git a/sklearn/datasets/california_housing.py b/sklearn/datasets/california_housing.py index 8a74ad9e60..8db5e3139d 100644 --- a/sklearn/datasets/california_housing.py +++ b/sklearn/datasets/california_housing.py @@ -35,7 +35,8 @@ except ImportError: import numpy as np -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch from .base import _pkl_filepath from ..externals import joblib diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py index 6e0b4d2d0d..a165f7c054 100644 --- a/sklearn/datasets/covtype.py +++ b/sklearn/datasets/covtype.py @@ -26,7 +26,7 @@ except ImportError: import numpy as np from .base import get_data_home -from .base import Bunch +from ..utils import Bunch from .base import _pkl_filepath from ..utils.fixes import makedirs from ..externals import joblib diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py index c2ed39caa1..4be96e4560 100644 --- a/sklearn/datasets/kddcup99.py +++ b/sklearn/datasets/kddcup99.py @@ -23,7 +23,7 @@ except ImportError: import numpy as np from .base import get_data_home -from .base import Bunch +from ..utils import Bunch from ..externals import joblib, six from ..utils import check_random_state from ..utils import shuffle as shuffle_method diff --git a/sklearn/datasets/lfw.py b/sklearn/datasets/lfw.py index 13aaed805b..e3406f9e3c 100644 --- a/sklearn/datasets/lfw.py +++ b/sklearn/datasets/lfw.py @@ -34,7 +34,8 @@ try: except ImportError: import urllib -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch from ..externals.joblib import Memory from ..externals.six import b diff --git a/sklearn/datasets/mldata.py b/sklearn/datasets/mldata.py index 82ae9858e9..f5377f203e 100644 --- a/sklearn/datasets/mldata.py +++ b/sklearn/datasets/mldata.py @@ -23,7 +23,8 @@ import scipy as sp from scipy import io from shutil import copyfileobj -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch MLDATA_BASE_URL = "http://mldata.org/repository/data/download/matlab/%s" diff --git a/sklearn/datasets/olivetti_faces.py b/sklearn/datasets/olivetti_faces.py index 5f3af040dc..9ecab18c0e 100644 --- a/sklearn/datasets/olivetti_faces.py +++ b/sklearn/datasets/olivetti_faces.py @@ -37,9 +37,9 @@ except ImportError: import numpy as np from scipy.io.matlab import loadmat -from .base import get_data_home, Bunch +from .base import get_data_home from .base import _pkl_filepath -from ..utils import check_random_state +from ..utils import check_random_state, Bunch from ..externals import joblib @@ -80,10 +80,12 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0, An object with the following attributes: data : numpy array of shape (400, 4096) - Each row corresponds to a ravelled face image of original size 64 x 64 pixels. + Each row corresponds to a ravelled face image of original size + 64 x 64 pixels. images : numpy array of shape (400, 64, 64) - Each row is a face image corresponding to one of the 40 subjects of the dataset. + Each row is a face image corresponding to one of the 40 subjects + of the dataset. target : numpy array of shape (400, ) Labels associated to each face image. Those labels are ranging from diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py index 83b4d223cc..ae45764b40 100644 --- a/sklearn/datasets/rcv1.py +++ b/sklearn/datasets/rcv1.py @@ -20,12 +20,12 @@ import numpy as np import scipy.sparse as sp from .base import get_data_home -from .base import Bunch from .base import _pkl_filepath from ..utils.fixes import makedirs from ..externals import joblib from .svmlight_format import load_svmlight_files from ..utils import shuffle as shuffle_ +from ..utils import Bunch URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/' diff --git a/sklearn/datasets/species_distributions.py b/sklearn/datasets/species_distributions.py index 330c535620..f34eb92d33 100644 --- a/sklearn/datasets/species_distributions.py +++ b/sklearn/datasets/species_distributions.py @@ -50,7 +50,8 @@ except ImportError: import numpy as np -from sklearn.datasets.base import get_data_home, Bunch +from sklearn.datasets.base import get_data_home +from ..utils import Bunch from sklearn.datasets.base import _pkl_filepath from sklearn.externals import joblib diff --git a/sklearn/datasets/twenty_newsgroups.py b/sklearn/datasets/twenty_newsgroups.py index 128610fd28..47b543d8d2 100644 --- a/sklearn/datasets/twenty_newsgroups.py +++ b/sklearn/datasets/twenty_newsgroups.py @@ -47,10 +47,9 @@ import numpy as np import scipy.sparse as sp from .base import get_data_home -from .base import Bunch from .base import load_files from .base import _pkl_filepath -from ..utils import check_random_state +from ..utils import check_random_state, Bunch from ..feature_extraction.text import CountVectorizer from ..preprocessing import normalize from ..externals import joblib, six diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 61d7b12b75..6dfc7284cc 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -20,6 +20,7 @@ from .externals.joblib import Parallel, delayed, Memory from .externals import six from .utils import tosequence from .utils.metaestimators import if_delegate_has_method +from .utils import Bunch __all__ = ['Pipeline', 'FeatureUnion'] @@ -122,7 +123,7 @@ class Pipeline(_BasePipeline): Attributes ---------- - named_steps : dict + named_steps : bunch object, a dictionary with attribute access Read-only attribute to access any step parameter by user given name. Keys are step names and values are steps parameters. @@ -157,7 +158,12 @@ class Pipeline(_BasePipeline): array([False, False, True, True, False, False, True, True, False, True, False, True, True, False, True, False, True, True, False, False], dtype=bool) - + >>> # Another way to get selected features chosen by anova_filter + >>> anova_svm.named_steps.anova.get_support() + ... # doctest: +NORMALIZE_WHITESPACE + array([False, False, True, True, False, False, True, True, False, + True, False, True, True, False, True, False, True, True, + False, False], dtype=bool) """ # BaseEstimator interface @@ -227,7 +233,8 @@ class Pipeline(_BasePipeline): @property def named_steps(self): - return dict(self.steps) + # Use Bunch object to improve autocomplete + return Bunch(**dict(self.steps)) @property def _final_estimator(self): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 33e3128931..d4c4844fe3 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -509,6 +509,23 @@ def test_set_pipeline_steps(): assert_raises(TypeError, pipeline.fit_transform, [[1]], [1]) +def test_pipeline_named_steps(): + transf = Transf() + mult2 = Mult(mult=2) + pipeline = Pipeline([('mock', transf), ("mult", mult2)]) + + # Test access via named_steps bunch object + assert_true('mock' in pipeline.named_steps) + assert_true('mock2' not in pipeline.named_steps) + assert_true(pipeline.named_steps.mock is transf) + assert_true(pipeline.named_steps.mult is mult2) + + # Test bunch with conflict attribute of dict + pipeline = Pipeline([('values', transf), ("mult", mult2)]) + assert_true(pipeline.named_steps.values is not transf) + assert_true(pipeline.named_steps.mult is mult2) + + def test_set_pipeline_step_none(): # Test setting Pipeline steps to None X = np.array([[1]]) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index a4e5b6a4f3..0bc4d6de33 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -28,6 +28,52 @@ __all__ = ["murmurhash3_32", "as_float_array", "check_symmetric", "indices_to_mask", "deprecated"] +class Bunch(dict): + """Container object for datasets + + Dictionary-like object that exposes its keys as attributes. + + >>> b = Bunch(a=1, b=2) + >>> b['b'] + 2 + >>> b.b + 2 + >>> b.a = 3 + >>> b['a'] + 3 + >>> b.c = 6 + >>> b['c'] + 6 + + """ + + def __init__(self, **kwargs): + super(Bunch, self).__init__(kwargs) + + def __setattr__(self, key, value): + self[key] = value + + def __dir__(self): + return self.keys() + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setstate__(self, state): + # Bunch pickles generated with scikit-learn 0.16.* have an non + # empty __dict__. This causes a surprising behaviour when + # loading these pickles scikit-learn 0.17: reading bunch.key + # uses __dict__ but assigning to bunch.key use __setattr__ and + # only changes bunch['key']. More details can be found at: + # https://github.com/scikit-learn/scikit-learn/issues/6196. + # Overriding __setstate__ to be a noop has the effect of + # ignoring the pickled __dict__ + pass + + def safe_mask(X, mask): """Return a mask which is safe to use on X. -- GitLab