From a2639ac0db4b56c24f766c5626d567eefa97d5d5 Mon Sep 17 00:00:00 2001 From: Anne-Laure Fouque <afouque@is208050.(none)> Date: Mon, 29 Nov 2010 15:52:58 +0100 Subject: [PATCH] ENH added R^2 coeff which is now used as score function in RegressorMixin Signed-off-by: Fabian Pedregosa <fabian.pedregosa@inria.fr> --- scikits/learn/base.py | 6 +++--- scikits/learn/linear_model/base.py | 4 ++-- scikits/learn/metrics.py | 19 +++++++++++++++++++ scikits/learn/tests/test_metrics.py | 12 ++++++++++-- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index 7deca23023..77d6fbf71d 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -10,7 +10,7 @@ import copy import numpy as np -from .metrics import explained_variance_score +from .metrics import r2_score ################################################################################ def clone(estimator, safe=True): @@ -236,7 +236,7 @@ class RegressorMixin(object): """ def score(self, X, y): - """ Returns the explained variance of the prediction + """ Returns the coefficient of determination of the prediction Parameters ---------- @@ -249,7 +249,7 @@ class RegressorMixin(object): ------- z : float """ - return explained_variance_score(y, self.predict(X)) + return r2_score(y, self.predict(X)) ################################################################################ diff --git a/scikits/learn/linear_model/base.py b/scikits/learn/linear_model/base.py index cc4d3856a6..4d3f08b41d 100644 --- a/scikits/learn/linear_model/base.py +++ b/scikits/learn/linear_model/base.py @@ -13,7 +13,7 @@ Generalized Linear models. import numpy as np from ..base import BaseEstimator, RegressorMixin, ClassifierMixin -from ..metrics import explained_variance_score +from ..metrics import r2_score from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber ### @@ -45,7 +45,7 @@ class LinearModel(BaseEstimator, RegressorMixin): def _explained_variance(self, X, y): """Compute explained variance a.k.a. r^2""" - return explained_variance_score(y, self.predict(X)) + return r2_score(y, self.predict(X)) @staticmethod def _center_data(X, y, fit_intercept): diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index f79b4e783c..74ac9ebdcf 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -512,6 +512,25 @@ def explained_variance_score(y_true, y_pred): y_pred : array-like """ return 1 - np.var(y_true - y_pred) / np.var(y_true) + + +def r2_score(y_true, y_pred): + """R^2 (coefficient of determination) regression score function + + Best possible score is 1.0, lower values are worst. + + Note: not a symmetric function. + + return the R^2 score + + Parameters + ---------- + y_true : array-like + + y_pred : array-like + """ + return 1 -((y_true - y_pred)**2).sum() / ((y_true-y_true.mean())**2).sum() + ############################################################################### diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py index 93fac32958..eeb603bbe5 100644 --- a/scikits/learn/tests/test_metrics.py +++ b/scikits/learn/tests/test_metrics.py @@ -2,7 +2,9 @@ import random import numpy as np import nose -from numpy.testing import assert_ +# from numpy.testing import assert_ +# numpy.testing.assert_ only exists in recent versions of numpy +from nose.tools import assert_true from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_equal from numpy.testing import assert_equal, assert_almost_equal @@ -13,6 +15,7 @@ from ..metrics import auc from ..metrics import classification_report from ..metrics import confusion_matrix from ..metrics import explained_variance_score +from ..metrics import r2_score from ..metrics import f1_score from ..metrics import mean_square_error from ..metrics import precision_recall_curve @@ -222,6 +225,9 @@ def test_losses(): assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2) assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2) + assert_almost_equal(r2_score(y_true, y_pred), -0.04, 2) + assert_almost_equal(r2_score(y_true, y_true), 1.00, 2) + def test_symmetry(): """Test the symmetry of score and loss functions""" @@ -233,8 +239,10 @@ def test_symmetry(): assert_almost_equal(mean_square_error(y_true, y_pred), mean_square_error(y_pred, y_true)) # not symmetric - assert_(explained_variance_score(y_true, y_pred) != \ + assert_true(explained_variance_score(y_true, y_pred) != \ explained_variance_score(y_pred, y_true)) + assert_true(r2_score(y_true, y_pred) != \ + r2_score(y_pred, y_true)) # FIXME: precision and recall aren't symmetric either -- GitLab