From 7cc2577761011b94a34cdfebad86b5c925ced098 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort <alexandre.gramfort@inria.fr> Date: Mon, 15 Nov 2010 18:13:04 +0100 Subject: [PATCH] ENH: adding predict_log_proba to LDA and QDA + tests to reach 100% coverage --- scikits/learn/lda.py | 16 ++++++++++++++++ scikits/learn/qda.py | 17 +++++++++++++++++ scikits/learn/tests/test_lda.py | 8 +++++++- scikits/learn/tests/test_qda.py | 8 +++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py index 0b71a4a82a..25b05d8008 100644 --- a/scikits/learn/lda.py +++ b/scikits/learn/lda.py @@ -222,3 +222,19 @@ class LDA(BaseEstimator, ClassifierMixin): # compute posterior probabilities return likelihood / likelihood.sum(axis=1)[:, np.newaxis] + def predict_log_proba(self, X): + """ + This function return posterior log-probabilities of classification + according to each class on an array of test vectors X. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array, shape = [n_samples, n_classes] + """ + # XXX : can do better to avoid precision overflows + probas_ = self.predict_proba(X) + return np.log(probas_) diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py index f0cf284177..c65fcd8ee1 100644 --- a/scikits/learn/qda.py +++ b/scikits/learn/qda.py @@ -199,3 +199,20 @@ class QDA(BaseEstimator, ClassifierMixin): likelihood = np.exp(values - values.min(axis=1)[:, np.newaxis]) # compute posterior probabilities return likelihood / likelihood.sum(axis=1)[:, np.newaxis] + + def predict_log_proba(self, X): + """ + This function return posterior log-probabilities of classification + according to each class on an array of test vectors X. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array, shape = [n_samples, n_classes] + """ + # XXX : can do better to avoid precision overflows + probas_ = self.predict_proba(X) + return np.log(probas_) diff --git a/scikits/learn/tests/test_lda.py b/scikits/learn/tests/test_lda.py index 0fb1583f5f..5bb82f2ab7 100644 --- a/scikits/learn/tests/test_lda.py +++ b/scikits/learn/tests/test_lda.py @@ -1,5 +1,5 @@ import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from nose.tools import assert_true from .. import lda @@ -29,6 +29,12 @@ def test_lda(): y_pred1 = clf.fit(X1, y).predict(X1) assert_array_equal(y_pred1, y) + # Test probas estimates + y_proba_pred1 = clf.predict_proba(X1) + assert_array_equal((y_proba_pred1[:,1] > 0.5) + 1, y) + y_log_proba_pred1 = clf.predict_log_proba(X1) + assert_array_almost_equal(np.exp(y_log_proba_pred1), y_proba_pred1, 8) + # Primarily test for commit 2f34950 -- "reuse" of priors y_pred3 = clf.fit(X, y3).predict(X) # LDA shouldn't be able to separate those diff --git a/scikits/learn/tests/test_qda.py b/scikits/learn/tests/test_qda.py index 8c699240e7..39352b2bec 100644 --- a/scikits/learn/tests/test_qda.py +++ b/scikits/learn/tests/test_qda.py @@ -1,5 +1,5 @@ import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from nose.tools import assert_true from .. import qda @@ -29,6 +29,12 @@ def test_qda(): y_pred1 = clf.fit(X1, y).predict(X1) assert_array_equal(y_pred1, y) + # Test probas estimates + y_proba_pred1 = clf.predict_proba(X1) + assert_array_equal((y_proba_pred1[:,1] > 0.5) + 1, y) + y_log_proba_pred1 = clf.predict_log_proba(X1) + assert_array_almost_equal(np.exp(y_log_proba_pred1), y_proba_pred1, 8) + y_pred3 = clf.fit(X, y3).predict(X) # QDA shouldn't be able to separate those assert_true(np.any(y_pred3 != y3)) -- GitLab