From 43e9934b45a86c376268b4f01ad7702632aef49d Mon Sep 17 00:00:00 2001
From: Jake Vanderplas <vanderplas@astro.washington.edu>
Date: Tue, 20 Dec 2011 03:10:07 -0800
Subject: [PATCH] Change logsum to logsumexp for comparability with scipy

---
 doc/developers/utilities.rst | 14 ++++++++++----
 sklearn/hmm.py               | 16 ++++++++--------
 sklearn/lda.py               |  4 ++--
 sklearn/mixture/gmm.py       |  4 ++--
 sklearn/naive_bayes.py       |  4 ++--
 sklearn/tests/test_hmm.py    |  6 +++---
 sklearn/utils/extmath.py     |  6 +++---
 7 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/doc/developers/utilities.rst b/doc/developers/utilities.rst
index 5bb4a6a8f5..05325badd1 100644
--- a/doc/developers/utilities.rst
+++ b/doc/developers/utilities.rst
@@ -89,7 +89,8 @@ Efficient Linear Algebra & Array Operations
 - :func:`arrayfuncs.min_pos`: (used in ``sklearn.linear_model.least_angle``)
   Find the minimum of the positive values within an array.
 
-- :func:`extmath.norm`: computes vector norm by directly calling the BLAS
+- :func:`extmath.norm`: computes Euclidean (L2) vector norm
+  by directly calling the BLAS
   ``nrm2`` function.  This is more stable than ``scipy.linalg.norm``.  See
   `Fabian's blog post
   <http://fseoane.net/blog/2011/computing-the-vector-norm/>`_ for a discussion.
@@ -103,9 +104,14 @@ Efficient Linear Algebra & Array Operations
   ``scipy.sparse`` inputs.  If the inputs are dense, it is equivalent to
   ``numpy.dot``.
 
-- :func:`extmath.logsum`: compute the sum of X assuming X is in the log domain.
-  This is equivalent to calling ``np.log(np.sum(np.exp(X)))``, but is
-  robust to overflow/underflow errors.
+- :func:`extmath.logsumexp`: compute the sum of X assuming X is in the log
+  domain. This is equivalent to calling ``np.log(np.sum(np.exp(X)))``, but is
+  robust to overflow/underflow errors.  Note that there is similar
+  functionality in ``np.logaddexp.reduce``, but because of the pairwise nature
+  of this routine, it is slower for large arrays.
+  Scipy has a similar routine in ``scipy.misc.logsumexp`` (In scipy versions
+  < 0.10, this is found in ``scipy.maxentropy.logsumexp``),
+  but the scipy version does not accept an ``axis`` keyword.
 
 - :func:`extmath.weighted_mode`: an extension of ``scipy.stats.mode`` which
   allows each item to have a real-valued weight.
diff --git a/sklearn/hmm.py b/sklearn/hmm.py
index 00cc9eea39..c1ad18f1c0 100644
--- a/sklearn/hmm.py
+++ b/sklearn/hmm.py
@@ -15,7 +15,7 @@ import string
 import numpy as np
 
 from .utils import check_random_state
-from .utils.extmath import logsum
+from .utils.extmath import logsumexp
 from .base import BaseEstimator
 from .mixture import (GMM, lmvnpdf, normalize, sample_gaussian,
                  _distribute_covar_matrix_to_match_cvtype, _validate_covars)
@@ -147,7 +147,7 @@ class _BaseHMM(BaseEstimator):
         # all frames, unless we do approximate inference using pruning.
         # So, we will normalize each frame explicitly in case we
         # pruned too aggressively.
-        posteriors = np.exp(gamma.T - logsum(gamma, axis=1)).T
+        posteriors = np.exp(gamma.T - logsumexp(gamma, axis=1)).T
         posteriors += np.finfo(np.float32).eps
         posteriors /= np.sum(posteriors, axis=1).reshape((-1, 1))
         return logprob, posteriors
@@ -364,7 +364,7 @@ class _BaseHMM(BaseEstimator):
                 bwdlattice = self._do_backward_pass(framelogprob, fwdlattice,
                                                    maxrank, beamlogprob)
                 gamma = fwdlattice + bwdlattice
