From fb5a498d0bd00fc2b42fbd19b6ef18e1dfeee47e Mon Sep 17 00:00:00 2001
From: RAKOTOARISON Herilalaina <rkt.herilalaina@gmail.com>
Date: Thu, 30 Mar 2017 14:21:51 +0200
Subject: [PATCH] [MRG+1] Change named_steps to Bunch object (#8586)

* Change named_steps to Bunch object

* Update named_steps attribute documentation

* Add test for named steps bunch object

* Delete whitespace in test_pipeline

* Update test_pipeline.py

* Add comment for named_steps usage

* Move dataset/Bunch to utils

* Fix to PEP8 format

* Add __getattribute method to Bunch class, Fix pep8 bug

* Remove __getattribute__, update test_pipeline

* Update test with conflict and non-conflict named_steps

* Add reference to class Pipeline
---
 doc/modules/pipeline.rst                  |  5 +++
 doc/whats_new.rst                         |  6 +++
 sklearn/datasets/base.py                  | 47 +----------------------
 sklearn/datasets/california_housing.py    |  3 +-
 sklearn/datasets/covtype.py               |  2 +-
 sklearn/datasets/kddcup99.py              |  2 +-
 sklearn/datasets/lfw.py                   |  3 +-
 sklearn/datasets/mldata.py                |  3 +-
 sklearn/datasets/olivetti_faces.py        | 10 +++--
 sklearn/datasets/rcv1.py                  |  2 +-
 sklearn/datasets/species_distributions.py |  3 +-
 sklearn/datasets/twenty_newsgroups.py     |  3 +-
 sklearn/pipeline.py                       | 13 +++++--
 sklearn/tests/test_pipeline.py            | 17 ++++++++
 sklearn/utils/__init__.py                 | 46 ++++++++++++++++++++++
 15 files changed, 103 insertions(+), 62 deletions(-)

diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst
index a48164b094..c90f35753f 100644
--- a/doc/modules/pipeline.rst
+++ b/doc/modules/pipeline.rst
@@ -79,6 +79,11 @@ Parameters of the estimators in the pipeline can be accessed using the
              steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',...)),
                     ('clf', SVC(C=10, cache_size=200, class_weight=None,...))])
 
+Attributes of named_steps map to keys, enabling tab completion in interactive environments::
+
+    >>> pipe.named_steps.reduce_dim is pipe.named_steps['reduce_dim']
+    True
+
 This is particularly important for doing grid searches::
 
     >>> from sklearn.model_selection import GridSearchCV
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index df096d1fae..f2189ba26c 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -277,6 +277,12 @@ API changes summary
      needed for the perplexity calculation. :issue:`7954` by
      :user:`Gary Foreman <garyForeman>`.
 
+   - Replace attribute ``named_steps`` ``dict`` to :class:`sklearn.utils.Bunch`
+     in :class:`sklearn.pipeline.Pipeline` to enable tab completion in interactive
+     environment. In the case conflict value on ``named_steps`` and ``dict``
+     attribute, ``dict`` behavior will be prioritized.
+     :issue:`8481` by :user:`Herilalaina Rakotoarison <herilalaina>`.
+
    - The :func:`sklearn.multioutput.MultiOutputClassifier.predict_proba`
      function used to return a 3d array (``n_samples``, ``n_classes``,
      ``n_outputs``). In the case where different target columns had different
diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py
index 2325d97142..2ad2bdb16c 100644
--- a/sklearn/datasets/base.py
+++ b/sklearn/datasets/base.py
@@ -20,58 +20,13 @@ from os.path import isdir
 from os.path import splitext
 from os import listdir
 from os import makedirs
+from ..utils import Bunch
 
 import numpy as np
 
 from ..utils import check_random_state
 
 
