From eee8be490b0994a99d911ad8bab2c2e59dc4e123 Mon Sep 17 00:00:00 2001
From: Stephen Hoover <shoover@civisanalytics.com>
Date: Fri, 10 Feb 2017 01:36:20 -0600
Subject: [PATCH] [MRG+1] Add classes_ parameter to hyperparameter CV classes
 (#8295)

---
 doc/whats_new.rst                             |  9 ++++--
 sklearn/model_selection/_search.py            |  5 ++++
 sklearn/model_selection/tests/test_search.py  | 30 ++++++++++++++++++-
 .../model_selection/tests/test_validation.py  | 16 ++++++++--
 4 files changed, 53 insertions(+), 7 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 6be337bbe6..7a93e8feee 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -66,9 +66,12 @@ Enhancements
      now uses significantly less memory when assigning data points to their
      nearest cluster center. :issue:`7721` by :user:`Jon Crall <Erotemic>`.
 
-   - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
-     that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661`
-     by :user:`Alyssa Batula <abatula>` and :user:`Dylan Werner-Meier <unautre>`.
+   - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`,
+     :class:`model_selection.RandomizedSearchCV`,  :class:`grid_search.GridSearchCV`,
+     and  :class:`grid_search.RandomizedSearchCV` that matches the ``classes_``
+     attribute of ``best_estimator_``. :issue:`7661` and :issue:`8295`
+     by :user:`Alyssa Batula <abatula>`, :user:`Dylan Werner-Meier <unautre>`,
+     and :user:`Stephen Hoover <stephen-hoover>`.
 
    - The ``min_weight_fraction_leaf`` constraint in tree construction is now
      more efficient, taking a fast path to declare a node a leaf if its weight
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 3d5846596f..7c7224af47 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -532,6 +532,11 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('inverse_transform')
         return self.best_estimator_.transform(Xt)
 
+    @property
+    def classes_(self):
+        self._check_is_fitted("classes_")
+        return self.best_estimator_.classes_
+
     def fit(self, X, y=None, groups=None, **fit_params):
         """Run fit with all sets of parameters.
 
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 29bb29264d..98e92aa515 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -59,7 +59,7 @@ from sklearn.metrics import make_scorer
 from sklearn.metrics import roc_auc_score
 from sklearn.preprocessing import Imputer
 from sklearn.pipeline import Pipeline
-from sklearn.linear_model import SGDClassifier
+from sklearn.linear_model import Ridge, SGDClassifier
 
 from sklearn.model_selection.tests.common import OneTimeSplitter
 
@@ -73,6 +73,7 @@ class MockClassifier(object):
 
     def fit(self, X, Y):
         assert_true(len(X) == len(Y))
+        self.classes_ = np.unique(Y)
         return self
 
     def predict(self, T):
@@ -323,6 +324,33 @@ def test_grid_search_groups():
         gs.fit(X, y)
 
 
+def test_classes__property():
+    # Test that classes_ property matches best_estimator_.classes_
+    X = np.arange(100).reshape(10, 10)
+    y = np.array([0] * 5 + [1] * 5)
+    Cs = [.1, 1, 10]
+
+    grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
+    grid_search.fit(X, y)
+    assert_array_equal(grid_search.best_estimator_.classes_,
+                       grid_search.classes_)
+
+    # Test that regressors do not have a classes_ attribute
+    grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
+    grid_search.fit(X, y)
+    assert_false(hasattr(grid_search, 'classes_'))
+
+    # Test that the grid searcher has no classes_ attribute before it's fit
+    grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
+    assert_false(hasattr(grid_search, 'classes_'))
+
+    # Test that the grid searcher has no classes_ attribute without a refit
+    grid_search = GridSearchCV(LinearSVC(random_state=0),
+                               {'C': Cs}, refit=False)
+    grid_search.fit(X, y)
+    assert_false(hasattr(grid_search, 'classes_'))
+
+
 def test_trivial_cv_results_attr():
     # Test search over a "grid" with only one point.
     # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV.
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index c6ae5f3fdd..cc6f5973a0 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -62,6 +62,7 @@ from sklearn.datasets import make_classification
 from sklearn.datasets import make_multilabel_classification
 
 from sklearn.model_selection.tests.common import OneTimeSplitter
+from sklearn.model_selection import GridSearchCV
 
 
 try:
@@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction():
     assert_array_almost_equal(preds_sparse, preds)
 
 
-def test_cross_val_predict_with_method():
+def check_cross_val_predict_with_method(est):
     iris = load_iris()
     X, y = iris.data, iris.target
     X, y = shuffle(X, y, random_state=0)
@@ -924,8 +925,6 @@ def test_cross_val_predict_with_method():
 
     methods = ['decision_function', 'predict_proba', 'predict_log_proba']
     for method in methods:
-        est = LogisticRegression()
-
         predictions = cross_val_predict(est, X, y, method=method)
         assert_equal(len(predictions), len(y))
 
@@ -955,6 +954,17 @@ def test_cross_val_predict_with_method():
         assert_array_equal(predictions, predictions_ystr)
 
 
+def test_cross_val_predict_with_method():
+    check_cross_val_predict_with_method(LogisticRegression())
+
+
+def test_gridsearchcv_cross_val_predict_with_method():
+    est = GridSearchCV(LogisticRegression(random_state=42),
+                       {'C': [0.1, 1]},
+                       cv=2)
+    check_cross_val_predict_with_method(est)
+
+
 def get_expected_predictions(X, y, cv, classes, est, method):
 
     expected_predictions = np.zeros([len(y), classes])
-- 
GitLab