diff --git a/scikits/learn/base.py b/scikits/learn/base.py index dab6fa7d10aa991608af76b1672daaf05659184e..58db218cf48e917d64ad48a388ae92a81ade86fe 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -10,7 +10,7 @@ import copy import numpy as np -from .metrics import explained_variance +from .metrics import explained_variance_score ################################################################################ def clone(estimator, safe=True): @@ -242,7 +242,7 @@ class RegressorMixin(object): ------- z : float """ - return explained_variance(y, self.predict(X)) + return explained_variance_score(y, self.predict(X)) ################################################################################ diff --git a/scikits/learn/glm/base.py b/scikits/learn/glm/base.py index 65f7dfe70e80d821248d9d022e3d35d0747d522d..12e02b633413e7bd9d7bafe486f24bac14f71734 100644 --- a/scikits/learn/glm/base.py +++ b/scikits/learn/glm/base.py @@ -13,6 +13,7 @@ Generalized Linear models. import numpy as np from ..base import BaseEstimator, RegressorMixin +from ..metrics import explained_variance_score ### ### TODO: intercept for all models @@ -43,9 +44,7 @@ class LinearModel(BaseEstimator, RegressorMixin): def _explained_variance(self, X, y): """Compute explained variance a.k.a. r^2""" - ## TODO: this should have a tests. - return 1 - np.linalg.norm(y - self.predict(X))**2 \ - / np.linalg.norm(y)**2 + return explained_variance_score(y, self.predict(X)) def _center_data (self, X, y): """ @@ -64,7 +63,6 @@ class LinearModel(BaseEstimator, RegressorMixin): ymean = 0. return X, y, Xmean, ymean - def _set_intercept(self, Xmean, ymean): """Set the intercept_ """ @@ -73,7 +71,6 @@ class LinearModel(BaseEstimator, RegressorMixin): else: self.intercept_ = 0 - def __str__(self): if self.coef_ is not None: return ("%s \n%s #... Fitted: explained variance=%s" % diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index ab5401067a008018c5848d29dc0219c345afd741..244f25a891acff4565f59d117202506f3260e426 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -505,7 +505,7 @@ def mean_square_error(y_true, y_pred): return np.linalg.norm(y_pred - y_true) ** 2 -def explained_variance(y_true, y_pred): +def explained_variance_score(y_true, y_pred): """Explained variance regression loss Best possible score is 1.0, lower values are worst. diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py index bd7a85f919286699d534eaefe5f5945c085a3730..5b7782bdc6e311fb7adf78e4dfd83972e118f285 100644 --- a/scikits/learn/tests/test_metrics.py +++ b/scikits/learn/tests/test_metrics.py @@ -12,7 +12,7 @@ from .. import svm from ..metrics import auc from ..metrics import classification_report from ..metrics import confusion_matrix -from ..metrics import explained_variance +from ..metrics import explained_variance_score from ..metrics import f1_score from ..metrics import mean_square_error from ..metrics import precision_recall_curve @@ -209,8 +209,8 @@ def test_losses(): assert_almost_equal(mean_square_error(y_true, y_pred), 12.999, 2) assert_almost_equal(mean_square_error(y_true, y_true), 0.00, 2) - assert_almost_equal(explained_variance(y_true, y_pred), -0.04, 2) - assert_almost_equal(explained_variance(y_true, y_true), 1.00, 2) + assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2) + assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2) def test_symmetry(): @@ -223,8 +223,8 @@ def test_symmetry(): assert_almost_equal(mean_square_error(y_true, y_pred), mean_square_error(y_pred, y_true)) # not symmetric - assert_(explained_variance(y_true, y_pred) != \ - explained_variance(y_pred, y_true)) + assert_(explained_variance_score(y_true, y_pred) != \ + explained_variance_score(y_pred, y_true)) # FIXME: precision and recall aren't symmetric either