From 7cc2577761011b94a34cdfebad86b5c925ced098 Mon Sep 17 00:00:00 2001
From: Alexandre Gramfort <alexandre.gramfort@inria.fr>
Date: Mon, 15 Nov 2010 18:13:04 +0100
Subject: [PATCH] ENH: adding predict_log_proba to LDA and QDA + tests to reach
 100% coverage

---
 scikits/learn/lda.py            | 16 ++++++++++++++++
 scikits/learn/qda.py            | 17 +++++++++++++++++
 scikits/learn/tests/test_lda.py |  8 +++++++-
 scikits/learn/tests/test_qda.py |  8 +++++++-
 4 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py
index 0b71a4a82a..25b05d8008 100644
--- a/scikits/learn/lda.py
+++ b/scikits/learn/lda.py
@@ -222,3 +222,19 @@ class LDA(BaseEstimator, ClassifierMixin):
         # compute posterior probabilities
         return likelihood / likelihood.sum(axis=1)[:, np.newaxis]
 
+    def predict_log_proba(self, X):
+        """
+        This function return posterior log-probabilities of classification
+        according to each class on an array of test vectors X.
+
+        Parameters
+        ----------
+        X : array-like, shape = [n_samples, n_features]
+
+        Returns
+        -------
+        C : array, shape = [n_samples, n_classes]
+        """
+        # XXX : can do better to avoid precision overflows
+        probas_ = self.predict_proba(X)
+        return np.log(probas_)
diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py
index f0cf284177..c65fcd8ee1 100644
--- a/scikits/learn/qda.py
+++ b/scikits/learn/qda.py
@@ -199,3 +199,20 @@ class QDA(BaseEstimator, ClassifierMixin):
         likelihood = np.exp(values - values.min(axis=1)[:, np.newaxis])
         # compute posterior probabilities
         return likelihood / likelihood.sum(axis=1)[:, np.newaxis]
+
+    def predict_log_proba(self, X):
+        """
+        This function return posterior log-probabilities of classification
+        according to each class on an array of test vectors X.
+
+        Parameters
+        ----------
+        X : array-like, shape = [n_samples, n_features]
+
+        Returns
+        -------
+        C : array, shape = [n_samples, n_classes]
+        """
+        # XXX : can do better to avoid precision overflows
+        probas_ = self.predict_proba(X)
+        return np.log(probas_)
diff --git a/scikits/learn/tests/test_lda.py b/scikits/learn/tests/test_lda.py
index 0fb1583f5f..5bb82f2ab7 100644
--- a/scikits/learn/tests/test_lda.py
+++ b/scikits/learn/tests/test_lda.py
@@ -1,5 +1,5 @@
 import numpy as np
-from numpy.testing import assert_array_equal
+from numpy.testing import assert_array_equal, assert_array_almost_equal
 from nose.tools import assert_true
 
 from .. import lda
@@ -29,6 +29,12 @@ def test_lda():
     y_pred1 = clf.fit(X1, y).predict(X1)
     assert_array_equal(y_pred1, y)
 
+    # Test probas estimates
+    y_proba_pred1 = clf.predict_proba(X1)
+    assert_array_equal((y_proba_pred1[:,1] > 0.5) + 1, y)
+    y_log_proba_pred1 = clf.predict_log_proba(X1)
+    assert_array_almost_equal(np.exp(y_log_proba_pred1), y_proba_pred1, 8)
+
     # Primarily test for commit 2f34950 -- "reuse" of priors
     y_pred3 = clf.fit(X, y3).predict(X)
     # LDA shouldn't be able to separate those
diff --git a/scikits/learn/tests/test_qda.py b/scikits/learn/tests/test_qda.py
index 8c699240e7..39352b2bec 100644
--- a/scikits/learn/tests/test_qda.py
+++ b/scikits/learn/tests/test_qda.py
@@ -1,5 +1,5 @@
 import numpy as np
-from numpy.testing import assert_array_equal
+from numpy.testing import assert_array_equal, assert_array_almost_equal
 from nose.tools import assert_true
 
 from .. import qda
@@ -29,6 +29,12 @@ def test_qda():
     y_pred1 = clf.fit(X1, y).predict(X1)
     assert_array_equal(y_pred1, y)
 
+    # Test probas estimates
+    y_proba_pred1 = clf.predict_proba(X1)
+    assert_array_equal((y_proba_pred1[:,1] > 0.5) + 1, y)
+    y_log_proba_pred1 = clf.predict_log_proba(X1)
+    assert_array_almost_equal(np.exp(y_log_proba_pred1), y_proba_pred1, 8)
+
     y_pred3 = clf.fit(X, y3).predict(X)
     # QDA shouldn't be able to separate those
     assert_true(np.any(y_pred3 != y3))
-- 
GitLab