From ad16bb45c666945563a74b71025e3742d57c463f Mon Sep 17 00:00:00 2001
From: Reiichiro Nakano <reiichiro.s.nakano@gmail.com>
Date: Fri, 20 Oct 2017 06:05:43 +0900
Subject: [PATCH] [MRG+1] Fix cross_val_predict behavior for binary
 classification in decision_function (Fixes #9589) (#9593)

* fix cross_val_predict for binary classification in decision_function

* Add unit tests

* Add unit tests

* Add unit tests

* better fix

* fix conflict

* fix broken

* only calculate n_classes if one of 'decision_function', 'predict_proba', 'predict_log_proba'

* add test for SVC ovo in cross_val_predict

* flake8 fix

* fix case of ovo and imbalanced folds for binary classification

* change assert_raises to assert_raise_message for ovo case

* fix flake8 linetoo long

* add comments and clearer tests

* improve comments and error message for OvO

* fix .format error with L

* use assert_raises_regex for better error message

* raise error in decision_function special cases. change predict_log_proba missing classes to minimum numpy value

* fix broken tests due to special cases of decision_function

* add modified test for decision_function behavior that does not trigger edge cases

* fix typos

* fix typos

* escape regex .

* escape regex .

* address comments. one unaddressed comment

* simplify code

* flake

* wrong classes range

* address comments. adjust error message

* add warning

* change warning to runtimewarning

* add test for the warning

* Use assert_warns_message rather than assert_warns

Other minor fixes

* Note on class-absent replacement values

* Improve error message
---
 sklearn/model_selection/_validation.py        |  58 +++++++++-
 .../model_selection/tests/test_validation.py  | 104 ++++++++++++++++--
 2 files changed, 147 insertions(+), 15 deletions(-)

diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index bcdcb9f010..fdf6fa6912 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -644,6 +644,15 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
     predictions : ndarray
         This is the result of calling ``method``
 
+    Notes
+    -----
+    In the case that one or more classes are absent in a training portion, a
+    default score needs to be assigned to all instances for that class if
+    ``method`` produces columns per class, as in {'decision_function',
+    'predict_proba', 'predict_log_proba'}.  For ``predict_proba`` this value is
+    0.  In order to ensure finite output, we approximate negative infinity by
+    the minimum finite float value for the dtype in other cases.
+
     Examples
     --------
     >>> from sklearn import datasets, linear_model
@@ -746,12 +755,49 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
     predictions = func(X_test)
     if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
         n_classes = len(set(y))
-        predictions_ = np.zeros((_num_samples(X_test), n_classes))
-        if method == 'decision_function' and len(estimator.classes_) == 2:
-            predictions_[:, estimator.classes_[-1]] = predictions
-        else:
-            predictions_[:, estimator.classes_] = predictions
-        predictions = predictions_
+        if n_classes != len(estimator.classes_):
+            recommendation = (
+                'To fix this, use a cross-validation '
+                'technique resulting in properly '
+                'stratified folds')
+            warnings.warn('Number of classes in training fold ({}) does '
+                          'not match total number of classes ({}). '
+                          'Results may not be appropriate for your use case. '
+                          '{}'.format(len(estimator.classes_),
+                                      n_classes, recommendation),
+                          RuntimeWarning)
+            if method == 'decision_function':
+                if (predictions.ndim == 2 and
+                        predictions.shape[1] != len(estimator.classes_)):
+                    # This handles the case when the shape of predictions
+                    # does not match the number of classes used to train
+                    # it with. This case is found when sklearn.svm.SVC is
+                    # set to `decision_function_shape='ovo'`.
+                    raise ValueError('Output shape {} of {} does not match '
+                                     'number of classes ({}) in fold. '
+                                     'Irregular decision_function outputs '
+                                     'are not currently supported by '
+                                     'cross_val_predict'.format(
+                                        predictions.shape, method,
+                                        len(estimator.classes_),
+                                        recommendation))
+                if len(estimator.classes_) <= 2:
+                    # In this special case, `predictions` contains a 1D array.
+                    raise ValueError('Only {} class/es in training fold, this '
+                                     'is not supported for decision_function '
+                                     'with imbalanced folds. {}'.format(
+                                        len(estimator.classes_),
+                                        recommendation))
+
+            float_min = np.finfo(predictions.dtype).min
+            default_values = {'decision_function': float_min,
+                              'predict_log_proba': float_min,
+                              'predict_proba': 0}
+            predictions_for_all_classes = np.full((_num_samples(predictions),
+                                                   n_classes),
+                                                  default_values[method])
+            predictions_for_all_classes[:, estimator.classes_] = predictions
+            predictions = predictions_for_all_classes
     return predictions, test
 
 
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index d57be1e835..b7b1dd781e 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -24,6 +24,7 @@ from sklearn.utils.testing import assert_less
 from sklearn.utils.testing import assert_array_almost_equal
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_warns_message
 from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
 
 from sklearn.model_selection import cross_val_score
@@ -44,6 +45,7 @@ from sklearn.model_selection._validation import _check_is_permutation
 from sklearn.datasets import make_regression
 from sklearn.datasets import load_boston
 from sklearn.datasets import load_iris
+from sklearn.datasets import load_digits
 from sklearn.metrics import explained_variance_score
 from sklearn.metrics import make_scorer
 from sklearn.metrics import accuracy_score
@@ -54,7 +56,7 @@ from sklearn.metrics import r2_score
 from sklearn.metrics.scorer import check_scoring
 
 from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier
-from sklearn.linear_model import PassiveAggressiveClassifier
+from sklearn.linear_model import PassiveAggressiveClassifier, RidgeClassifier
 from sklearn.neighbors import KNeighborsClassifier
 from sklearn.svm import SVC
 from sklearn.cluster import KMeans
@@ -800,6 +802,89 @@ def test_cross_val_predict():
 
     assert_raises(ValueError, cross_val_predict, est, X, y, cv=BadCV())
 
+    X, y = load_iris(return_X_y=True)
+
+    warning_message = ('Number of classes in training fold (2) does '
+                       'not match total number of classes (3). '
+                       'Results may not be appropriate for your use case.')
+    assert_warns_message(RuntimeWarning, warning_message,
+                         cross_val_predict, LogisticRegression(),
+                         X, y, method='predict_proba', cv=KFold(2))
+
+
+def test_cross_val_predict_decision_function_shape():
+    X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='decision_function')
+    assert_equal(preds.shape, (50,))
+
+    X, y = load_iris(return_X_y=True)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='decision_function')
+    assert_equal(preds.shape, (150, 3))
+
+    # This specifically tests imbalanced splits for binary
+    # classification with decision_function. This is only
+    # applicable to classifiers that can be fit on a single
+    # class.
+    X = X[:100]
+    y = y[:100]
+    assert_raise_message(ValueError,
+                         'Only 1 class/es in training fold, this'
+                         ' is not supported for decision_function'
+                         ' with imbalanced folds. To fix '
+                         'this, use a cross-validation technique '
+                         'resulting in properly stratified folds',
+                         cross_val_predict, RidgeClassifier(), X, y,
+                         method='decision_function', cv=KFold(2))
+
+    X, y = load_digits(return_X_y=True)
+    est = SVC(kernel='linear', decision_function_shape='ovo')
+
+    preds = cross_val_predict(est,
+                              X, y,
+                              method='decision_function')
+    assert_equal(preds.shape, (1797, 45))
+
+    ind = np.argsort(y)
+    X, y = X[ind], y[ind]
+    assert_raises_regex(ValueError,
+                        'Output shape \(599L?, 21L?\) of decision_function '
+                        'does not match number of classes \(7\) in fold. '
+                        'Irregular decision_function .*',
+                        cross_val_predict, est, X, y,
+                        cv=KFold(n_splits=3), method='decision_function')
+
+
+def test_cross_val_predict_predict_proba_shape():
+    X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='predict_proba')
+    assert_equal(preds.shape, (50, 2))
+
+    X, y = load_iris(return_X_y=True)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='predict_proba')
+    assert_equal(preds.shape, (150, 3))
+
+
+def test_cross_val_predict_predict_log_proba_shape():
+    X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='predict_log_proba')
+    assert_equal(preds.shape, (50, 2))
+
+    X, y = load_iris(return_X_y=True)
+
+    preds = cross_val_predict(LogisticRegression(), X, y,
+                              method='predict_log_proba')
+    assert_equal(preds.shape, (150, 3))
+
 
 def test_cross_val_predict_input_types():
     iris = load_iris()
