From a2639ac0db4b56c24f766c5626d567eefa97d5d5 Mon Sep 17 00:00:00 2001
From: Anne-Laure Fouque <afouque@is208050.(none)>
Date: Mon, 29 Nov 2010 15:52:58 +0100
Subject: [PATCH] ENH added R^2 coeff which is now used as score function in
 RegressorMixin

Signed-off-by: Fabian Pedregosa <fabian.pedregosa@inria.fr>
---
 scikits/learn/base.py               |  6 +++---
 scikits/learn/linear_model/base.py  |  4 ++--
 scikits/learn/metrics.py            | 19 +++++++++++++++++++
 scikits/learn/tests/test_metrics.py | 12 ++++++++++--
 4 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/scikits/learn/base.py b/scikits/learn/base.py
index 7deca23023..77d6fbf71d 100644
--- a/scikits/learn/base.py
+++ b/scikits/learn/base.py
@@ -10,7 +10,7 @@ import copy
 
 import numpy as np
 
-from .metrics import explained_variance_score
+from .metrics import r2_score
 
 ################################################################################
 def clone(estimator, safe=True):
@@ -236,7 +236,7 @@ class RegressorMixin(object):
     """
 
     def score(self, X, y):
-        """ Returns the explained variance of the prediction
+        """ Returns the coefficient of determination of the prediction
 
             Parameters
             ----------
@@ -249,7 +249,7 @@ class RegressorMixin(object):
             -------
             z : float
         """
-        return explained_variance_score(y, self.predict(X))
+        return r2_score(y, self.predict(X))
 
 
 ################################################################################
diff --git a/scikits/learn/linear_model/base.py b/scikits/learn/linear_model/base.py
index cc4d3856a6..4d3f08b41d 100644
--- a/scikits/learn/linear_model/base.py
+++ b/scikits/learn/linear_model/base.py
@@ -13,7 +13,7 @@ Generalized Linear models.
 import numpy as np
 
 from ..base import BaseEstimator, RegressorMixin, ClassifierMixin
-from ..metrics import explained_variance_score
+from ..metrics import r2_score
 from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber
 
 ###
@@ -45,7 +45,7 @@ class LinearModel(BaseEstimator, RegressorMixin):
 
     def _explained_variance(self, X, y):
         """Compute explained variance a.k.a. r^2"""
-        return explained_variance_score(y, self.predict(X))
+        return r2_score(y, self.predict(X))
 
     @staticmethod
     def _center_data(X, y, fit_intercept):
diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py
index f79b4e783c..74ac9ebdcf 100644
--- a/scikits/learn/metrics.py
+++ b/scikits/learn/metrics.py
@@ -512,6 +512,25 @@ def explained_variance_score(y_true, y_pred):
     y_pred : array-like
     """
     return 1 - np.var(y_true - y_pred) / np.var(y_true)
+    
+
+def r2_score(y_true, y_pred):
+    """R^2 (coefficient of determination) regression score function
+
+    Best possible score is 1.0, lower values are worst.
+
+    Note: not a symmetric function.
+
+    return the R^2 score
+
+    Parameters
+    ----------
+    y_true : array-like
+
+    y_pred : array-like
+    """
+    return 1 -((y_true - y_pred)**2).sum() / ((y_true-y_true.mean())**2).sum()
+
 
 
 ###############################################################################
diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py
index 93fac32958..eeb603bbe5 100644
--- a/scikits/learn/tests/test_metrics.py
+++ b/scikits/learn/tests/test_metrics.py
@@ -2,7 +2,9 @@ import random
 import numpy as np
 import nose
 
-from numpy.testing import assert_
+# from numpy.testing import assert_
+# numpy.testing.assert_ only exists in recent versions of numpy
+from nose.tools import assert_true
 from numpy.testing import assert_array_almost_equal
 from numpy.testing import assert_array_equal
 from numpy.testing import assert_equal, assert_almost_equal
@@ -13,6 +15,7 @@ from ..metrics import auc
 from ..metrics import classification_report
 from ..metrics import confusion_matrix
 from ..metrics import explained_variance_score
+from ..metrics import r2_score
 from ..metrics import f1_score
 from ..metrics import mean_square_error
 from ..metrics import precision_recall_curve
@@ -222,6 +225,9 @@ def test_losses():
     assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2)
     assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2)
 
+    assert_almost_equal(r2_score(y_true, y_pred), -0.04, 2)
+    assert_almost_equal(r2_score(y_true, y_true), 1.00, 2)
+
 
 def test_symmetry():
     """Test the symmetry of score and loss functions"""
@@ -233,8 +239,10 @@ def test_symmetry():
     assert_almost_equal(mean_square_error(y_true, y_pred),
                         mean_square_error(y_pred, y_true))
     # not symmetric
-    assert_(explained_variance_score(y_true, y_pred) != \
+    assert_true(explained_variance_score(y_true, y_pred) != \
             explained_variance_score(y_pred, y_true))
+    assert_true(r2_score(y_true, y_pred) != \
+            r2_score(y_pred, y_true))
     # FIXME: precision and recall aren't symmetric either
 
 
-- 
GitLab