-class Bunch(dict):
-    """Container object for datasets
-
-    Dictionary-like object that exposes its keys as attributes.
-
-    >>> b = Bunch(a=1, b=2)
-    >>> b['b']
-    2
-    >>> b.b
-    2
-    >>> b.a = 3
-    >>> b['a']
-    3
-    >>> b.c = 6
-    >>> b['c']
-    6
-
-    """
-
-    def __init__(self, **kwargs):
-        super(Bunch, self).__init__(kwargs)
-
-    def __setattr__(self, key, value):
-        self[key] = value
-
-    def __dir__(self):
-        return self.keys()
-
-    def __getattr__(self, key):
-        try:
-            return self[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    def __setstate__(self, state):
-        # Bunch pickles generated with scikit-learn 0.16.* have an non
-        # empty __dict__. This causes a surprising behaviour when
-        # loading these pickles scikit-learn 0.17: reading bunch.key
-        # uses __dict__ but assigning to bunch.key use __setattr__ and
-        # only changes bunch['key']. More details can be found at:
-        # https://github.com/scikit-learn/scikit-learn/issues/6196.
-        # Overriding __setstate__ to be a noop has the effect of
-        # ignoring the pickled __dict__
-        pass
-
-
 def get_data_home(data_home=None):
     """Return the path of the scikit-learn data dir.
 
diff --git a/sklearn/datasets/california_housing.py b/sklearn/datasets/california_housing.py
index 8a74ad9e60..8db5e3139d 100644
--- a/sklearn/datasets/california_housing.py
+++ b/sklearn/datasets/california_housing.py
@@ -35,7 +35,8 @@ except ImportError:
 
 import numpy as np
 
-from .base import get_data_home, Bunch
+from .base import get_data_home
+from ..utils import Bunch
 from .base import _pkl_filepath
 from ..externals import joblib
 
diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py
index 6e0b4d2d0d..a165f7c054 100644
--- a/sklearn/datasets/covtype.py
+++ b/sklearn/datasets/covtype.py
@@ -26,7 +26,7 @@ except ImportError:
 import numpy as np
 
 from .base import get_data_home
-from .base import Bunch
+from ..utils import Bunch
 from .base import _pkl_filepath
 from ..utils.fixes import makedirs
 from ..externals import joblib
diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py
index c2ed39caa1..4be96e4560 100644
--- a/sklearn/datasets/kddcup99.py
+++ b/sklearn/datasets/kddcup99.py
@@ -23,7 +23,7 @@ except ImportError:
 import numpy as np
 
 from .base import get_data_home
-from .base import Bunch
+from ..utils import Bunch
 from ..externals import joblib, six
 from ..utils import check_random_state
 from ..utils import shuffle as shuffle_method
diff --git a/sklearn/datasets/lfw.py b/sklearn/datasets/lfw.py
index 13aaed805b..e3406f9e3c 100644
--- a/sklearn/datasets/lfw.py
+++ b/sklearn/datasets/lfw.py
@@ -34,7 +34,8 @@ try:
 except ImportError:
     import urllib
 
-from .base import get_data_home, Bunch
+from .base import get_data_home
+from ..utils import Bunch
 from ..externals.joblib import Memory
 
 from ..externals.six import b
diff --git a/sklearn/datasets/mldata.py b/sklearn/datasets/mldata.py
index 82ae9858e9..f5377f203e 100644
--- a/sklearn/datasets/mldata.py
+++ b/sklearn/datasets/mldata.py
@@ -23,7 +23,8 @@ import scipy as sp
 from scipy import io
 from shutil import copyfileobj
 
-from .base import get_data_home, Bunch
+from .base import get_data_home
+from ..utils import Bunch
 
 MLDATA_BASE_URL = "http://mldata.org/repository/data/download/matlab/%s"
 
diff --git a/sklearn/datasets/olivetti_faces.py b/sklearn/datasets/olivetti_faces.py
index 5f3af040dc..9ecab18c0e 100644
--- a/sklearn/datasets/olivetti_faces.py
+++ b/sklearn/datasets/olivetti_faces.py
@@ -37,9 +37,9 @@ except ImportError:
 import numpy as np
 from scipy.io.matlab import loadmat
 
-from .base import get_data_home, Bunch
+from .base import get_data_home
 from .base import _pkl_filepath
-from ..utils import check_random_state
+from ..utils import check_random_state, Bunch
 from ..externals import joblib
 
 
@@ -80,10 +80,12 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0,
     An object with the following attributes:
 
     data : numpy array of shape (400, 4096)
-        Each row corresponds to a ravelled face image of original size 64 x 64 pixels.
+        Each row corresponds to a ravelled face image of original size
+        64 x 64 pixels.
 
     images : numpy array of shape (400, 64, 64)
-        Each row is a face image corresponding to one of the 40 subjects of the dataset.
+        Each row is a face image corresponding to one of the 40 subjects
+        of the dataset.
 
     target : numpy array of shape (400, )
         Labels associated to each face image. Those labels are ranging from
diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py
index 83b4d223cc..ae45764b40 100644
--- a/sklearn/datasets/rcv1.py
+++ b/sklearn/datasets/rcv1.py
@@ -20,12 +20,12 @@ import numpy as np
 import scipy.sparse as sp
 
 from .base import get_data_home
-from .base import Bunch
 from .base import _pkl_filepath
 from ..utils.fixes import makedirs
 from ..externals import joblib
 from .svmlight_format import load_svmlight_files
 from ..utils import shuffle as shuffle_
+from ..utils import Bunch
 
 
 URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/'
diff --git a/sklearn/datasets/species_distributions.py b/sklearn/datasets/species_distributions.py
index 330c535620..f34eb92d33 100644
--- a/sklearn/datasets/species_distributions.py
+++ b/sklearn/datasets/species_distributions.py
@@ -50,7 +50,8 @@ except ImportError:
 
 import numpy as np
 
-from sklearn.datasets.base import get_data_home, Bunch
+from sklearn.datasets.base import get_data_home
+from ..utils import Bunch
 from sklearn.datasets.base import _pkl_filepath
 from sklearn.externals import joblib
 
diff --git a/sklearn/datasets/twenty_newsgroups.py b/sklearn/datasets/twenty_newsgroups.py
index 128610fd28..47b543d8d2 100644
--- a/sklearn/datasets/twenty_newsgroups.py
+++ b/sklearn/datasets/twenty_newsgroups.py
@@ -47,10 +47,9 @@ import numpy as np
 import scipy.sparse as sp
 
 from .base import get_data_home
-from .base import Bunch
 from .base import load_files
 from .base import _pkl_filepath
-from ..utils import check_random_state
+from ..utils import check_random_state, Bunch
 from ..feature_extraction.text import CountVectorizer
 from ..preprocessing import normalize
 from ..externals import joblib, six
diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py
index 61d7b12b75..6dfc7284cc 100644
--- a/sklearn/pipeline.py
+++ b/sklearn/pipeline.py
@@ -20,6 +20,7 @@ from .externals.joblib import Parallel, delayed, Memory
 from .externals import six
 from .utils import tosequence
 from .utils.metaestimators import if_delegate_has_method
+from .utils import Bunch
 
 __all__ = ['Pipeline', 'FeatureUnion']
 
@@ -122,7 +123,7 @@ class Pipeline(_BasePipeline):
 
     Attributes
     ----------
-    named_steps : dict
+    named_steps : bunch object, a dictionary with attribute access
         Read-only attribute to access any step parameter by user given name.
         Keys are step names and values are steps parameters.
 
@@ -157,7 +158,12 @@ class Pipeline(_BasePipeline):
     array([False, False,  True,  True, False, False, True,  True, False,
            True,  False,  True,  True, False, True,  False, True, True,
            False, False], dtype=bool)
-
+    >>> # Another way to get selected features chosen by anova_filter
+    >>> anova_svm.named_steps.anova.get_support()
+    ... # doctest: +NORMALIZE_WHITESPACE
+    array([False, False,  True,  True, False, False, True,  True, False,
+           True,  False,  True,  True, False, True,  False, True, True,
+           False, False], dtype=bool)
     """
 
     # BaseEstimator interface
@@ -227,7 +233,8 @@ class Pipeline(_BasePipeline):
 
     @property
     def named_steps(self):
-        return dict(self.steps)
+        # Use Bunch object to improve autocomplete
+        return Bunch(**dict(self.steps))
 
     @property
     def _final_estimator(self):
diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py
index 33e3128931..d4c4844fe3 100644
--- a/sklearn/tests/test_pipeline.py
+++ b/sklearn/tests/test_pipeline.py
@@ -509,6 +509,23 @@ def test_set_pipeline_steps():
     assert_raises(TypeError, pipeline.fit_transform, [[1]], [1])
 
 
+def test_pipeline_named_steps():
+    transf = Transf()
+    mult2 = Mult(mult=2)
+    pipeline = Pipeline([('mock', transf), ("mult", mult2)])
+
+    # Test access via named_steps bunch object
+    assert_true('mock' in pipeline.named_steps)
+    assert_true('mock2' not in pipeline.named_steps)
+    assert_true(pipeline.named_steps.mock is transf)
+    assert_true(pipeline.named_steps.mult is mult2)
+
+    # Test bunch with conflict attribute of dict
+    pipeline = Pipeline([('values', transf), ("mult", mult2)])
+    assert_true(pipeline.named_steps.values is not transf)
+    assert_true(pipeline.named_steps.mult is mult2)
+
+
 def test_set_pipeline_step_none():
     # Test setting Pipeline steps to None
     X = np.array([[1]])
diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py
index a4e5b6a4f3..0bc4d6de33 100644
--- a/sklearn/utils/__init__.py
+++ b/sklearn/utils/__init__.py
@@ -28,6 +28,52 @@ __all__ = ["murmurhash3_32", "as_float_array",
            "check_symmetric", "indices_to_mask", "deprecated"]
 
 
+class Bunch(dict):
+    """Container object for datasets
+
+    Dictionary-like object that exposes its keys as attributes.
+
+    >>> b = Bunch(a=1, b=2)
+    >>> b['b']
+    2
+    >>> b.b
+    2
+    >>> b.a = 3
+    >>> b['a']
+    3
+    >>> b.c = 6
+    >>> b['c']
+    6
+
+    """
+
+    def __init__(self, **kwargs):
+        super(Bunch, self).__init__(kwargs)
+
+    def __setattr__(self, key, value):
+        self[key] = value
+
+    def __dir__(self):
+        return self.keys()
+
+    def __getattr__(self, key):
+        try:
+            return self[key]
+        except KeyError:
+            raise AttributeError(key)
+
+    def __setstate__(self, state):
+        # Bunch pickles generated with scikit-learn 0.16.* have an non
+        # empty __dict__. This causes a surprising behaviour when
+        # loading these pickles scikit-learn 0.17: reading bunch.key
+        # uses __dict__ but assigning to bunch.key use __setattr__ and
+        # only changes bunch['key']. More details can be found at:
+        # https://github.com/scikit-learn/scikit-learn/issues/6196.
+        # Overriding __setstate__ to be a noop has the effect of
+        # ignoring the pickled __dict__
+        pass
+
+
 def safe_mask(X, mask):
     """Return a mask which is safe to use on X.
 
-- 
GitLab