diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index a48164b09470e2287c3972ee63a61291575234aa..c90f35753fb0055219df7c558e9510c969d0f504 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -79,6 +79,11 @@ Parameters of the estimators in the pipeline can be accessed using the steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',...)), ('clf', SVC(C=10, cache_size=200, class_weight=None,...))]) +Attributes of named_steps map to keys, enabling tab completion in interactive environments:: + + >>> pipe.named_steps.reduce_dim is pipe.named_steps['reduce_dim'] + True + This is particularly important for doing grid searches:: >>> from sklearn.model_selection import GridSearchCV diff --git a/doc/whats_new.rst b/doc/whats_new.rst index df096d1faec4203b436cb5dce205f6ececb597e9..f2189ba26ca007cac243fedc284fa2757527a5f6 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -277,6 +277,12 @@ API changes summary needed for the perplexity calculation. :issue:`7954` by :user:`Gary Foreman <garyForeman>`. + - Replace attribute ``named_steps`` ``dict`` to :class:`sklearn.utils.Bunch` + in :class:`sklearn.pipeline.Pipeline` to enable tab completion in interactive + environment. In the case conflict value on ``named_steps`` and ``dict`` + attribute, ``dict`` behavior will be prioritized. + :issue:`8481` by :user:`Herilalaina Rakotoarison <herilalaina>`. + - The :func:`sklearn.multioutput.MultiOutputClassifier.predict_proba` function used to return a 3d array (``n_samples``, ``n_classes``, ``n_outputs``). In the case where different target columns had different diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 2325d971428d2f74f5a2708c3b24142f81ab2b19..2ad2bdb16cbfad129997a586d2d41e546a3a1a85 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -20,58 +20,13 @@ from os.path import isdir from os.path import splitext from os import listdir from os import makedirs +from ..utils import Bunch import numpy as np from ..utils import check_random_state -class Bunch(dict): - """Container object for datasets - - Dictionary-like object that exposes its keys as attributes. - - >>> b = Bunch(a=1, b=2) - >>> b['b'] - 2 - >>> b.b - 2 - >>> b.a = 3 - >>> b['a'] - 3 - >>> b.c = 6 - >>> b['c'] - 6 - - """ - - def __init__(self, **kwargs): - super(Bunch, self).__init__(kwargs) - - def __setattr__(self, key, value): - self[key] = value - - def __dir__(self): - return self.keys() - - def __getattr__(self, key): - try: - return self[key] - except KeyError: - raise AttributeError(key) - - def __setstate__(self, state): - # Bunch pickles generated with scikit-learn 0.16.* have an non - # empty __dict__. This causes a surprising behaviour when - # loading these pickles scikit-learn 0.17: reading bunch.key - # uses __dict__ but assigning to bunch.key use __setattr__ and - # only changes bunch['key']. More details can be found at: - # https://github.com/scikit-learn/scikit-learn/issues/6196. - # Overriding __setstate__ to be a noop has the effect of - # ignoring the pickled __dict__ - pass - - def get_data_home(data_home=None): """Return the path of the scikit-learn data dir. diff --git a/sklearn/datasets/california_housing.py b/sklearn/datasets/california_housing.py index 8a74ad9e60e35275ee28e01dd9a02c98754cffb7..8db5e3139d159c7a79cd10a1e70338d79f1447a9 100644 --- a/sklearn/datasets/california_housing.py +++ b/sklearn/datasets/california_housing.py @@ -35,7 +35,8 @@ except ImportError: import numpy as np -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch from .base import _pkl_filepath from ..externals import joblib diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py index 6e0b4d2d0d21c739093d6debb36f008c559cddfd..a165f7c054c62ae135c7f0b243dde92c0b97be8b 100644 --- a/sklearn/datasets/covtype.py +++ b/sklearn/datasets/covtype.py @@ -26,7 +26,7 @@ except ImportError: import numpy as np from .base import get_data_home -from .base import Bunch +from ..utils import Bunch from .base import _pkl_filepath from ..utils.fixes import makedirs from ..externals import joblib diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py index c2ed39caa10a6d51466be44e049098e20840e466..4be96e45605f9e6458ac415e4922b17e725eea2b 100644 --- a/sklearn/datasets/kddcup99.py +++ b/sklearn/datasets/kddcup99.py @@ -23,7 +23,7 @@ except ImportError: import numpy as np from .base import get_data_home -from .base import Bunch +from ..utils import Bunch from ..externals import joblib, six from ..utils import check_random_state from ..utils import shuffle as shuffle_method diff --git a/sklearn/datasets/lfw.py b/sklearn/datasets/lfw.py index 13aaed805b4fbff6c15fc86cd194b66454b55d45..e3406f9e3ce7e2cd5e1d581791ac3e14c67c2550 100644 --- a/sklearn/datasets/lfw.py +++ b/sklearn/datasets/lfw.py @@ -34,7 +34,8 @@ try: except ImportError: import urllib -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch from ..externals.joblib import Memory from ..externals.six import b diff --git a/sklearn/datasets/mldata.py b/sklearn/datasets/mldata.py index 82ae9858e9df620d39d6ff869e1c5ac2a5778726..f5377f203e1da31831f0eab4a1dc14fc7332d626 100644 --- a/sklearn/datasets/mldata.py +++ b/sklearn/datasets/mldata.py @@ -23,7 +23,8 @@ import scipy as sp from scipy import io from shutil import copyfileobj -from .base import get_data_home, Bunch +from .base import get_data_home +from ..utils import Bunch MLDATA_BASE_URL = "http://mldata.org/repository/data/download/matlab/%s" diff --git a/sklearn/datasets/olivetti_faces.py b/sklearn/datasets/olivetti_faces.py index 5f3af040dc1a41ab0e1a0442b66b1dff7720129c..9ecab18c0e5f34e01bd0a16ddd2377c59e23d116 100644 --- a/sklearn/datasets/olivetti_faces.py +++ b/sklearn/datasets/olivetti_faces.py @@ -37,9 +37,9 @@ except ImportError: import numpy as np from scipy.io.matlab import loadmat -from .base import get_data_home, Bunch +from .base import get_data_home from .base import _pkl_filepath -from ..utils import check_random_state +from ..utils import check_random_state, Bunch from ..externals import joblib @@ -80,10 +80,12 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0, An object with the following attributes: data : numpy array of shape (400, 4096) - Each row corresponds to a ravelled face image of original size 64 x 64 pixels. + Each row corresponds to a ravelled face image of original size + 64 x 64 pixels. images : numpy array of shape (400, 64, 64) - Each row is a face image corresponding to one of the 40 subjects of the dataset. + Each row is a face image corresponding to one of the 40 subjects + of the dataset. target : numpy array of shape (400, ) Labels associated to each face image. Those labels are ranging from diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py index 83b4d223cc3619b9a7d5b57bc69703b984c59bb9..ae45764b4042b79f943a7877900a954229cbd24a 100644 --- a/sklearn/datasets/rcv1.py +++ b/sklearn/datasets/rcv1.py @@ -20,12 +20,12 @@ import numpy as np import scipy.sparse as sp from .base import get_data_home -from .base import Bunch from .base import _pkl_filepath from ..utils.fixes import makedirs from ..externals import joblib from .svmlight_format import load_svmlight_files from ..utils import shuffle as shuffle_ +from ..utils import Bunch URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/' diff --git a/sklearn/datasets/species_distributions.py b/sklearn/datasets/species_distributions.py index 330c535620b7d4e84df26ffe293865beb0b46036..f34eb92d3366d73053bccfe7e83020b10b488bc3 100644 --- a/sklearn/datasets/species_distributions.py +++ b/sklearn/datasets/species_distributions.py @@ -50,7 +50,8 @@ except ImportError: import numpy as np -from sklearn.datasets.base import get_data_home, Bunch +from sklearn.datasets.base import get_data_home +from ..utils import Bunch from sklearn.datasets.base import _pkl_filepath from sklearn.externals import joblib diff --git a/sklearn/datasets/twenty_newsgroups.py b/sklearn/datasets/twenty_newsgroups.py index 128610fd2830f03407cbdb36de0184306612abdb..47b543d8d2e16192a6f9d408a46aad6c230015bf 100644 --- a/sklearn/datasets/twenty_newsgroups.py +++ b/sklearn/datasets/twenty_newsgroups.py @@ -47,10 +47,9 @@ import numpy as np import scipy.sparse as sp from .base import get_data_home -from .base import Bunch from .base import load_files from .base import _pkl_filepath -from ..utils import check_random_state +from ..utils import check_random_state, Bunch from ..feature_extraction.text import CountVectorizer from ..preprocessing import normalize from ..externals import joblib, six diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 61d7b12b7564d640f1b8b45fc4f0a33692aa8de5..6dfc7284cc6812a737ea69012ec27e244b562e44 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -20,6 +20,7 @@ from .externals.joblib import Parallel, delayed, Memory from .externals import six from .utils import tosequence from .utils.metaestimators import if_delegate_has_method +from .utils import Bunch __all__ = ['Pipeline', 'FeatureUnion'] @@ -122,7 +123,7 @@ class Pipeline(_BasePipeline): Attributes ---------- - named_steps : dict + named_steps : bunch object, a dictionary with attribute access Read-only attribute to access any step parameter by user given name. Keys are step names and values are steps parameters. @@ -157,7 +158,12 @@ class Pipeline(_BasePipeline): array([False, False, True, True, False, False, True, True, False, True, False, True, True, False, True, False, True, True, False, False], dtype=bool) - + >>> # Another way to get selected features chosen by anova_filter + >>> anova_svm.named_steps.anova.get_support() + ... # doctest: +NORMALIZE_WHITESPACE + array([False, False, True, True, False, False, True, True, False, + True, False, True, True, False, True, False, True, True, + False, False], dtype=bool) """ # BaseEstimator interface @@ -227,7 +233,8 @@ class Pipeline(_BasePipeline): @property def named_steps(self): - return dict(self.steps) + # Use Bunch object to improve autocomplete + return Bunch(**dict(self.steps)) @property def _final_estimator(self): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 33e3128931aff2570a765e76b4d54b77b507c6fa..d4c4844fe375de21ab4ab65f2822a89191d0c572 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -509,6 +509,23 @@ def test_set_pipeline_steps(): assert_raises(TypeError, pipeline.fit_transform, [[1]], [1]) +def test_pipeline_named_steps(): + transf = Transf() + mult2 = Mult(mult=2) + pipeline = Pipeline([('mock', transf), ("mult", mult2)]) + + # Test access via named_steps bunch object + assert_true('mock' in pipeline.named_steps) + assert_true('mock2' not in pipeline.named_steps) + assert_true(pipeline.named_steps.mock is transf) + assert_true(pipeline.named_steps.mult is mult2) + + # Test bunch with conflict attribute of dict + pipeline = Pipeline([('values', transf), ("mult", mult2)]) + assert_true(pipeline.named_steps.values is not transf) + assert_true(pipeline.named_steps.mult is mult2) + + def test_set_pipeline_step_none(): # Test setting Pipeline steps to None X = np.array([[1]]) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index a4e5b6a4f3ea52f703a9dcd5d869b2349db16918..0bc4d6de33c3f17d33e693ce970a36a0f961b237 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -28,6 +28,52 @@ __all__ = ["murmurhash3_32", "as_float_array", "check_symmetric", "indices_to_mask", "deprecated"] +class Bunch(dict): + """Container object for datasets + + Dictionary-like object that exposes its keys as attributes. + + >>> b = Bunch(a=1, b=2) + >>> b['b'] + 2 + >>> b.b + 2 + >>> b.a = 3 + >>> b['a'] + 3 + >>> b.c = 6 + >>> b['c'] + 6 + + """ + + def __init__(self, **kwargs): + super(Bunch, self).__init__(kwargs) + + def __setattr__(self, key, value): + self[key] = value + + def __dir__(self): + return self.keys() + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setstate__(self, state): + # Bunch pickles generated with scikit-learn 0.16.* have an non + # empty __dict__. This causes a surprising behaviour when + # loading these pickles scikit-learn 0.17: reading bunch.key + # uses __dict__ but assigning to bunch.key use __setattr__ and + # only changes bunch['key']. More details can be found at: + # https://github.com/scikit-learn/scikit-learn/issues/6196. + # Overriding __setstate__ to be a noop has the effect of + # ignoring the pickled __dict__ + pass + + def safe_mask(X, mask): """Return a mask which is safe to use on X.