From d8d7e5d4ab2cb3f2a8e4aac3836d2a2c67da6aa7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Mon, 1 Nov 2010 18:44:31 +0100 Subject: [PATCH] s/explained_variance/explained_variance_score --- scikits/learn/base.py | 4 ++-- scikits/learn/glm/base.py | 7 ++----- scikits/learn/metrics.py | 2 +- scikits/learn/tests/test_metrics.py | 10 +++++----- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index dab6fa7d10..58db218cf4 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -10,7 +10,7 @@ import copy import numpy as np -from .metrics import explained_variance +from .metrics import explained_variance_score ################################################################################ def clone(estimator, safe=True): @@ -242,7 +242,7 @@ class RegressorMixin(object): ------- z : float """ - return explained_variance(y, self.predict(X)) + return explained_variance_score(y, self.predict(X)) ################################################################################ diff --git a/scikits/learn/glm/base.py b/scikits/learn/glm/base.py index 65f7dfe70e..12e02b6334 100644 --- a/scikits/learn/glm/base.py +++ b/scikits/learn/glm/base.py @@ -13,6 +13,7 @@ Generalized Linear models. import numpy as np from ..base import BaseEstimator, RegressorMixin +from ..metrics import explained_variance_score ### ### TODO: intercept for all models @@ -43,9 +44,7 @@ class LinearModel(BaseEstimator, RegressorMixin): def _explained_variance(self, X, y): """Compute explained variance a.k.a. r^2""" - ## TODO: this should have a tests. - return 1 - np.linalg.norm(y - self.predict(X))**2 \ - / np.linalg.norm(y)**2 + return explained_variance_score(y, self.predict(X)) def _center_data (self, X, y): """ @@ -64,7 +63,6 @@ class LinearModel(BaseEstimator, RegressorMixin): ymean = 0. return X, y, Xmean, ymean - def _set_intercept(self, Xmean, ymean): """Set the intercept_ """ @@ -73,7 +71,6 @@ class LinearModel(BaseEstimator, RegressorMixin): else: self.intercept_ = 0 - def __str__(self): if self.coef_ is not None: return ("%s \n%s #... Fitted: explained variance=%s" % diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index ab5401067a..244f25a891 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -505,7 +505,7 @@ def mean_square_error(y_true, y_pred): return np.linalg.norm(y_pred - y_true) ** 2 -def explained_variance(y_true, y_pred): +def explained_variance_score(y_true, y_pred): """Explained variance regression loss Best possible score is 1.0, lower values are worst. diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py index bd7a85f919..5b7782bdc6 100644 --- a/scikits/learn/tests/test_metrics.py +++ b/scikits/learn/tests/test_metrics.py @@ -12,7 +12,7 @@ from .. import svm from ..metrics import auc from ..metrics import classification_report from ..metrics import confusion_matrix -from ..metrics import explained_variance +from ..metrics import explained_variance_score from ..metrics import f1_score from ..metrics import mean_square_error from ..metrics import precision_recall_curve @@ -209,8 +209,8 @@ def test_losses(): assert_almost_equal(mean_square_error(y_true, y_pred), 12.999, 2) assert_almost_equal(mean_square_error(y_true, y_true), 0.00, 2) - assert_almost_equal(explained_variance(y_true, y_pred), -0.04, 2) - assert_almost_equal(explained_variance(y_true, y_true), 1.00, 2) + assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2) + assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2) def test_symmetry(): @@ -223,8 +223,8 @@ def test_symmetry(): assert_almost_equal(mean_square_error(y_true, y_pred), mean_square_error(y_pred, y_true)) # not symmetric - assert_(explained_variance(y_true, y_pred) != \ - explained_variance(y_pred, y_true)) + assert_(explained_variance_score(y_true, y_pred) != \ + explained_variance_score(y_pred, y_true)) # FIXME: precision and recall aren't symmetric either -- GitLab