From ad16bb45c666945563a74b71025e3742d57c463f Mon Sep 17 00:00:00 2001 From: Reiichiro Nakano <reiichiro.s.nakano@gmail.com> Date: Fri, 20 Oct 2017 06:05:43 +0900 Subject: [PATCH] [MRG+1] Fix cross_val_predict behavior for binary classification in decision_function (Fixes #9589) (#9593) * fix cross_val_predict for binary classification in decision_function * Add unit tests * Add unit tests * Add unit tests * better fix * fix conflict * fix broken * only calculate n_classes if one of 'decision_function', 'predict_proba', 'predict_log_proba' * add test for SVC ovo in cross_val_predict * flake8 fix * fix case of ovo and imbalanced folds for binary classification * change assert_raises to assert_raise_message for ovo case * fix flake8 linetoo long * add comments and clearer tests * improve comments and error message for OvO * fix .format error with L * use assert_raises_regex for better error message * raise error in decision_function special cases. change predict_log_proba missing classes to minimum numpy value * fix broken tests due to special cases of decision_function * add modified test for decision_function behavior that does not trigger edge cases * fix typos * fix typos * escape regex . * escape regex . * address comments. one unaddressed comment * simplify code * flake * wrong classes range * address comments. adjust error message * add warning * change warning to runtimewarning * add test for the warning * Use assert_warns_message rather than assert_warns Other minor fixes * Note on class-absent replacement values * Improve error message --- sklearn/model_selection/_validation.py | 58 +++++++++- .../model_selection/tests/test_validation.py | 104 ++++++++++++++++-- 2 files changed, 147 insertions(+), 15 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index bcdcb9f010..fdf6fa6912 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -644,6 +644,15 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, predictions : ndarray This is the result of calling ``method`` + Notes + ----- + In the case that one or more classes are absent in a training portion, a + default score needs to be assigned to all instances for that class if + ``method`` produces columns per class, as in {'decision_function', + 'predict_proba', 'predict_log_proba'}. For ``predict_proba`` this value is + 0. In order to ensure finite output, we approximate negative infinity by + the minimum finite float value for the dtype in other cases. + Examples -------- >>> from sklearn import datasets, linear_model @@ -746,12 +755,49 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params, predictions = func(X_test) if method in ['decision_function', 'predict_proba', 'predict_log_proba']: n_classes = len(set(y)) - predictions_ = np.zeros((_num_samples(X_test), n_classes)) - if method == 'decision_function' and len(estimator.classes_) == 2: - predictions_[:, estimator.classes_[-1]] = predictions - else: - predictions_[:, estimator.classes_] = predictions - predictions = predictions_ + if n_classes != len(estimator.classes_): + recommendation = ( + 'To fix this, use a cross-validation ' + 'technique resulting in properly ' + 'stratified folds') + warnings.warn('Number of classes in training fold ({}) does ' + 'not match total number of classes ({}). ' + 'Results may not be appropriate for your use case. ' + '{}'.format(len(estimator.classes_), + n_classes, recommendation), + RuntimeWarning) + if method == 'decision_function': + if (predictions.ndim == 2 and + predictions.shape[1] != len(estimator.classes_)): + # This handles the case when the shape of predictions + # does not match the number of classes used to train + # it with. This case is found when sklearn.svm.SVC is + # set to `decision_function_shape='ovo'`. + raise ValueError('Output shape {} of {} does not match ' + 'number of classes ({}) in fold. ' + 'Irregular decision_function outputs ' + 'are not currently supported by ' + 'cross_val_predict'.format( + predictions.shape, method, + len(estimator.classes_), + recommendation)) + if len(estimator.classes_) <= 2: + # In this special case, `predictions` contains a 1D array. + raise ValueError('Only {} class/es in training fold, this ' + 'is not supported for decision_function ' + 'with imbalanced folds. {}'.format( + len(estimator.classes_), + recommendation)) + + float_min = np.finfo(predictions.dtype).min + default_values = {'decision_function': float_min, + 'predict_log_proba': float_min, + 'predict_proba': 0} + predictions_for_all_classes = np.full((_num_samples(predictions), + n_classes), + default_values[method]) + predictions_for_all_classes[:, estimator.classes_] = predictions + predictions = predictions_for_all_classes return predictions, test diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index d57be1e835..b7b1dd781e 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -24,6 +24,7 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_warns_message from sklearn.utils.mocking import CheckingClassifier, MockDataFrame from sklearn.model_selection import cross_val_score @@ -44,6 +45,7 @@ from sklearn.model_selection._validation import _check_is_permutation from sklearn.datasets import make_regression from sklearn.datasets import load_boston from sklearn.datasets import load_iris +from sklearn.datasets import load_digits from sklearn.metrics import explained_variance_score from sklearn.metrics import make_scorer from sklearn.metrics import accuracy_score @@ -54,7 +56,7 @@ from sklearn.metrics import r2_score from sklearn.metrics.scorer import check_scoring from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier -from sklearn.linear_model import PassiveAggressiveClassifier +from sklearn.linear_model import PassiveAggressiveClassifier, RidgeClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.cluster import KMeans @@ -800,6 +802,89 @@ def test_cross_val_predict(): assert_raises(ValueError, cross_val_predict, est, X, y, cv=BadCV()) + X, y = load_iris(return_X_y=True) + + warning_message = ('Number of classes in training fold (2) does ' + 'not match total number of classes (3). ' + 'Results may not be appropriate for your use case.') + assert_warns_message(RuntimeWarning, warning_message, + cross_val_predict, LogisticRegression(), + X, y, method='predict_proba', cv=KFold(2)) + + +def test_cross_val_predict_decision_function_shape(): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='decision_function') + assert_equal(preds.shape, (50,)) + + X, y = load_iris(return_X_y=True) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='decision_function') + assert_equal(preds.shape, (150, 3)) + + # This specifically tests imbalanced splits for binary + # classification with decision_function. This is only + # applicable to classifiers that can be fit on a single + # class. + X = X[:100] + y = y[:100] + assert_raise_message(ValueError, + 'Only 1 class/es in training fold, this' + ' is not supported for decision_function' + ' with imbalanced folds. To fix ' + 'this, use a cross-validation technique ' + 'resulting in properly stratified folds', + cross_val_predict, RidgeClassifier(), X, y, + method='decision_function', cv=KFold(2)) + + X, y = load_digits(return_X_y=True) + est = SVC(kernel='linear', decision_function_shape='ovo') + + preds = cross_val_predict(est, + X, y, + method='decision_function') + assert_equal(preds.shape, (1797, 45)) + + ind = np.argsort(y) + X, y = X[ind], y[ind] + assert_raises_regex(ValueError, + 'Output shape \(599L?, 21L?\) of decision_function ' + 'does not match number of classes \(7\) in fold. ' + 'Irregular decision_function .*', + cross_val_predict, est, X, y, + cv=KFold(n_splits=3), method='decision_function') + + +def test_cross_val_predict_predict_proba_shape(): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='predict_proba') + assert_equal(preds.shape, (50, 2)) + + X, y = load_iris(return_X_y=True) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='predict_proba') + assert_equal(preds.shape, (150, 3)) + + +def test_cross_val_predict_predict_log_proba_shape(): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='predict_log_proba') + assert_equal(preds.shape, (50, 2)) + + X, y = load_iris(return_X_y=True) + + preds = cross_val_predict(LogisticRegression(), X, y, + method='predict_log_proba') + assert_equal(preds.shape, (150, 3)) + def test_cross_val_predict_input_types(): iris = load_iris() @@ -1241,11 +1326,12 @@ def get_expected_predictions(X, y, cv, classes, est, method): est.fit(X[train], y[train]) expected_predictions_ = func(X[test]) # To avoid 2 dimensional indexing - exp_pred_test = np.zeros((len(test), classes)) - if method is 'decision_function' and len(est.classes_) == 2: - exp_pred_test[:, est.classes_[-1]] = expected_predictions_ + if method is 'predict_proba': + exp_pred_test = np.zeros((len(test), classes)) else: - exp_pred_test[:, est.classes_] = expected_predictions_ + exp_pred_test = np.full((len(test), classes), + np.finfo(expected_predictions.dtype).min) + exp_pred_test[:, est.classes_] = expected_predictions_ expected_predictions[test] = exp_pred_test return expected_predictions @@ -1253,9 +1339,9 @@ def get_expected_predictions(X, y, cv, classes, est, method): def test_cross_val_predict_class_subset(): - X = np.arange(8).reshape(4, 2) - y = np.array([0, 0, 1, 2]) - classes = 3 + X = np.arange(200).reshape(100, 2) + y = np.array([x//10 for x in range(100)]) + classes = 10 kfold3 = KFold(n_splits=3) kfold4 = KFold(n_splits=4) @@ -1283,7 +1369,7 @@ def test_cross_val_predict_class_subset(): assert_array_almost_equal(expected_predictions, predictions) # Testing unordered labels - y = [1, 1, -4, 6] + y = shuffle(np.repeat(range(10), 10), random_state=0) predictions = cross_val_predict(est, X, y, method=method, cv=kfold3) y = le.fit_transform(y) -- GitLab