diff --git a/doc/whats_new.rst b/doc/whats_new.rst index eb661dacae87abf3cb5f44aae050cdc4d9e062fb..92d140cd0b19fc7e576c35d76ecfe72e93cdc56f 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -389,6 +389,10 @@ Bug fixes classes, and some values proposed in the docstring could raise errors. :issue:`5359` by `Tom Dupre la Tour`_. + - Fixed a bug where :func:`model_selection.validation_curve` + reused the same estimator for each parameter value. + :issue:`7365` by `Aleksandr Sandrovskii <Sundrique>`. + API changes summary ------------------- diff --git a/sklearn/learning_curve.py b/sklearn/learning_curve.py index 0cfe4c3cad031f9ff49eb8b80e9dee3c83a9226b..cfe1aba4ea178c3dfcce5fe45b49d3a8f19b6f7f 100644 --- a/sklearn/learning_curve.py +++ b/sklearn/learning_curve.py @@ -348,7 +348,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) out = parallel(delayed(_fit_and_score)( - estimator, X, y, scorer, train, test, verbose, + clone(estimator), X, y, scorer, train, test, verbose, parameters={param_name: v}, fit_params=None, return_train_score=True) for train, test in cv for v in param_range) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index db830619567d8e5e7c889b2357faa8d76503b8af..61a9039114fa69c18a3a04879600695b9181ea98 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -988,7 +988,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) out = parallel(delayed(_fit_and_score)( - estimator, X, y, scorer, train, test, verbose, + clone(estimator), X, y, scorer, train, test, verbose, parameters={param_name: v}, fit_params=None, return_train_score=True) # NOTE do not change order of iteration to allow one time cv splitters for train, test in cv.split(X, y, groups) for v in param_range) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index c05b25ce67f12ab85994c32ba79af44bd130f84b..5817c31f5f99a60c0352bf2ef48f36f329f47fa1 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -133,6 +133,21 @@ class MockEstimatorWithParameter(BaseEstimator): return X is self.X_subset +class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter): + """Dummy classifier that disallows repeated calls of fit method""" + + def fit(self, X_subset, y_subset): + assert_false( + hasattr(self, 'fit_called_'), + 'fit is called the second time' + ) + self.fit_called_ = True + return super(type(self), self).fit(X_subset, y_subset) + + def predict(self, X): + raise NotImplementedError + + class MockClassifier(object): """Dummy classifier to test the cross-validation""" @@ -852,6 +867,18 @@ def test_validation_curve(): assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) +def test_validation_curve_clone_estimator(): + X, y = make_classification(n_samples=2, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + + param_range = np.linspace(1, 0, 10) + _, _ = validation_curve( + MockEstimatorWithSingleFitCallAllowed(), X, y, + param_name="param", param_range=param_range, cv=2 + ) + + def test_validation_curve_cv_splits_consistency(): n_samples = 100 n_splits = 5 diff --git a/sklearn/tests/test_learning_curve.py b/sklearn/tests/test_learning_curve.py index 129dba52ac8313216e065d0d09e21328f46a25ba..48cb8ce0ea0b0661709c232f94536ebad373d097 100644 --- a/sklearn/tests/test_learning_curve.py +++ b/sklearn/tests/test_learning_curve.py @@ -12,6 +12,7 @@ from sklearn.utils.testing import assert_warns from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_false from sklearn.datasets import make_classification with warnings.catch_warnings(): @@ -93,6 +94,18 @@ class MockEstimatorFailing(BaseEstimator): return None +class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter): + """Dummy classifier that disallows repeated calls of fit method""" + + def fit(self, X_subset, y_subset): + assert_false( + hasattr(self, 'fit_called_'), + 'fit is called the second time' + ) + self.fit_called_ = True + return super(type(self), self).fit(X_subset, y_subset) + + def test_learning_curve(): X, y = make_classification(n_samples=30, n_features=1, n_informative=1, n_redundant=0, n_classes=2, @@ -284,3 +297,15 @@ def test_validation_curve(): assert_array_almost_equal(train_scores.mean(axis=1), param_range) assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) + + +def test_validation_curve_clone_estimator(): + X, y = make_classification(n_samples=2, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + + param_range = np.linspace(1, 0, 10) + _, _ = validation_curve( + MockEstimatorWithSingleFitCallAllowed(), X, y, + param_name="param", param_range=param_range, cv=2 + )