From 1f0815bdcff70d93f70978c20e61b1ca3b747c22 Mon Sep 17 00:00:00 2001
From: Arnaud Joly <arnaud.v.joly@gmail.com>
Date: Wed, 7 Aug 2013 10:23:37 +0200
Subject: [PATCH] ENH more explicit name for auc + consistency for scorer, fix
 #2096

Conflicts:
	sklearn/metrics/tests/test_metrics.py
---
 doc/modules/classes.rst                     |  2 +-
 doc/modules/model_evaluation.rst            | 26 +++++---
 doc/whats_new.rst                           |  2 +
 sklearn/metrics/__init__.py                 |  7 +-
 sklearn/metrics/metrics.py                  | 61 +++++++++++++++--
 sklearn/metrics/scorer.py                   |  8 +--
 sklearn/metrics/tests/test_metrics.py       | 74 +++++++++++++++++----
 sklearn/metrics/tests/test_score_objects.py |  8 +--
 sklearn/tests/test_grid_search.py           |  4 +-
 9 files changed, 148 insertions(+), 44 deletions(-)

diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 987ccc548a..a577f9fe66 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -715,7 +715,6 @@ details.
 
    metrics.accuracy_score
    metrics.auc
-   metrics.auc_score
    metrics.average_precision_score
    metrics.classification_report
    metrics.confusion_matrix
@@ -730,6 +729,7 @@ details.
    metrics.precision_recall_fscore_support
    metrics.precision_score
    metrics.recall_score
+   metrics.roc_auc_score
    metrics.roc_curve
    metrics.zero_one_loss
 
diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst
index 869ff37083..bc6b68c815 100644
--- a/doc/modules/model_evaluation.rst
+++ b/doc/modules/model_evaluation.rst
@@ -56,7 +56,7 @@ Scoring                    Function
 'f1'                       :func:`sklearn.metrics.f1_score`
 'precision'                :func:`sklearn.metrics.precision_score`
 'recall'                   :func:`sklearn.metrics.recall_score`
-'roc_auc'                  :func:`sklearn.metrics.auc_score`
+'roc_auc'                  :func:`sklearn.metrics.roc_auc_score`
 
 **Clustering**
 'adjusted_rand_score'      :func:`sklearn.metrics.adjusted_rand_score`
@@ -182,11 +182,11 @@ Some of these are restricted to the binary classification case:
 .. autosummary::
    :template: function.rst
 
-   auc_score
    average_precision_score
    hinge_loss
    matthews_corrcoef
    precision_recall_curve
+   roc_auc_score
    roc_curve
 
 
@@ -268,21 +268,21 @@ and with a list of labels format:
     for an example of accuracy score usage using permutations of
     the dataset.
 
-Area under the curve (AUC)
-...........................
+Area under the ROC curve
+.........................
 
-The :func:`auc_score` function computes the 'area under the curve' (AUC) which
-is the area under the receiver operating characteristic (ROC) curve.
+The :func:`roc_auc_score` function computes the area under the receiver
+operating characteristic (ROC) curve.
 
-This function requires  the true binary value and the target scores, which can
+This function requires the true binary value and the target scores, which can
 either be probability estimates of the positive class, confidence values, or
 binary decisions.
 
   >>> import numpy as np
-  >>> from sklearn.metrics import auc_score
+  >>> from sklearn.metrics import roc_auc_score
   >>> y_true = np.array([0, 0, 1, 1])
   >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
-  >>> auc_score(y_true, y_scores)
+  >>> roc_auc_score(y_true, y_scores)
   0.75
 
 For more information see the
@@ -812,12 +812,16 @@ Wikipedia) <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_:
 Here a small example of how to use the :func:`roc_curve` function::
 
     >>> import numpy as np
-    >>> from sklearn import metrics
+    >>> from sklearn.metrics import roc_curve
     >>> y = np.array([1, 1, 2, 2])
     >>> scores = np.array([0.1, 0.4, 0.35, 0.8])
