Skip to content
Snippets Groups Projects
Commit 51c8c161 authored by Sam Shleifer's avatar Sam Shleifer Committed by Gael Varoquaux
Browse files

[MRG+1] RFE can raise NotFittedError (#9283)

* RFE can raise NotFittedError

* boom boom

* dont change tests

* tests pass

* remove two extra lines
parent c2dfd751
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@
import numpy as np
from ..utils import check_X_y, safe_sqr
from ..utils.metaestimators import if_delegate_has_method
from ..utils.validation import check_is_fitted
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..base import clone
......@@ -233,6 +234,7 @@ class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
y : array of shape [n_samples]
The predicted target values.
"""
check_is_fitted(self, 'estimator_')
return self.estimator_.predict(self.transform(X))
@if_delegate_has_method(delegate='estimator')
......@@ -248,21 +250,26 @@ class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
y : array of shape [n_samples]
The target values.
"""
check_is_fitted(self, 'estimator_')
return self.estimator_.score(self.transform(X), y)
def _get_support_mask(self):
check_is_fitted(self, 'support_')
return self.support_
@if_delegate_has_method(delegate='estimator')
def decision_function(self, X):
check_is_fitted(self, 'estimator_')
return self.estimator_.decision_function(self.transform(X))
@if_delegate_has_method(delegate='estimator')
def predict_proba(self, X):
check_is_fitted(self, 'estimator_')
return self.estimator_.predict_proba(self.transform(X))
@if_delegate_has_method(delegate='estimator')
def predict_log_proba(self, X):
check_is_fitted(self, 'estimator_')
return self.estimator_.predict_log_proba(self.transform(X))
......
......@@ -7,11 +7,14 @@ import numpy as np
from sklearn.base import BaseEstimator
from sklearn.externals.six import iterkeys
from sklearn.datasets import make_classification
from sklearn.utils.testing import assert_true, assert_false, assert_raises
from sklearn.utils.validation import check_is_fitted
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.feature_selection import RFE, RFECV
from sklearn.ensemble import BaggingClassifier
from sklearn.exceptions import NotFittedError
class DelegatorData(object):
......@@ -64,8 +67,7 @@ def test_metaestimator_delegation():
return True
def _check_fit(self):
if not hasattr(self, 'coef_'):
raise RuntimeError('Estimator is not fit')
check_is_fitted(self, 'coef_')
@hides
def inverse_transform(self, X, *args, **kwargs):
......@@ -116,8 +118,8 @@ def test_metaestimator_delegation():
assert_true(hasattr(delegator, method),
msg="%s does not have method %r when its delegate does"
% (delegator_data.name, method))
# delegation before fit raises an exception
assert_raises(Exception, getattr(delegator, method),
# delegation before fit raises a NotFittedError
assert_raises(NotFittedError, getattr(delegator, method),
delegator_data.fit_args[0])
delegator.fit(*delegator_data.fit_args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment