From 1f0815bdcff70d93f70978c20e61b1ca3b747c22 Mon Sep 17 00:00:00 2001 From: Arnaud Joly <arnaud.v.joly@gmail.com> Date: Wed, 7 Aug 2013 10:23:37 +0200 Subject: [PATCH] ENH more explicit name for auc + consistency for scorer, fix #2096 Conflicts: sklearn/metrics/tests/test_metrics.py --- doc/modules/classes.rst | 2 +- doc/modules/model_evaluation.rst | 26 +++++--- doc/whats_new.rst | 2 + sklearn/metrics/__init__.py | 7 +- sklearn/metrics/metrics.py | 61 +++++++++++++++-- sklearn/metrics/scorer.py | 8 +-- sklearn/metrics/tests/test_metrics.py | 74 +++++++++++++++++---- sklearn/metrics/tests/test_score_objects.py | 8 +-- sklearn/tests/test_grid_search.py | 4 +- 9 files changed, 148 insertions(+), 44 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 987ccc548a..a577f9fe66 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -715,7 +715,6 @@ details. metrics.accuracy_score metrics.auc - metrics.auc_score metrics.average_precision_score metrics.classification_report metrics.confusion_matrix @@ -730,6 +729,7 @@ details. metrics.precision_recall_fscore_support metrics.precision_score metrics.recall_score + metrics.roc_auc_score metrics.roc_curve metrics.zero_one_loss diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 869ff37083..bc6b68c815 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -56,7 +56,7 @@ Scoring Function 'f1' :func:`sklearn.metrics.f1_score` 'precision' :func:`sklearn.metrics.precision_score` 'recall' :func:`sklearn.metrics.recall_score` -'roc_auc' :func:`sklearn.metrics.auc_score` +'roc_auc' :func:`sklearn.metrics.roc_auc_score` **Clustering** 'adjusted_rand_score' :func:`sklearn.metrics.adjusted_rand_score` @@ -182,11 +182,11 @@ Some of these are restricted to the binary classification case: .. autosummary:: :template: function.rst - auc_score average_precision_score hinge_loss matthews_corrcoef precision_recall_curve + roc_auc_score roc_curve @@ -268,21 +268,21 @@ and with a list of labels format: for an example of accuracy score usage using permutations of the dataset. -Area under the curve (AUC) -........................... +Area under the ROC curve +......................... -The :func:`auc_score` function computes the 'area under the curve' (AUC) which -is the area under the receiver operating characteristic (ROC) curve. +The :func:`roc_auc_score` function computes the area under the receiver +operating characteristic (ROC) curve. -This function requires the true binary value and the target scores, which can +This function requires the true binary value and the target scores, which can either be probability estimates of the positive class, confidence values, or binary decisions. >>> import numpy as np - >>> from sklearn.metrics import auc_score + >>> from sklearn.metrics import roc_auc_score >>> y_true = np.array([0, 0, 1, 1]) >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) - >>> auc_score(y_true, y_scores) + >>> roc_auc_score(y_true, y_scores) 0.75 For more information see the @@ -812,12 +812,16 @@ Wikipedia) <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_: Here a small example of how to use the :func:`roc_curve` function:: >>> import numpy as np - >>> from sklearn import metrics + >>> from sklearn.metrics import roc_curve >>> y = np.array([1, 1, 2, 2]) >>> scores = np.array([0.1, 0.4, 0.35, 0.8]) - >>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2) + >>> fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2) >>> fpr array([ 0. , 0.5, 0.5, 1. ]) + >>> tpr + array([ 0.5, 0.5, 1. , 1. ]) + >>> thresholds + array([ 0.8 , 0.4 , 0.35, 0.1 ]) The following figure shows an example of such ROC curve. diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 84704937f1..84c002772e 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -212,6 +212,8 @@ Changelog API changes summary ------------------- + - The :func:`auc_score` was renamed :func:`roc_auc_score`. + - Testing scikit-learn with `sklearn.test()` is deprecated. Use `nosetest sklearn` from the command line. diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f3aca8f62b..1b63fc1c8b 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -6,7 +6,7 @@ and pairwise metrics and distance computations. from .metrics import (accuracy_score, average_precision_score, auc, - auc_score, + roc_auc_score, classification_report, confusion_matrix, explained_variance_score, @@ -31,6 +31,9 @@ from .metrics import (accuracy_score, from .metrics import zero_one from .metrics import zero_one_score +# Deprecated in 0.16 +from .metrics import auc_score + from .scorer import make_scorer, SCORERS from . import cluster @@ -54,7 +57,7 @@ __all__ = ['accuracy_score', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'auc', - 'auc_score', + 'roc_auc_score', 'average_precision_score', 'classification_report', 'cluster', diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 982f34e2ac..2b48f458f6 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -133,7 +133,7 @@ def auc(x, y, reorder=False): """Compute Area Under the Curve (AUC) using the trapezoidal rule This is a general function, given points on a curve. For computing the - area under the ROC-curve, see :func:`auc_score`. + area under the ROC-curve, see :func:`roc_auc_score`. Parameters ---------- @@ -163,7 +163,10 @@ def auc(x, y, reorder=False): See also -------- - auc_score : Computes the area under the ROC curve + roc_auc_score : Computes the area under the ROC curve + + precision_recall_curve : + Compute precision-recall pairs for different probability thresholds """ x, y = check_arrays(x, y) @@ -292,7 +295,7 @@ def average_precision_score(y_true, y_score): See also -------- - auc_score : Area under the ROC curve + roc_auc_score : Area under the ROC curve precision_recall_curve : Compute precision-recall pairs for different probability thresholds @@ -310,7 +313,8 @@ def average_precision_score(y_true, y_score): precision, recall, thresholds = precision_recall_curve(y_true, y_score) return auc(recall, precision) - +@deprecated("Function 'auc_score' has been renamed to " + "'roc_auc_score' and will be removed in release 0.16.") def auc_score(y_true, y_score): """Compute Area Under the Curve (AUC) from prediction scores @@ -344,10 +348,53 @@ def auc_score(y_true, y_score): Examples -------- >>> import numpy as np - >>> from sklearn.metrics import auc_score + >>> from sklearn.metrics import roc_auc_score + >>> y_true = np.array([0, 0, 1, 1]) + >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) + >>> roc_auc_score(y_true, y_scores) + 0.75 + + """ + return roc_auc_score(y_true, y_score) + + +def roc_auc_score(y_true, y_score): + """Compute Area Under the Curve (AUC) from prediction scores + + Note: this implementation is restricted to the binary classification task. + + Parameters + ---------- + + y_true : array, shape = [n_samples] + True binary labels. + + y_score : array, shape = [n_samples] + Target scores, can either be probability estimates of the positive + class, confidence values, or binary decisions. + + Returns + ------- + auc : float + + References + ---------- + .. [1] `Wikipedia entry for the Receiver operating characteristic + <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_ + + See also + -------- + average_precision_score : Area under the precision-recall curve + + roc_curve : Compute Receiver operating characteristic (ROC) + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import roc_auc_score >>> y_true = np.array([0, 0, 1, 1]) >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) - >>> auc_score(y_true, y_scores) + >>> roc_auc_score(y_true, y_scores) 0.75 """ @@ -593,7 +640,7 @@ def roc_curve(y_true, y_score, pos_label=None): See also -------- - auc_score : Compute Area Under the Curve (AUC) from prediction scores + roc_auc_score : Compute Area Under the Curve (AUC) from prediction scores Notes ----- diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index fb4a55b36b..e5c829cad3 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -23,7 +23,7 @@ from warnings import warn import numpy as np from . import (r2_score, mean_squared_error, accuracy_score, f1_score, - auc_score, average_precision_score, precision_score, + roc_auc_score, average_precision_score, precision_score, recall_score, log_loss) from .cluster import adjusted_rand_score @@ -253,8 +253,8 @@ accuracy_scorer = make_scorer(accuracy_score) f1_scorer = make_scorer(f1_score) # Score functions that need decision values -auc_scorer = make_scorer(auc_score, greater_is_better=True, - needs_threshold=True) +roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True, + needs_threshold=True) average_precision_scorer = make_scorer(average_precision_score, needs_threshold=True) precision_scorer = make_scorer(precision_score) @@ -269,7 +269,7 @@ adjusted_rand_scorer = make_scorer(adjusted_rand_score) SCORERS = dict(r2=r2_scorer, mean_squared_error=mean_squared_error_scorer, - accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=auc_scorer, + accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=roc_auc_scorer, average_precision=average_precision_scorer, precision=precision_scorer, recall=recall_scorer, log_loss=log_loss_scorer, diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index aef165e291..5a7dec3783 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -46,6 +46,7 @@ from sklearn.metrics import (accuracy_score, precision_score, recall_score, r2_score, + roc_auc_score, roc_curve, zero_one, zero_one_score, @@ -106,7 +107,7 @@ CLASSIFICATION_METRICS = { } THRESHOLDED_METRICS = { - "auc_score": auc_score, + "roc_auc_score": roc_auc_score, "average_precision_score": average_precision_score, } @@ -299,14 +300,33 @@ def make_prediction(dataset=None, binary=False): return y_true, y_pred, probas_pred +def _auc(y_true, y_score): + pos_label = np.unique(y_true)[1] + + # Count the number of times positive samples are correctly ranked above + # negative samples. + pos = y_score[y_true == pos_label] + neg = y_score[y_true != pos_label] + diff_matrix = pos.reshape(1, -1) - neg.reshape(-1, 1) + n_correct = np.sum(diff_matrix > 0) + + return n_correct / float(len(pos) * len(neg)) + + def test_roc_curve(): """Test Area under Receiver Operating Characteristic (ROC) curve""" y_true, _, probas_pred = make_prediction(binary=True) fpr, tpr, thresholds = roc_curve(y_true, probas_pred) roc_auc = auc(fpr, tpr) - assert_array_almost_equal(roc_auc, 0.90, decimal=2) - assert_almost_equal(roc_auc, auc_score(y_true, probas_pred)) + expected_auc = _auc(y_true, probas_pred) + assert_array_almost_equal(roc_auc, expected_auc, decimal=2) + assert_almost_equal(roc_auc, roc_auc_score(y_true, probas_pred)) + + with warnings.catch_warnings(record=True): + assert_almost_equal(roc_auc, auc_score(y_true, probas_pred)) + + assert_equal(fpr.shape, tpr.shape) assert_equal(fpr.shape, thresholds.shape) @@ -461,7 +481,7 @@ def test_auc_errors(): def test_auc_score_non_binary_class(): - """Test that auc_score function returns an error when trying to compute AUC + """Test that roc_auc_score function returns an error when trying to compute AUC for non-binary class values. """ rng = check_random_state(404) @@ -469,18 +489,39 @@ def test_auc_score_non_binary_class(): # y_true contains only one class value y_true = np.zeros(10, dtype="int") assert_raise_message(ValueError, "AUC is defined for binary " - "classification only", auc_score, y_true, y_pred) + "classification only", roc_auc_score, y_true, y_pred) y_true = np.ones(10, dtype="int") assert_raise_message(ValueError, "AUC is defined for binary " - "classification only", auc_score, y_true, y_pred) + "classification only", roc_auc_score, y_true, y_pred) y_true = -np.ones(10, dtype="int") assert_raise_message(ValueError, "AUC is defined for binary " - "classification only", auc_score, y_true, y_pred) + "classification only", roc_auc_score, y_true, y_pred) # y_true contains three different class values y_true = rng.randint(0, 3, size=10) assert_raise_message(ValueError, "AUC is defined for binary " - "classification only", auc_score, y_true, y_pred) + "classification only", roc_auc_score, y_true, y_pred) + with warnings.catch_warnings(record=True): + rng = check_random_state(404) + y_pred = rng.rand(10) + # y_true contains only one class value + y_true = np.zeros(10, dtype="int") + assert_raise_message(ValueError, "AUC is defined for binary " + "classification only", auc_score, + y_true, y_pred) + y_true = np.ones(10, dtype="int") + assert_raise_message(ValueError, "AUC is defined for binary " + "classification only", auc_score, y_true, + y_pred) + y_true = -np.ones(10, dtype="int") + assert_raise_message(ValueError, "AUC is defined for binary " + "classification only", auc_score, y_true, + y_pred) + # y_true contains three different class values + y_true = rng.randint(0, 3, size=10) + assert_raise_message(ValueError, "AUC is defined for binary " + "classification only", auc_score, y_true, + y_pred) def test_precision_recall_f1_score_binary(): """Test Precision Recall and F1 Score for binary classification task""" @@ -871,16 +912,23 @@ def test_precision_recall_curve_errors(): def test_score_scale_invariance(): - # Test that average_precision_score and auc_score are invariant by + # Test that average_precision_score and roc_auc_score are invariant by # the scaling or shifting of probabilities y_true, _, probas_pred = make_prediction(binary=True) - roc_auc = auc_score(y_true, probas_pred) - roc_auc_scaled = auc_score(y_true, 100 * probas_pred) - roc_auc_shifted = auc_score(y_true, probas_pred - 10) + roc_auc = roc_auc_score(y_true, probas_pred) + roc_auc_scaled = roc_auc_score(y_true, 100 * probas_pred) + roc_auc_shifted = roc_auc_score(y_true, probas_pred - 10) assert_equal(roc_auc, roc_auc_scaled) assert_equal(roc_auc, roc_auc_shifted) + with warnings.catch_warnings(): + roc_auc = auc_score(y_true, probas_pred) + roc_auc_scaled = auc_score(y_true, 100 * probas_pred) + roc_auc_shifted = auc_score(y_true, probas_pred - 10) + assert_equal(roc_auc, roc_auc_scaled) + assert_equal(roc_auc, roc_auc_shifted) + pr_auc = average_precision_score(y_true, probas_pred) pr_auc_scaled = average_precision_score(y_true, 100 * probas_pred) pr_auc_shifted = average_precision_score(y_true, probas_pred - 10) @@ -912,7 +960,7 @@ def test_losses(): 1 - zero_one_loss(y_true, y_pred)) with warnings.catch_warnings(record=True): - # Throw deprecated warning + # Throw deprecated warning assert_equal(zero_one_score(y_true, y_pred), 1 - zero_one_loss(y_true, y_pred)) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 8c4eb4e275..040a24fafb 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -3,7 +3,7 @@ import pickle from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises -from sklearn.metrics import (f1_score, r2_score, auc_score, fbeta_score, +from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, log_loss) from sklearn.metrics.cluster import adjusted_rand_score from sklearn.metrics import make_scorer, SCORERS @@ -67,8 +67,8 @@ def test_thresholded_scorers(): clf = LogisticRegression(random_state=0) clf.fit(X_train, y_train) score1 = SCORERS['roc_auc'](clf, X_test, y_test) - score2 = auc_score(y_test, clf.decision_function(X_test)) - score3 = auc_score(y_test, clf.predict_proba(X_test)[:, 1]) + score2 = roc_auc_score(y_test, clf.decision_function(X_test)) + score3 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) assert_almost_equal(score1, score2) assert_almost_equal(score1, score3) @@ -80,7 +80,7 @@ def test_thresholded_scorers(): clf = DecisionTreeClassifier() clf.fit(X_train, y_train) score1 = SCORERS['roc_auc'](clf, X_test, y_test) - score2 = auc_score(y_test, clf.predict_proba(X_test)[:, 1]) + score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) assert_almost_equal(score1, score2) # Test that an exception is raised on more than two classes diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index c7c3a70c7d..c932b7bb10 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -36,7 +36,7 @@ from sklearn.tree import DecisionTreeClassifier from sklearn.cluster import KMeans, MeanShift from sklearn.metrics import f1_score from sklearn.metrics import make_scorer -from sklearn.metrics import auc_score +from sklearn.metrics import roc_auc_score from sklearn.cross_validation import KFold, StratifiedKFold @@ -572,7 +572,7 @@ def test_grid_search_score_consistency(): if score == "f1": correct_score = f1_score(y[test], clf.predict(X[test])) elif score == "roc_auc": - correct_score = auc_score(y[test], + correct_score = roc_auc_score(y[test], clf.decision_function(X[test])) assert_almost_equal(correct_score, scores[i]) i += 1 -- GitLab