-    >>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
+    >>> fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2)
     >>> fpr
     array([ 0. ,  0.5,  0.5,  1. ])
+    >>> tpr
+    array([ 0.5,  0.5,  1. ,  1. ])
+    >>> thresholds
+    array([ 0.8 ,  0.4 ,  0.35,  0.1 ])
 
 
 The following figure shows an example of such ROC curve.
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 84704937f1..84c002772e 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -212,6 +212,8 @@ Changelog
 API changes summary
 -------------------
 
+   - The :func:`auc_score` was renamed :func:`roc_auc_score`.
+
    - Testing scikit-learn with `sklearn.test()` is deprecated. Use
      `nosetest sklearn` from the command line.
 
diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py
index f3aca8f62b..1b63fc1c8b 100644
--- a/sklearn/metrics/__init__.py
+++ b/sklearn/metrics/__init__.py
@@ -6,7 +6,7 @@ and pairwise metrics and distance computations.
 from .metrics import (accuracy_score,
                       average_precision_score,
                       auc,
-                      auc_score,
+                      roc_auc_score,
                       classification_report,
                       confusion_matrix,
                       explained_variance_score,
@@ -31,6 +31,9 @@ from .metrics import (accuracy_score,
 from .metrics import zero_one
 from .metrics import zero_one_score
 
+# Deprecated in 0.16
+from .metrics import auc_score
+
 from .scorer import make_scorer, SCORERS
 
 from . import cluster
@@ -54,7 +57,7 @@ __all__ = ['accuracy_score',
            'adjusted_mutual_info_score',
            'adjusted_rand_score',
            'auc',
-           'auc_score',
+           'roc_auc_score',
            'average_precision_score',
            'classification_report',
            'cluster',
diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py
index 982f34e2ac..2b48f458f6 100644
--- a/sklearn/metrics/metrics.py
+++ b/sklearn/metrics/metrics.py
@@ -133,7 +133,7 @@ def auc(x, y, reorder=False):
     """Compute Area Under the Curve (AUC) using the trapezoidal rule
 
     This is a general function, given points on a curve.  For computing the
-    area under the ROC-curve, see :func:`auc_score`.
+    area under the ROC-curve, see :func:`roc_auc_score`.
 
     Parameters
     ----------
@@ -163,7 +163,10 @@ def auc(x, y, reorder=False):
 
     See also
     --------
-    auc_score : Computes the area under the ROC curve
+    roc_auc_score : Computes the area under the ROC curve
+
+    precision_recall_curve :
+        Compute precision-recall pairs for different probability thresholds
 
     """
     x, y = check_arrays(x, y)
@@ -292,7 +295,7 @@ def average_precision_score(y_true, y_score):
 
     See also
     --------
-    auc_score : Area under the ROC curve
+    roc_auc_score : Area under the ROC curve
 
     precision_recall_curve :
         Compute precision-recall pairs for different probability thresholds
@@ -310,7 +313,8 @@ def average_precision_score(y_true, y_score):
     precision, recall, thresholds = precision_recall_curve(y_true, y_score)
     return auc(recall, precision)
 
-
+@deprecated("Function 'auc_score' has been renamed to "
+            "'roc_auc_score' and will be removed in release 0.16.")
 def auc_score(y_true, y_score):
     """Compute Area Under the Curve (AUC) from prediction scores
 
@@ -344,10 +348,53 @@ def auc_score(y_true, y_score):
     Examples
     --------
     >>> import numpy as np
-    >>> from sklearn.metrics import auc_score
+    >>> from sklearn.metrics import roc_auc_score
+    >>> y_true = np.array([0, 0, 1, 1])
+    >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
+    >>> roc_auc_score(y_true, y_scores)
+    0.75
+
+    """
+    return roc_auc_score(y_true, y_score)
+
+
+def roc_auc_score(y_true, y_score):
+    """Compute Area Under the Curve (AUC) from prediction scores
+
+    Note: this implementation is restricted to the binary classification task.
+
+    Parameters
+    ----------
+
+    y_true : array, shape = [n_samples]
+        True binary labels.
+
+    y_score : array, shape = [n_samples]
+        Target scores, can either be probability estimates of the positive
+        class, confidence values, or binary decisions.
+
+    Returns
+    -------
+    auc : float
+
+    References
+    ----------
+    .. [1] `Wikipedia entry for the Receiver operating characteristic
+            <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
+
+    See also
+    --------
+    average_precision_score : Area under the precision-recall curve
+
+    roc_curve : Compute Receiver operating characteristic (ROC)
+
+    Examples
+    --------
+    >>> import numpy as np
+    >>> from sklearn.metrics import roc_auc_score
     >>> y_true = np.array([0, 0, 1, 1])
     >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
-    >>> auc_score(y_true, y_scores)
+    >>> roc_auc_score(y_true, y_scores)
     0.75
 
     """
@@ -593,7 +640,7 @@ def roc_curve(y_true, y_score, pos_label=None):
 
     See also
     --------
-    auc_score : Compute Area Under the Curve (AUC) from prediction scores
+    roc_auc_score : Compute Area Under the Curve (AUC) from prediction scores
 
     Notes
     -----
diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py
index fb4a55b36b..e5c829cad3 100644
--- a/sklearn/metrics/scorer.py
+++ b/sklearn/metrics/scorer.py
@@ -23,7 +23,7 @@ from warnings import warn
 import numpy as np
 
 from . import (r2_score, mean_squared_error, accuracy_score, f1_score,
-               auc_score, average_precision_score, precision_score,
+               roc_auc_score, average_precision_score, precision_score,
                recall_score, log_loss)
 
 from .cluster import adjusted_rand_score
@@ -253,8 +253,8 @@ accuracy_scorer = make_scorer(accuracy_score)
 f1_scorer = make_scorer(f1_score)
 
 # Score functions that need decision values
-auc_scorer = make_scorer(auc_score, greater_is_better=True,
-                         needs_threshold=True)
+roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
+                             needs_threshold=True)
 average_precision_scorer = make_scorer(average_precision_score,
                                        needs_threshold=True)
 precision_scorer = make_scorer(precision_score)
@@ -269,7 +269,7 @@ adjusted_rand_scorer = make_scorer(adjusted_rand_score)
 
 SCORERS = dict(r2=r2_scorer,
                mean_squared_error=mean_squared_error_scorer,
-               accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=auc_scorer,
+               accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=roc_auc_scorer,
                average_precision=average_precision_scorer,
                precision=precision_scorer, recall=recall_scorer,
                log_loss=log_loss_scorer,
diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py
index aef165e291..5a7dec3783 100644
--- a/sklearn/metrics/tests/test_metrics.py
+++ b/sklearn/metrics/tests/test_metrics.py
@@ -46,6 +46,7 @@ from sklearn.metrics import (accuracy_score,
                              precision_score,
                              recall_score,
                              r2_score,
+                             roc_auc_score,
                              roc_curve,
                              zero_one,
                              zero_one_score,
@@ -106,7 +107,7 @@ CLASSIFICATION_METRICS = {
 }
 
 THRESHOLDED_METRICS = {
-    "auc_score": auc_score,
+    "roc_auc_score": roc_auc_score,
     "average_precision_score": average_precision_score,
 }
 
@@ -299,14 +300,33 @@ def make_prediction(dataset=None, binary=False):
     return y_true, y_pred, probas_pred
 
 
+def _auc(y_true, y_score):
+    pos_label = np.unique(y_true)[1]
+
+    # Count the number of times positive samples are correctly ranked above
+    # negative samples.
+    pos = y_score[y_true == pos_label]
+    neg = y_score[y_true != pos_label]
+    diff_matrix = pos.reshape(1, -1) - neg.reshape(-1, 1)
+    n_correct = np.sum(diff_matrix > 0)
+
+    return n_correct / float(len(pos) * len(neg))
+
+
 def test_roc_curve():
     """Test Area under Receiver Operating Characteristic (ROC) curve"""
     y_true, _, probas_pred = make_prediction(binary=True)
 
     fpr, tpr, thresholds = roc_curve(y_true, probas_pred)
     roc_auc = auc(fpr, tpr)
-    assert_array_almost_equal(roc_auc, 0.90, decimal=2)
-    assert_almost_equal(roc_auc, auc_score(y_true, probas_pred))
+    expected_auc = _auc(y_true, probas_pred)
+    assert_array_almost_equal(roc_auc, expected_auc, decimal=2)
+    assert_almost_equal(roc_auc, roc_auc_score(y_true, probas_pred))
+
+    with warnings.catch_warnings(record=True):
+        assert_almost_equal(roc_auc, auc_score(y_true, probas_pred))
+
+
     assert_equal(fpr.shape, tpr.shape)
     assert_equal(fpr.shape, thresholds.shape)
 
@@ -461,7 +481,7 @@ def test_auc_errors():
 
 
 def test_auc_score_non_binary_class():
-    """Test that auc_score function returns an error when trying to compute AUC
+    """Test that roc_auc_score function returns an error when trying to compute AUC
     for non-binary class values.
     """
     rng = check_random_state(404)
@@ -469,18 +489,39 @@ def test_auc_score_non_binary_class():
     # y_true contains only one class value
     y_true = np.zeros(10, dtype="int")
     assert_raise_message(ValueError, "AUC is defined for binary "
-                         "classification only", auc_score, y_true, y_pred)
+                         "classification only", roc_auc_score, y_true, y_pred)
     y_true = np.ones(10, dtype="int")
     assert_raise_message(ValueError, "AUC is defined for binary "
-                         "classification only", auc_score, y_true, y_pred)
+                         "classification only", roc_auc_score, y_true, y_pred)
     y_true = -np.ones(10, dtype="int")
     assert_raise_message(ValueError, "AUC is defined for binary "
-                         "classification only", auc_score, y_true, y_pred)
+                         "classification only", roc_auc_score, y_true, y_pred)
     # y_true contains three different class values
     y_true = rng.randint(0, 3, size=10)
     assert_raise_message(ValueError, "AUC is defined for binary "
-                         "classification only", auc_score, y_true, y_pred)
+                         "classification only", roc_auc_score, y_true, y_pred)
 
+    with warnings.catch_warnings(record=True):
+        rng = check_random_state(404)
+        y_pred = rng.rand(10)
+        # y_true contains only one class value
+        y_true = np.zeros(10, dtype="int")
+        assert_raise_message(ValueError, "AUC is defined for binary "
+                             "classification only", auc_score,
+                             y_true, y_pred)
+        y_true = np.ones(10, dtype="int")
+        assert_raise_message(ValueError, "AUC is defined for binary "
+                             "classification only", auc_score, y_true,
+                             y_pred)
+        y_true = -np.ones(10, dtype="int")
+        assert_raise_message(ValueError, "AUC is defined for binary "
+                             "classification only", auc_score, y_true,
+                             y_pred)
+        # y_true contains three different class values
+        y_true = rng.randint(0, 3, size=10)
+        assert_raise_message(ValueError, "AUC is defined for binary "
+                             "classification only", auc_score, y_true,
+                             y_pred)
 
 def test_precision_recall_f1_score_binary():
     """Test Precision Recall and F1 Score for binary classification task"""
@@ -871,16 +912,23 @@ def test_precision_recall_curve_errors():
 
 
 def test_score_scale_invariance():
-    # Test that average_precision_score and auc_score are invariant by
+    # Test that average_precision_score and roc_auc_score are invariant by
     # the scaling or shifting of probabilities
     y_true, _, probas_pred = make_prediction(binary=True)
 
-    roc_auc = auc_score(y_true, probas_pred)
-    roc_auc_scaled = auc_score(y_true, 100 * probas_pred)
-    roc_auc_shifted = auc_score(y_true, probas_pred - 10)
+    roc_auc = roc_auc_score(y_true, probas_pred)
+    roc_auc_scaled = roc_auc_score(y_true, 100 * probas_pred)
+    roc_auc_shifted = roc_auc_score(y_true, probas_pred - 10)
     assert_equal(roc_auc, roc_auc_scaled)
     assert_equal(roc_auc, roc_auc_shifted)
 
+    with warnings.catch_warnings():
+        roc_auc = auc_score(y_true, probas_pred)
+        roc_auc_scaled = auc_score(y_true, 100 * probas_pred)
+        roc_auc_shifted = auc_score(y_true, probas_pred - 10)
+        assert_equal(roc_auc, roc_auc_scaled)
+        assert_equal(roc_auc, roc_auc_shifted)
+
     pr_auc = average_precision_score(y_true, probas_pred)
     pr_auc_scaled = average_precision_score(y_true, 100 * probas_pred)
     pr_auc_shifted = average_precision_score(y_true, probas_pred - 10)
@@ -912,7 +960,7 @@ def test_losses():
                  1 - zero_one_loss(y_true, y_pred))
 
     with warnings.catch_warnings(record=True):
-    # Throw deprecated warning
+        # Throw deprecated warning
         assert_equal(zero_one_score(y_true, y_pred),
                      1 - zero_one_loss(y_true, y_pred))
 
diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py
index 8c4eb4e275..040a24fafb 100644
--- a/sklearn/metrics/tests/test_score_objects.py
+++ b/sklearn/metrics/tests/test_score_objects.py
@@ -3,7 +3,7 @@ import pickle
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_raises
 
-from sklearn.metrics import (f1_score, r2_score, auc_score, fbeta_score,
+from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
                              log_loss)
 from sklearn.metrics.cluster import adjusted_rand_score
 from sklearn.metrics import make_scorer, SCORERS
@@ -67,8 +67,8 @@ def test_thresholded_scorers():
     clf = LogisticRegression(random_state=0)
     clf.fit(X_train, y_train)
     score1 = SCORERS['roc_auc'](clf, X_test, y_test)
-    score2 = auc_score(y_test, clf.decision_function(X_test))
-    score3 = auc_score(y_test, clf.predict_proba(X_test)[:, 1])
+    score2 = roc_auc_score(y_test, clf.decision_function(X_test))
+    score3 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
     assert_almost_equal(score1, score2)
     assert_almost_equal(score1, score3)
 
@@ -80,7 +80,7 @@ def test_thresholded_scorers():
     clf = DecisionTreeClassifier()
     clf.fit(X_train, y_train)
     score1 = SCORERS['roc_auc'](clf, X_test, y_test)
-    score2 = auc_score(y_test, clf.predict_proba(X_test)[:, 1])
+    score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
     assert_almost_equal(score1, score2)
 
     # Test that an exception is raised on more than two classes
diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py
index c7c3a70c7d..c932b7bb10 100644
--- a/sklearn/tests/test_grid_search.py
+++ b/sklearn/tests/test_grid_search.py
@@ -36,7 +36,7 @@ from sklearn.tree import DecisionTreeClassifier
 from sklearn.cluster import KMeans, MeanShift
 from sklearn.metrics import f1_score
 from sklearn.metrics import make_scorer
-from sklearn.metrics import auc_score
+from sklearn.metrics import roc_auc_score
 from sklearn.cross_validation import KFold, StratifiedKFold
 
 
@@ -572,7 +572,7 @@ def test_grid_search_score_consistency():
                 if score == "f1":
                     correct_score = f1_score(y[test], clf.predict(X[test]))
                 elif score == "roc_auc":
-                    correct_score = auc_score(y[test],
+                    correct_score = roc_auc_score(y[test],
                                               clf.decision_function(X[test]))
                 assert_almost_equal(correct_score, scores[i])
                 i += 1
-- 
GitLab