diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6be337bbe676525fe2fcfbffd7b29f04e0a42d31..7a93e8feee74ac88a19a761685b2736fcf81bddd 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -66,9 +66,12 @@ Enhancements now uses significantly less memory when assigning data points to their nearest cluster center. :issue:`7721` by :user:`Jon Crall <Erotemic>`. - - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV` - that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661` - by :user:`Alyssa Batula <abatula>` and :user:`Dylan Werner-Meier <unautre>`. + - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`, + :class:`model_selection.RandomizedSearchCV`, :class:`grid_search.GridSearchCV`, + and :class:`grid_search.RandomizedSearchCV` that matches the ``classes_`` + attribute of ``best_estimator_``. :issue:`7661` and :issue:`8295` + by :user:`Alyssa Batula <abatula>`, :user:`Dylan Werner-Meier <unautre>`, + and :user:`Stephen Hoover <stephen-hoover>`. - The ``min_weight_fraction_leaf`` constraint in tree construction is now more efficient, taking a fast path to declare a node a leaf if its weight diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 3d5846596f82b9c74208b9ee03cebc091a609e09..7c7224af474b8a627749413233e9270cdbfebcb2 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -532,6 +532,11 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('inverse_transform') return self.best_estimator_.transform(Xt) + @property + def classes_(self): + self._check_is_fitted("classes_") + return self.best_estimator_.classes_ + def fit(self, X, y=None, groups=None, **fit_params): """Run fit with all sets of parameters. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 29bb29264dad6ebabb91b23437fa3fefb7e2dd3d..98e92aa5154f601f4c37a226cdc5b988bf76ebb2 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -59,7 +59,7 @@ from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score from sklearn.preprocessing import Imputer from sklearn.pipeline import Pipeline -from sklearn.linear_model import SGDClassifier +from sklearn.linear_model import Ridge, SGDClassifier from sklearn.model_selection.tests.common import OneTimeSplitter @@ -73,6 +73,7 @@ class MockClassifier(object): def fit(self, X, Y): assert_true(len(X) == len(Y)) + self.classes_ = np.unique(Y) return self def predict(self, T): @@ -323,6 +324,33 @@ def test_grid_search_groups(): gs.fit(X, y) +def test_classes__property(): + # Test that classes_ property matches best_estimator_.classes_ + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + Cs = [.1, 1, 10] + + grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) + grid_search.fit(X, y) + assert_array_equal(grid_search.best_estimator_.classes_, + grid_search.classes_) + + # Test that regressors do not have a classes_ attribute + grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]}) + grid_search.fit(X, y) + assert_false(hasattr(grid_search, 'classes_')) + + # Test that the grid searcher has no classes_ attribute before it's fit + grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) + assert_false(hasattr(grid_search, 'classes_')) + + # Test that the grid searcher has no classes_ attribute without a refit + grid_search = GridSearchCV(LinearSVC(random_state=0), + {'C': Cs}, refit=False) + grid_search.fit(X, y) + assert_false(hasattr(grid_search, 'classes_')) + + def test_trivial_cv_results_attr(): # Test search over a "grid" with only one point. # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV. diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index c6ae5f3fdd18a553e6de886957c95f607f0cfefc..cc6f5973a0b096c81fe1c9487242636ef4e8fa75 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -62,6 +62,7 @@ from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification from sklearn.model_selection.tests.common import OneTimeSplitter +from sklearn.model_selection import GridSearchCV try: @@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction(): assert_array_almost_equal(preds_sparse, preds) -def test_cross_val_predict_with_method(): +def check_cross_val_predict_with_method(est): iris = load_iris() X, y = iris.data, iris.target X, y = shuffle(X, y, random_state=0) @@ -924,8 +925,6 @@ def test_cross_val_predict_with_method(): methods = ['decision_function', 'predict_proba', 'predict_log_proba'] for method in methods: - est = LogisticRegression() - predictions = cross_val_predict(est, X, y, method=method) assert_equal(len(predictions), len(y)) @@ -955,6 +954,17 @@ def test_cross_val_predict_with_method(): assert_array_equal(predictions, predictions_ystr) +def test_cross_val_predict_with_method(): + check_cross_val_predict_with_method(LogisticRegression()) + + +def test_gridsearchcv_cross_val_predict_with_method(): + est = GridSearchCV(LogisticRegression(random_state=42), + {'C': [0.1, 1]}, + cv=2) + check_cross_val_predict_with_method(est) + + def get_expected_predictions(X, y, cv, classes, est, method): expected_predictions = np.zeros([len(y), classes])