@@ -1241,11 +1326,12 @@ def get_expected_predictions(X, y, cv, classes, est, method):
         est.fit(X[train], y[train])
         expected_predictions_ = func(X[test])
         # To avoid 2 dimensional indexing
-        exp_pred_test = np.zeros((len(test), classes))
-        if method is 'decision_function' and len(est.classes_) == 2:
-            exp_pred_test[:, est.classes_[-1]] = expected_predictions_
+        if method is 'predict_proba':
+            exp_pred_test = np.zeros((len(test), classes))
         else:
-            exp_pred_test[:, est.classes_] = expected_predictions_
+            exp_pred_test = np.full((len(test), classes),
+                                    np.finfo(expected_predictions.dtype).min)
+        exp_pred_test[:, est.classes_] = expected_predictions_
         expected_predictions[test] = exp_pred_test
 
     return expected_predictions
@@ -1253,9 +1339,9 @@ def get_expected_predictions(X, y, cv, classes, est, method):
 
 def test_cross_val_predict_class_subset():
 
-    X = np.arange(8).reshape(4, 2)
-    y = np.array([0, 0, 1, 2])
-    classes = 3
+    X = np.arange(200).reshape(100, 2)
+    y = np.array([x//10 for x in range(100)])
+    classes = 10
 
     kfold3 = KFold(n_splits=3)
     kfold4 = KFold(n_splits=4)
@@ -1283,7 +1369,7 @@ def test_cross_val_predict_class_subset():
         assert_array_almost_equal(expected_predictions, predictions)
 
         # Testing unordered labels
-        y = [1, 1, -4, 6]
+        y = shuffle(np.repeat(range(10), 10), random_state=0)
         predictions = cross_val_predict(est, X, y, method=method,
                                         cv=kfold3)
         y = le.fit_transform(y)
-- 
GitLab