diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 202327574ee76b80fbceca8da056029d195e7d5f..0de08ee9e89f06486d0e35f0dbfba63d403c45b6 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -426,7 +426,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, ChangedBehaviorWarning) return self.scorer_(self.best_estimator_, X, y) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict(self, X): """Call predict on the estimator with the best found parameters. @@ -442,7 +442,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, """ return self.best_estimator_.predict(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict_proba(self, X): """Call predict_proba on the estimator with the best found parameters. @@ -458,7 +458,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, """ return self.best_estimator_.predict_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict_log_proba(self, X): """Call predict_log_proba on the estimator with the best found parameters. @@ -474,7 +474,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, """ return self.best_estimator_.predict_log_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def decision_function(self, X): """Call decision_function on the estimator with the best found parameters. @@ -490,7 +490,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, """ return self.best_estimator_.decision_function(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def transform(self, X): """Call transform on the estimator with the best found parameters. @@ -506,7 +506,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, """ return self.best_estimator_.transform(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def inverse_transform(self, Xt): """Call inverse_transform on the estimator with the best found parameters. diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index f1880555df33904e3321fa1a4db6e0c74fa2512d..7c6344c02c853e40b0a9cdadf95715452375fda0 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -426,7 +426,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, else: check_is_fitted(self, 'best_estimator_') - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict(self, X): """Call predict on the estimator with the best found parameters. @@ -443,7 +443,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('predict') return self.best_estimator_.predict(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict_proba(self, X): """Call predict_proba on the estimator with the best found parameters. @@ -460,7 +460,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('predict_proba') return self.best_estimator_.predict_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def predict_log_proba(self, X): """Call predict_log_proba on the estimator with the best found parameters. @@ -477,7 +477,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('predict_log_proba') return self.best_estimator_.predict_log_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def decision_function(self, X): """Call decision_function on the estimator with the best found parameters. @@ -494,7 +494,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('decision_function') return self.best_estimator_.decision_function(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def transform(self, X): """Call transform on the estimator with the best found parameters. @@ -511,7 +511,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, self._check_is_fitted('transform') return self.best_estimator_.transform(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def inverse_transform(self, Xt): """Call inverse_transform on the estimator with the best found params. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 141c1a21b46e98dc7e5de73bb8fca2dab3f627cb..bb21a386d35b73b001003d07cc170cbda1d3345e 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -58,6 +58,7 @@ from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score from sklearn.preprocessing import Imputer from sklearn.pipeline import Pipeline +from sklearn.linear_model import SGDClassifier # Neither of the following two estimators inherit from BaseEstimator, @@ -967,11 +968,13 @@ def test_grid_search_failing_classifier(): refit=False, error_score=0.0) assert_warns(FitFailedWarning, gs.fit, X, y) n_candidates = len(gs.cv_results_['params']) + # Ensure that grid scores were set to zero as required for those fits # that are expected to fail. - get_cand_scores = lambda i: np.array(list( - gs.cv_results_['split%d_test_score' % s][i] - for s in range(gs.n_splits_))) + def get_cand_scores(i): + return np.array(list(gs.cv_results_['split%d_test_score' % s][i] + for s in range(gs.n_splits_))) + assert all((np.all(get_cand_scores(cand_i) == 0.0) for cand_i in range(n_candidates) if gs.cv_results_['param_parameter'][cand_i] == @@ -1028,3 +1031,33 @@ def test_parameters_sampler_replacement(): sampler = ParameterSampler(params_distribution, n_iter=7) samples = list(sampler) assert_equal(len(samples), 7) + + +def test_stochastic_gradient_loss_param(): + # Make sure the predict_proba works when loss is specified + # as one of the parameters in the param_grid. + param_grid = { + 'loss': ['log'], + } + X = np.arange(20).reshape(5, -1) + y = [0, 0, 1, 1, 1] + clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'), + param_grid=param_grid) + + # When the estimator is not fitted, `predict_proba` is not available as the + # loss is 'hinge'. + assert_false(hasattr(clf, "predict_proba")) + clf.fit(X, y) + clf.predict_proba(X) + clf.predict_log_proba(X) + + # Make sure `predict_proba` is not available when setting loss=['hinge'] + # in param_grid + param_grid = { + 'loss': ['hinge'], + } + clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'), + param_grid=param_grid) + assert_false(hasattr(clf, "predict_proba")) + clf.fit(X, y) + assert_false(hasattr(clf, "predict_proba")) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 4c6ace3d3aeb837b8d520766383bfd9824a28e96..f0f30cb91ae72d79f5fc2b0698fe7e1dfdbada12 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -39,7 +39,8 @@ DELEGATING_METAESTIMATORS = [ skip_methods=['transform', 'inverse_transform', 'score']), DelegatorData('BaggingClassifier', BaggingClassifier, skip_methods=['transform', 'inverse_transform', 'score', - 'predict_proba', 'predict_log_proba', 'predict']) + 'predict_proba', 'predict_log_proba', + 'predict']) ] diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 9850ea50ea8a6a9e29dd4219908a8879966becb2..346064448b008f84e5a597f450e361329405dc8b 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -14,16 +14,22 @@ class _IffHasAttrDescriptor(object): """Implements a conditional property using the descriptor protocol. Using this class to create a decorator will raise an ``AttributeError`` - if the ``attribute_name`` is not present on the base object. + if none of the delegates (specified in ``delegate_names``) is an attribute + of the base object or the first found delegate does not have an attribute + ``attribute_name``. - This allows ducktyping of the decorated method based on ``attribute_name``. + This allows ducktyping of the decorated method based on + ``delegate.attribute_name``. Here ``delegate`` is the first item in + ``delegate_names`` for which ``hasattr(object, delegate) is True``. See https://docs.python.org/3/howto/descriptor.html for an explanation of descriptors. """ - def __init__(self, fn, attribute_name): + def __init__(self, fn, delegate_names, attribute_name): self.fn = fn - self.get_attribute = attrgetter(attribute_name) + self.delegate_names = delegate_names + self.attribute_name = attribute_name + # update the docstring of the descriptor update_wrapper(self, fn) @@ -32,7 +38,17 @@ class _IffHasAttrDescriptor(object): if obj is not None: # delegate only on instances, not the classes. # this is to allow access to the docstrings. - self.get_attribute(obj) + for delegate_name in self.delegate_names: + try: + delegate = attrgetter(delegate_name)(obj) + except AttributeError: + continue + else: + getattr(delegate, self.attribute_name) + break + else: + attrgetter(self.delegate_names[-1])(obj) + # lambda, but not partial, allows help() to work with update_wrapper out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # update the docstring of the returned function @@ -46,27 +62,18 @@ def if_delegate_has_method(delegate): This enables ducktyping by hasattr returning True according to the sub-estimator. - >>> from sklearn.utils.metaestimators import if_delegate_has_method - >>> - >>> - >>> class MetaEst(object): - ... def __init__(self, sub_est): - ... self.sub_est = sub_est - ... - ... @if_delegate_has_method(delegate='sub_est') - ... def predict(self, X): - ... return self.sub_est.predict(X) - ... - >>> class HasPredict(object): - ... def predict(self, X): - ... return X.sum(axis=1) - ... - >>> class HasNoPredict(object): - ... pass - ... - >>> hasattr(MetaEst(HasPredict()), 'predict') - True - >>> hasattr(MetaEst(HasNoPredict()), 'predict') - False + Parameters + ---------- + delegate : string, list of strings or tuple of strings + Name of the sub-estimator that can be accessed as an attribute of the + base object. If a list or a tuple of names are provided, the first + sub-estimator that is an attribute of the base object will be used. + """ - return lambda fn: _IffHasAttrDescriptor(fn, '%s.%s' % (delegate, fn.__name__)) + if isinstance(delegate, list): + delegate = tuple(delegate) + if not isinstance(delegate, tuple): + delegate = (delegate,) + + return lambda fn: _IffHasAttrDescriptor(fn, delegate, + attribute_name=fn.__name__) diff --git a/sklearn/utils/tests/test_metaestimators.py b/sklearn/utils/tests/test_metaestimators.py index cb1f46ef80eb11e4e47b370895c11b511c96890c..d73c67d0d19832d26f3602c8d8c68a535bbd2a5d 100644 --- a/sklearn/utils/tests/test_metaestimators.py +++ b/sklearn/utils/tests/test_metaestimators.py @@ -1,5 +1,5 @@ +from nose.tools import assert_true, assert_false from sklearn.utils.metaestimators import if_delegate_has_method -from nose.tools import assert_true class Prefix(object): @@ -24,3 +24,57 @@ def test_delegated_docstring(): in str(MockMetaEstimator.func.__doc__)) assert_true("This is a mock delegated function" in str(MockMetaEstimator().func.__doc__)) + + +class MetaEst(object): + """A mock meta estimator""" + def __init__(self, sub_est, better_sub_est=None): + self.sub_est = sub_est + self.better_sub_est = better_sub_est + + @if_delegate_has_method(delegate='sub_est') + def predict(self): + pass + + +class MetaEstTestTuple(MetaEst): + """A mock meta estimator to test passing a tuple of delegates""" + + @if_delegate_has_method(delegate=('sub_est', 'better_sub_est')) + def predict(self): + pass + + +class MetaEstTestList(MetaEst): + """A mock meta estimator to test passing a list of delegates""" + + @if_delegate_has_method(delegate=['sub_est', 'better_sub_est']) + def predict(self): + pass + + +class HasPredict(object): + """A mock sub-estimator with predict method""" + + def predict(self): + pass + + +class HasNoPredict(object): + """A mock sub-estimator with no predict method""" + pass + + +def test_if_delegate_has_method(): + assert_true(hasattr(MetaEst(HasPredict()), 'predict')) + assert_false(hasattr(MetaEst(HasNoPredict()), 'predict')) + assert_false( + hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), 'predict')) + assert_true( + hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), 'predict')) + assert_false( + hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), 'predict')) + assert_false( + hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), 'predict')) + assert_true( + hasattr(MetaEstTestList(HasPredict(), HasPredict()), 'predict'))