From 3a0ea190c58394b080e84d200486c3f864822481 Mon Sep 17 00:00:00 2001
From: Stephen Hoover <shoover@civisanalytics.com>
Date: Thu, 9 Feb 2017 00:42:23 -0600
Subject: [PATCH] [MRG+1] Accept keyword parameters to hyperparameter search
 fit methods (#8278)

* ENH Accept keyword parameters to hyperparameter search fit methods

Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``.

* CR: Expanded tests, remove deprecated use in Ridge

* Make tests consistent in Python 2 and 3
---
 doc/whats_new.rst                            | 11 ++++
 sklearn/linear_model/ridge.py                | 22 +++----
 sklearn/linear_model/tests/test_ridge.py     |  6 +-
 sklearn/model_selection/_search.py           | 28 +++++---
 sklearn/model_selection/tests/test_search.py | 69 ++++++++++++++++++++
 sklearn/utils/mocking.py                     | 15 ++++-
 6 files changed, 122 insertions(+), 29 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index c0467a5cbf..6be337bbe6 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -222,6 +222,17 @@ API changes summary
      (``n_samples``, ``n_classes``) for that particular output.
      :issue:`8093` by :user:`Peter Bull <pjbull>`.
 
+    - Deprecate the ``fit_params`` constructor input to the
+      :class:`sklearn.model_selection.GridSearchCV` and
+      :class:`sklearn.model_selection.RandomizedSearchCV` in favor
+      of passing keyword parameters to the ``fit`` methods
+      of those classes. Data-dependent parameters needed for model
+      training should be passed as keyword arguments to ``fit``,
+      and conforming to this convention will allow the hyperparameter
+      selection classes to be used with tools such as
+      :func:`sklearn.model_selection.cross_val_predict`.
+      :issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.
+
 .. _changes_0_18_1:
 
 Version 0.18.1
diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py
index 84ec97056a..3b89434ac6 100644
--- a/sklearn/linear_model/ridge.py
+++ b/sklearn/linear_model/ridge.py
@@ -213,7 +213,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
         Regularization strength; must be a positive float. Regularization
         improves the conditioning of the problem and reduces the variance of
         the estimates. Larger values specify stronger regularization.
-        Alpha corresponds to ``C^-1`` in other linear models such as 
+        Alpha corresponds to ``C^-1`` in other linear models such as
         LogisticRegression or LinearSVC. If an array is passed, penalties are
         assumed to be specific to the targets. Hence they must correspond in
         number.
@@ -508,7 +508,7 @@ class Ridge(_BaseRidge, RegressorMixin):
         Regularization strength; must be a positive float. Regularization
         improves the conditioning of the problem and reduces the variance of
         the estimates. Larger values specify stronger regularization.
-        Alpha corresponds to ``C^-1`` in other linear models such as 
+        Alpha corresponds to ``C^-1`` in other linear models such as
         LogisticRegression or LinearSVC. If an array is passed, penalties are
         assumed to be specific to the targets. Hence they must correspond in
         number.
@@ -653,7 +653,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
         Regularization strength; must be a positive float. Regularization
         improves the conditioning of the problem and reduces the variance of
         the estimates. Larger values specify stronger regularization.
-        Alpha corresponds to ``C^-1`` in other linear models such as 
+        Alpha corresponds to ``C^-1`` in other linear models such as
         LogisticRegression or LinearSVC.
 
     class_weight : dict or 'balanced', optional
@@ -1090,11 +1090,9 @@ class _BaseRidgeCV(LinearModel):
                 raise ValueError("cv!=None and store_cv_values=True "
                                  " are incompatible")
             parameters = {'alpha': self.alphas}
-            fit_params = {'sample_weight': sample_weight}
             gs = GridSearchCV(Ridge(fit_intercept=self.fit_intercept),
-                              parameters, fit_params=fit_params, cv=self.cv,
-                              scoring=self.scoring)
-            gs.fit(X, y)
+                              parameters, cv=self.cv, scoring=self.scoring)
+            gs.fit(X, y, sample_weight=sample_weight)
             estimator = gs.best_estimator_
             self.alpha_ = gs.best_estimator_.alpha
 
@@ -1119,8 +1117,8 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
         Regularization strength; must be a positive float. Regularization
         improves the conditioning of the problem and reduces the variance of
         the estimates. Larger values specify stronger regularization.
-        Alpha corresponds to ``C^-1`` in other linear models such as 
-        LogisticRegression or LinearSVC. 
+        Alpha corresponds to ``C^-1`` in other linear models such as
+        LogisticRegression or LinearSVC.
 
     fit_intercept : boolean
         Whether to calculate the intercept for this model. If set
@@ -1152,7 +1150,7 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
         - An iterable yielding train/test splits.
 
         For integer/None inputs, if ``y`` is binary or multiclass,
-        :class:`sklearn.model_selection.StratifiedKFold` is used, else, 
+        :class:`sklearn.model_selection.StratifiedKFold` is used, else,
         :class:`sklearn.model_selection.KFold` is used.
 
         Refer :ref:`User Guide <cross_validation>` for the various
@@ -1222,8 +1220,8 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
         Regularization strength; must be a positive float. Regularization
         improves the conditioning of the problem and reduces the variance of
         the estimates. Larger values specify stronger regularization.
-        Alpha corresponds to ``C^-1`` in other linear models such as 
-        LogisticRegression or LinearSVC. 
+        Alpha corresponds to ``C^-1`` in other linear models such as
+        LogisticRegression or LinearSVC.
 
     fit_intercept : boolean
         Whether to calculate the intercept for this model. If set
diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py
index c6f076483e..433801e45a 100644
--- a/sklearn/linear_model/tests/test_ridge.py
+++ b/sklearn/linear_model/tests/test_ridge.py
@@ -604,10 +604,8 @@ def test_ridgecv_sample_weight():
 
         # Check using GridSearchCV directly
         parameters = {'alpha': alphas}
-        fit_params = {'sample_weight': sample_weight}
-        gs = GridSearchCV(Ridge(), parameters, fit_params=fit_params,
-                          cv=cv)
-        gs.fit(X, y)
+        gs = GridSearchCV(Ridge(), parameters, cv=cv)
+        gs.fit(X, y, sample_weight=sample_weight)
 
         assert_equal(ridgecv.alpha_, gs.best_estimator_.alpha)
         assert_array_almost_equal(ridgecv.coef_, gs.best_estimator_.coef_)
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 566ec8c996..3d5846596f 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -532,7 +532,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('inverse_transform')
         return self.best_estimator_.transform(Xt)
 
-    def fit(self, X, y=None, groups=None):
+    def fit(self, X, y=None, groups=None, **fit_params):
         """Run fit with all sets of parameters.
 
         Parameters
@@ -549,7 +549,21 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
+
+        **fit_params : dict of string -> object
+            Parameters passed to the ``fit`` method of the estimator
         """
+        if self.fit_params:
+            warnings.warn('"fit_params" as a constructor argument was '
+                          'deprecated in version 0.19 and will be removed '
+                          'in version 0.21. Pass fit parameters to the '
+                          '"fit" method instead.', DeprecationWarning)
+            if fit_params:
+                warnings.warn('Ignoring fit_params passed as a constructor '
+                              'argument in favor of keyword arguments to '
+                              'the "fit" method.', RuntimeWarning)
+            else:
+                fit_params = self.fit_params
         estimator = self.estimator
         cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
         self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
@@ -572,7 +586,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
             pre_dispatch=pre_dispatch
         )(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
                                   train, test, self.verbose, parameters,
-                                  fit_params=self.fit_params,
+                                  fit_params=fit_params,
                                   return_train_score=self.return_train_score,
                                   return_n_test_samples=True,
                                   return_times=True, return_parameters=False,
@@ -655,9 +669,9 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
             best_estimator = clone(base_estimator).set_params(
                 **best_parameters)
             if y is not None:
-                best_estimator.fit(X, y, **self.fit_params)
+                best_estimator.fit(X, y, **fit_params)
             else:
-                best_estimator.fit(X, **self.fit_params)
+                best_estimator.fit(X, **fit_params)
             self.best_estimator_ = best_estimator
         return self
 
@@ -730,9 +744,6 @@ class GridSearchCV(BaseSearchCV):
         ``scorer(estimator, X, y)``.
         If ``None``, the ``score`` method of the estimator is used.
 
-    fit_params : dict, optional
-        Parameters to pass to the fit method.
-
     n_jobs : int, default=1
         Number of jobs to run in parallel.
 
@@ -990,9 +1001,6 @@ class RandomizedSearchCV(BaseSearchCV):
         ``scorer(estimator, X, y)``.
         If ``None``, the ``score`` method of the estimator is used.
 
-    fit_params : dict, optional
-        Parameters to pass to the fit method.
-
     n_jobs : int, default=1
         Number of jobs to run in parallel.
 
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 117b81a35a..29bb29264d 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -17,6 +17,7 @@ from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_not_equal
 from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_warns_message
 from sklearn.utils.testing import assert_raise_message
 from sklearn.utils.testing import assert_false, assert_true
 from sklearn.utils.testing import assert_array_equal
@@ -173,6 +174,74 @@ def test_grid_search():
     assert_raises(ValueError, grid_search.fit, X, y)
 
 
+def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
+    X = np.arange(100).reshape(10, 10)
+    y = np.array([0] * 5 + [1] * 5)
+    clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
+    searcher = klass(clf, {'foo_param': [1, 2, 3]}, cv=2, **klass_kwargs)
+
+    # The CheckingClassifer generates an assertion error if
+    # a parameter is missing or has length != len(X).
+    assert_raise_message(AssertionError,
+                         "Expected fit parameter(s) ['eggs'] not seen.",
+                         searcher.fit, X, y, spam=np.ones(10))
+    assert_raise_message(AssertionError,
+                         "Fit parameter spam has length 1; expected 4.",
+                         searcher.fit, X, y, spam=np.ones(1),
+                         eggs=np.zeros(10))
+    searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))
+
+
+def test_grid_search_with_fit_params():
+    check_hyperparameter_searcher_with_fit_params(GridSearchCV)
+
+
+def test_random_search_with_fit_params():
+    check_hyperparameter_searcher_with_fit_params(RandomizedSearchCV, n_iter=1)
+
+
+def test_grid_search_fit_params_deprecation():
+    # NOTE: Remove this test in v0.21
+
+    # Use of `fit_params` in the class constructor is deprecated,
+    # but will still work until v0.21.
+    X = np.arange(100).reshape(10, 10)
+    y = np.array([0] * 5 + [1] * 5)
+    clf = CheckingClassifier(expected_fit_params=['spam'])
+    grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
+                               fit_params={'spam': np.ones(10)})
+    assert_warns(DeprecationWarning, grid_search.fit, X, y)
+
+
+def test_grid_search_fit_params_two_places():
+    # NOTE: Remove this test in v0.21
+
+    # If users try to input fit parameters in both
+    # the constructor (deprecated use) and the `fit`
+    # method, we'll ignore the values passed to the constructor.
+    X = np.arange(100).reshape(10, 10)
+    y = np.array([0] * 5 + [1] * 5)
+    clf = CheckingClassifier(expected_fit_params=['spam'])
+
+    # The "spam" array is too short and will raise an
+    # error in the CheckingClassifier if used.
+    grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
+                               fit_params={'spam': np.ones(1)})
+
+    expected_warning = ('Ignoring fit_params passed as a constructor '
+                        'argument in favor of keyword arguments to '
+                        'the "fit" method.')
+    assert_warns_message(RuntimeWarning, expected_warning,
+                         grid_search.fit, X, y, spam=np.ones(10))
+
+    # Verify that `fit` prefers its own kwargs by giving valid
+    # kwargs in the constructor and invalid in the method call
+    grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
+                               fit_params={'spam': np.ones(10)})
+    assert_raise_message(AssertionError, "Fit parameter spam has length 1",
+                         grid_search.fit, X, y, spam=np.ones(1))
+
+
 @ignore_warnings
 def test_grid_search_no_score():
     # Test grid-search on classifier that has no score function.
diff --git a/sklearn/utils/mocking.py b/sklearn/utils/mocking.py
index c02bf8431f..013644a285 100644
--- a/sklearn/utils/mocking.py
+++ b/sklearn/utils/mocking.py
@@ -44,13 +44,14 @@ class CheckingClassifier(BaseEstimator, ClassifierMixin):
     This allows testing whether pipelines / cross-validation or metaestimators
     changed the input.
     """
-    def __init__(self, check_y=None,
-                 check_X=None, foo_param=0):
+    def __init__(self, check_y=None, check_X=None, foo_param=0,
+                 expected_fit_params=None):
         self.check_y = check_y
         self.check_X = check_X
         self.foo_param = foo_param
+        self.expected_fit_params = expected_fit_params
 
-    def fit(self, X, y):
+    def fit(self, X, y, **fit_params):
         assert_true(len(X) == len(y))
         if self.check_X is not None:
             assert_true(self.check_X(X))
@@ -58,6 +59,14 @@ class CheckingClassifier(BaseEstimator, ClassifierMixin):
             assert_true(self.check_y(y))
         self.classes_ = np.unique(check_array(y, ensure_2d=False,
                                               allow_nd=True))
+        if self.expected_fit_params:
+            missing = set(self.expected_fit_params) - set(fit_params)
+            assert_true(len(missing) == 0, 'Expected fit parameter(s) %s not '
+                                           'seen.' % list(missing))
+            for key, value in fit_params.items():
+                assert_true(len(value) == len(X),
+                            'Fit parameter %s has length %d; '
+                            'expected %d.' % (key, len(value), len(X)))
 
         return self
 
-- 
GitLab