-                posteriors = np.exp(gamma.T - logsum(gamma, axis=1)).T
+                posteriors = np.exp(gamma.T - logsumexp(gamma, axis=1)).T
                 curr_logprob += lpr
                 self._accumulate_sufficient_statistics(
                     stats, seq, framelogprob, posteriors, fwdlattice,
@@ -445,12 +445,12 @@ class _BaseHMM(BaseEstimator):
         fwdlattice[0] = self._log_startprob + framelogprob[0]
         for n in xrange(1, nobs):
             idx = self._prune_states(fwdlattice[n - 1], maxrank, beamlogprob)
-            fwdlattice[n] = (logsum(self._log_transmat[idx].T
+            fwdlattice[n] = (logsumexp(self._log_transmat[idx].T
                                     + fwdlattice[n - 1, idx], axis=1)
                              + framelogprob[n])
         fwdlattice[fwdlattice <= ZEROLOGPROB] = -np.Inf
 
-        return logsum(fwdlattice[-1]), fwdlattice
+        return logsumexp(fwdlattice[-1]), fwdlattice
 
     def _do_backward_pass(self, framelogprob, fwdlattice, maxrank=None,
                           beamlogprob=-np.Inf):
@@ -466,7 +466,7 @@ class _BaseHMM(BaseEstimator):
                                      -50)
                                      #beamlogprob)
                                      #-np.Inf)
-            bwdlattice[n - 1] = logsum(self._log_transmat[:, idx] +
+            bwdlattice[n - 1] = logsumexp(self._log_transmat[:, idx] +
                                        bwdlattice[n, idx] +
                                        framelogprob[n, idx],
                                        axis=1)
@@ -479,7 +479,7 @@ class _BaseHMM(BaseEstimator):
         after rank and beam pruning.
         """
         # Beam pruning
-        threshlogprob = logsum(lattice_frame) + beamlogprob
+        threshlogprob = logsumexp(lattice_frame) + beamlogprob
         # Rank pruning
         if maxrank:
             # How big should our rank pruning histogram be?
@@ -534,7 +534,7 @@ class _BaseHMM(BaseEstimator):
             for t in xrange(len(framelogprob)):
                 zeta = (fwdlattice[t - 1][:, np.newaxis] + self._log_transmat
                         + framelogprob[t] + bwdlattice[t])
-                stats['trans'] += np.exp(zeta - logsum(zeta))
+                stats['trans'] += np.exp(zeta - logsumexp(zeta))
 
     def _do_mstep(self, stats, params, **kwargs):
         # Based on Huang, Acero, Hon, "Spoken Language Processing",
diff --git a/sklearn/lda.py b/sklearn/lda.py
index 8b3a6f640d..24a54326c7 100644
--- a/sklearn/lda.py
+++ b/sklearn/lda.py
@@ -10,7 +10,7 @@ import numpy as np
 from scipy import linalg, ndimage
 
 from .base import BaseEstimator, ClassifierMixin, TransformerMixin
-from .utils.extmath import logsum
+from .utils.extmath import logsumexp
 
 
 class LDA(BaseEstimator, ClassifierMixin, TransformerMixin):
@@ -263,5 +263,5 @@ class LDA(BaseEstimator, ClassifierMixin, TransformerMixin):
         """
         values = self.decision_function(X)
         loglikelihood = (values - values.max(axis=1)[:, np.newaxis])
-        normalization = logsum(loglikelihood, axis=1)
+        normalization = logsumexp(loglikelihood, axis=1)
         return loglikelihood - normalization[:, np.newaxis]
diff --git a/sklearn/mixture/gmm.py b/sklearn/mixture/gmm.py
index 0d30af268a..ae51e0247f 100644
--- a/sklearn/mixture/gmm.py
+++ b/sklearn/mixture/gmm.py
@@ -10,7 +10,7 @@ import numpy as np
 
 from ..base import BaseEstimator
 from ..utils import check_random_state
-from ..utils.extmath import logsum
+from ..utils.extmath import logsumexp
 from .. import cluster
 
 
@@ -332,7 +332,7 @@ class GMM(BaseEstimator):
         obs = np.asarray(obs)
         lpr = (lmvnpdf(obs, self._means, self._covars, self._cvtype)
                + self._log_weights)
-        logprob = logsum(lpr, axis=1)
+        logprob = logsumexp(lpr, axis=1)
         posteriors = np.exp(lpr - logprob[:, np.newaxis])
         return logprob, posteriors
 
diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py
index bc5e365384..676d1fee80 100644
--- a/sklearn/naive_bayes.py
+++ b/sklearn/naive_bayes.py
@@ -23,7 +23,7 @@ from scipy.sparse import issparse
 from .base import BaseEstimator, ClassifierMixin
 from .preprocessing import binarize, LabelBinarizer
 from .utils import array2d, atleast2d_or_csr
-from .utils.extmath import safe_sparse_dot, logsum
+from .utils.extmath import safe_sparse_dot, logsumexp
 
 
 class BaseNB(BaseEstimator, ClassifierMixin):
@@ -74,7 +74,7 @@ class BaseNB(BaseEstimator, ClassifierMixin):
         """
         jll = self._joint_log_likelihood(X)
         # normalize by P(x) = P(f_1, ..., f_n)
-        log_prob_x = logsum(jll, axis=1)
+        log_prob_x = logsumexp(jll, axis=1)
         return jll - np.atleast_2d(log_prob_x).T
 
     def predict_proba(self, X):
diff --git a/sklearn/tests/test_hmm.py b/sklearn/tests/test_hmm.py
index 1b432085a8..2105e8bfd8 100644
--- a/sklearn/tests/test_hmm.py
+++ b/sklearn/tests/test_hmm.py
@@ -5,7 +5,7 @@ from unittest import TestCase
 
 from sklearn.datasets.samples_generator import make_spd_matrix
 from sklearn import hmm
-from sklearn.utils.extmath import logsum
+from sklearn.utils.extmath import logsumexp
 
 
 np.seterr(all='warn')
@@ -156,7 +156,7 @@ class TestBaseHMM(SeedRandomNumberGeneratorTestCase):
 
         assert_array_almost_equal(hmmposteriors.sum(axis=1), np.ones(nobs))
 
-        norm = logsum(framelogprob, axis=1)[:, np.newaxis]
+        norm = logsumexp(framelogprob, axis=1)[:, np.newaxis]
         gmmposteriors = np.exp(framelogprob - np.tile(norm, (1, n_components)))
         assert_array_almost_equal(hmmposteriors, gmmposteriors)
 
@@ -175,7 +175,7 @@ class TestBaseHMM(SeedRandomNumberGeneratorTestCase):
         # posteriors, not likelihoods).
         viterbi_ll, state_sequence = h.decode([])
 
-        norm = logsum(framelogprob, axis=1)[:, np.newaxis]
+        norm = logsumexp(framelogprob, axis=1)[:, np.newaxis]
         gmmposteriors = np.exp(framelogprob - np.tile(norm, (1, n_components)))
         gmmstate_sequence = gmmposteriors.argmax(axis=1)
         assert_array_equal(state_sequence, gmmstate_sequence)
diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py
index 1ab8c528c6..b6cf041424 100644
--- a/sklearn/utils/extmath.py
+++ b/sklearn/utils/extmath.py
@@ -205,7 +205,7 @@ def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
         return U[:, :k], s[:k], V[:k, :]
 
 
-def logsum(arr, axis=0):
+def logsumexp(arr, axis=0):
     """ Computes the sum of arr assuming arr is in the log domain.
 
     Returns log(sum(exp(arr))) while minimizing the possibility of
@@ -215,11 +215,11 @@ def logsum(arr, axis=0):
     ========
 
     >>> import numpy as np
-    >>> from sklearn.utils.extmath import logsum
+    >>> from sklearn.utils.extmath import logsumexp
     >>> a = np.arange(10)
     >>> np.log(np.sum(np.exp(a)))
     9.4586297444267107
-    >>> logsum(a)
+    >>> logsumexp(a)
     9.4586297444267107
     """
     arr = np.rollaxis(arr, axis)
-- 
GitLab