From 3dffa08e4349166148cd7da4f07c006c642f86d2 Mon Sep 17 00:00:00 2001
From: Yichuan Liu <yichuanliu2004@gmail.com>
Date: Mon, 12 Sep 2016 20:35:32 +0200
Subject: [PATCH] FIX 7155: GridSearchCV predict_proba delegation to
 SDGClassifier

---
 sklearn/grid_search.py                       | 12 ++--
 sklearn/model_selection/_search.py           | 12 ++--
 sklearn/model_selection/tests/test_search.py | 39 +++++++++++-
 sklearn/tests/test_metaestimators.py         |  3 +-
 sklearn/utils/metaestimators.py              | 63 +++++++++++---------
 sklearn/utils/tests/test_metaestimators.py   | 56 ++++++++++++++++-
 6 files changed, 140 insertions(+), 45 deletions(-)

diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py
index 202327574e..0de08ee9e8 100644
--- a/sklearn/grid_search.py
+++ b/sklearn/grid_search.py
@@ -426,7 +426,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                           ChangedBehaviorWarning)
         return self.scorer_(self.best_estimator_, X, y)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict(self, X):
         """Call predict on the estimator with the best found parameters.
 
@@ -442,7 +442,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         """
         return self.best_estimator_.predict(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict_proba(self, X):
         """Call predict_proba on the estimator with the best found parameters.
 
@@ -458,7 +458,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         """
         return self.best_estimator_.predict_proba(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict_log_proba(self, X):
         """Call predict_log_proba on the estimator with the best found parameters.
 
@@ -474,7 +474,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         """
         return self.best_estimator_.predict_log_proba(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def decision_function(self, X):
         """Call decision_function on the estimator with the best found parameters.
 
@@ -490,7 +490,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         """
         return self.best_estimator_.decision_function(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def transform(self, X):
         """Call transform on the estimator with the best found parameters.
 
@@ -506,7 +506,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         """
         return self.best_estimator_.transform(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def inverse_transform(self, Xt):
         """Call inverse_transform on the estimator with the best found parameters.
 
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index f1880555df..7c6344c02c 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -426,7 +426,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         else:
             check_is_fitted(self, 'best_estimator_')
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict(self, X):
         """Call predict on the estimator with the best found parameters.
 
@@ -443,7 +443,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('predict')
         return self.best_estimator_.predict(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict_proba(self, X):
         """Call predict_proba on the estimator with the best found parameters.
 
@@ -460,7 +460,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('predict_proba')
         return self.best_estimator_.predict_proba(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def predict_log_proba(self, X):
         """Call predict_log_proba on the estimator with the best found parameters.
 
@@ -477,7 +477,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('predict_log_proba')
         return self.best_estimator_.predict_log_proba(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def decision_function(self, X):
         """Call decision_function on the estimator with the best found parameters.
 
@@ -494,7 +494,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('decision_function')
         return self.best_estimator_.decision_function(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def transform(self, X):
         """Call transform on the estimator with the best found parameters.
 
@@ -511,7 +511,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('transform')
         return self.best_estimator_.transform(X)
 
-    @if_delegate_has_method(delegate='estimator')
+    @if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
     def inverse_transform(self, Xt):
         """Call inverse_transform on the estimator with the best found params.
 
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 141c1a21b4..bb21a386d3 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -58,6 +58,7 @@ from sklearn.metrics import make_scorer
 from sklearn.metrics import roc_auc_score
 from sklearn.preprocessing import Imputer
 from sklearn.pipeline import Pipeline
+from sklearn.linear_model import SGDClassifier
 
 
 # Neither of the following two estimators inherit from BaseEstimator,
@@ -967,11 +968,13 @@ def test_grid_search_failing_classifier():
                       refit=False, error_score=0.0)
     assert_warns(FitFailedWarning, gs.fit, X, y)
     n_candidates = len(gs.cv_results_['params'])
+
     # Ensure that grid scores were set to zero as required for those fits
     # that are expected to fail.
-    get_cand_scores = lambda i: np.array(list(
-        gs.cv_results_['split%d_test_score' % s][i]
-        for s in range(gs.n_splits_)))
+    def get_cand_scores(i):
+        return np.array(list(gs.cv_results_['split%d_test_score' % s][i]
+                             for s in range(gs.n_splits_)))
+
     assert all((np.all(get_cand_scores(cand_i) == 0.0)
                 for cand_i in range(n_candidates)
                 if gs.cv_results_['param_parameter'][cand_i] ==
@@ -1028,3 +1031,33 @@ def test_parameters_sampler_replacement():
     sampler = ParameterSampler(params_distribution, n_iter=7)
     samples = list(sampler)
     assert_equal(len(samples), 7)
+
+
+def test_stochastic_gradient_loss_param():
+    # Make sure the predict_proba works when loss is specified
+    # as one of the parameters in the param_grid.
+    param_grid = {
+        'loss': ['log'],
+    }
+    X = np.arange(20).reshape(5, -1)
+    y = [0, 0, 1, 1, 1]
+    clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
+                       param_grid=param_grid)
+
+    # When the estimator is not fitted, `predict_proba` is not available as the
+    # loss is 'hinge'.
+    assert_false(hasattr(clf, "predict_proba"))
+    clf.fit(X, y)
+    clf.predict_proba(X)
+    clf.predict_log_proba(X)
+
+    # Make sure `predict_proba` is not available when setting loss=['hinge']
+    # in param_grid
+    param_grid = {
+        'loss': ['hinge'],
+    }
+    clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
+                       param_grid=param_grid)
+    assert_false(hasattr(clf, "predict_proba"))
+    clf.fit(X, y)
+    assert_false(hasattr(clf, "predict_proba"))
diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py
index 4c6ace3d3a..f0f30cb91a 100644
--- a/sklearn/tests/test_metaestimators.py
+++ b/sklearn/tests/test_metaestimators.py
@@ -39,7 +39,8 @@ DELEGATING_METAESTIMATORS = [
                   skip_methods=['transform', 'inverse_transform', 'score']),
     DelegatorData('BaggingClassifier', BaggingClassifier,
                   skip_methods=['transform', 'inverse_transform', 'score',
-                                'predict_proba', 'predict_log_proba', 'predict'])
+                                'predict_proba', 'predict_log_proba',
+                                'predict'])
 ]
 
 
diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py
index 9850ea50ea..346064448b 100644
--- a/sklearn/utils/metaestimators.py
+++ b/sklearn/utils/metaestimators.py
@@ -14,16 +14,22 @@ class _IffHasAttrDescriptor(object):
     """Implements a conditional property using the descriptor protocol.
 
     Using this class to create a decorator will raise an ``AttributeError``
-    if the ``attribute_name`` is not present on the base object.
+    if none of the delegates (specified in ``delegate_names``) is an attribute
+    of the base object or the first found delegate does not have an attribute
+    ``attribute_name``.
 
-    This allows ducktyping of the decorated method based on ``attribute_name``.
+    This allows ducktyping of the decorated method based on
+    ``delegate.attribute_name``. Here ``delegate`` is the first item in
+    ``delegate_names`` for which ``hasattr(object, delegate) is True``.
 
     See https://docs.python.org/3/howto/descriptor.html for an explanation of
     descriptors.
     """
-    def __init__(self, fn, attribute_name):
+    def __init__(self, fn, delegate_names, attribute_name):
         self.fn = fn
-        self.get_attribute = attrgetter(attribute_name)
+        self.delegate_names = delegate_names
+        self.attribute_name = attribute_name
+
         # update the docstring of the descriptor
         update_wrapper(self, fn)
 
@@ -32,7 +38,17 @@ class _IffHasAttrDescriptor(object):
         if obj is not None:
             # delegate only on instances, not the classes.
             # this is to allow access to the docstrings.
-            self.get_attribute(obj)
+            for delegate_name in self.delegate_names:
+                try:
+                    delegate = attrgetter(delegate_name)(obj)
+                except AttributeError:
+                    continue
+                else:
+                    getattr(delegate, self.attribute_name)
+                    break
+            else:
+                attrgetter(self.delegate_names[-1])(obj)
+
         # lambda, but not partial, allows help() to work with update_wrapper
         out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
         # update the docstring of the returned function
@@ -46,27 +62,18 @@ def if_delegate_has_method(delegate):
     This enables ducktyping by hasattr returning True according to the
     sub-estimator.
 
-    >>> from sklearn.utils.metaestimators import if_delegate_has_method
-    >>>
-    >>>
-    >>> class MetaEst(object):
-    ...     def __init__(self, sub_est):
-    ...         self.sub_est = sub_est
-    ...
-    ...     @if_delegate_has_method(delegate='sub_est')
-    ...     def predict(self, X):
-    ...         return self.sub_est.predict(X)
-    ...
-    >>> class HasPredict(object):
-    ...     def predict(self, X):
-    ...         return X.sum(axis=1)
-    ...
-    >>> class HasNoPredict(object):
-    ...     pass
-    ...
-    >>> hasattr(MetaEst(HasPredict()), 'predict')
-    True
-    >>> hasattr(MetaEst(HasNoPredict()), 'predict')
-    False
+    Parameters
+    ----------
+    delegate : string, list of strings or tuple of strings
+        Name of the sub-estimator that can be accessed as an attribute of the
+        base object. If a list or a tuple of names are provided, the first
+        sub-estimator that is an attribute of the base object  will be used.
+
     """
-    return lambda fn: _IffHasAttrDescriptor(fn, '%s.%s' % (delegate, fn.__name__))
+    if isinstance(delegate, list):
+        delegate = tuple(delegate)
+    if not isinstance(delegate, tuple):
+        delegate = (delegate,)
+
+    return lambda fn: _IffHasAttrDescriptor(fn, delegate,
+                                            attribute_name=fn.__name__)
diff --git a/sklearn/utils/tests/test_metaestimators.py b/sklearn/utils/tests/test_metaestimators.py
index cb1f46ef80..d73c67d0d1 100644
--- a/sklearn/utils/tests/test_metaestimators.py
+++ b/sklearn/utils/tests/test_metaestimators.py
@@ -1,5 +1,5 @@
+from nose.tools import assert_true, assert_false
 from sklearn.utils.metaestimators import if_delegate_has_method
-from nose.tools import assert_true
 
 
 class Prefix(object):
@@ -24,3 +24,57 @@ def test_delegated_docstring():
                 in str(MockMetaEstimator.func.__doc__))
     assert_true("This is a mock delegated function"
                 in str(MockMetaEstimator().func.__doc__))
+
+
+class MetaEst(object):
+    """A mock meta estimator"""
+    def __init__(self, sub_est, better_sub_est=None):
+        self.sub_est = sub_est
+        self.better_sub_est = better_sub_est
+
+    @if_delegate_has_method(delegate='sub_est')
+    def predict(self):
+        pass
+
+
+class MetaEstTestTuple(MetaEst):
+    """A mock meta estimator to test passing a tuple of delegates"""
+
+    @if_delegate_has_method(delegate=('sub_est', 'better_sub_est'))
+    def predict(self):
+        pass
+
+
+class MetaEstTestList(MetaEst):
+    """A mock meta estimator to test passing a list of delegates"""
+
+    @if_delegate_has_method(delegate=['sub_est', 'better_sub_est'])
+    def predict(self):
+        pass
+
+
+class HasPredict(object):
+    """A mock sub-estimator with predict method"""
+
+    def predict(self):
+        pass
+
+
+class HasNoPredict(object):
+    """A mock sub-estimator with no predict method"""
+    pass
+
+
+def test_if_delegate_has_method():
+    assert_true(hasattr(MetaEst(HasPredict()), 'predict'))
+    assert_false(hasattr(MetaEst(HasNoPredict()), 'predict'))
+    assert_false(
+        hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), 'predict'))
+    assert_true(
+        hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), 'predict'))
+    assert_false(
+        hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), 'predict'))
+    assert_false(
+        hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), 'predict'))
+    assert_true(
+        hasattr(MetaEstTestList(HasPredict(), HasPredict()), 'predict'))
-- 
GitLab