diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 13efcfd6cc84dabacaea5597894222e493bd6709..a894753b0f46bb7e2a92b09972af51629b35300b 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -122,6 +122,10 @@ Decomposition, manifold learning and clustering with large datasets when ``n_components='mle'`` on Python 3 versions. :issue:`9886` by :user:`Hanmin Qin <qinhanmin2014>`. +- Fixed a bug when setting parameters on meta-estimator, involving both a + wrapped estimator and its parameter. :issue:`9999` by :user:`Marcus Voss + <marcus-voss>` and `Joel Nothman`_. + Metrics - Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with diff --git a/sklearn/base.py b/sklearn/base.py index b653b7149c373360afdc48cae308bbec0128d9d7..81c7e5dae7bcca62f45685955f3b1133720ba5fb 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -263,6 +263,7 @@ class BaseEstimator(object): nested_params[key][sub_key] = value else: setattr(self, key, value) + valid_params[key] = value for key, sub_params in nested_params.items(): valid_params[key].set_params(**sub_params) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 580a4e2ecac9f826b98dedf081fed8a49549dc54..4620dcbd03604de11181446be836b5ad4aed4eeb 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -246,6 +246,14 @@ def test_set_params_passes_all_parameters(): estimator__min_samples_leaf=2) +def test_set_params_updates_valid_params(): + # Check that set_params tries to set SVC().C, not + # DecisionTreeClassifier().C + gscv = GridSearchCV(DecisionTreeClassifier(), {}) + gscv.set_params(estimator=SVC(), estimator__C=42.0) + assert gscv.estimator.C == 42.0 + + def test_score_sample_weight(): rng = np.random.RandomState(0)