From fa598738570e03c50df5836a4e0a4285b91c957c Mon Sep 17 00:00:00 2001
From: yangarbiter <yangarbiter@users.noreply.github.com>
Date: Mon, 17 Oct 2016 06:49:23 -0700
Subject: [PATCH] [MRG+1] FIX unstable cumsum (#7376)

* FIX unstable cumsum in utils.random

* equal_nan = true for isclose
since numpy < 1.9 sum is as unstable as cumsum, fallback to np.cumsum

* added axis parameter to stable_cumsum

* FIX unstable sumsum in ensemble.weight_boosting and utils.stats

* FIX axis problem in stable_cumsum

* FIX unstable cumsum in mixture.gmm and mixture.dpgmm

* FIX unstable cumsum in cluster.k_means_, decomposition.pca, and manifold.locally_linear

* FIX unstable sumsum in dataset.samples_generator

* added docstring for parameter axis of stable_cumsum

* added comment for why fall back to np.cumsum when np version < 1.9

* remove unneeded stable_cumsum

* added stable_cumsum's axis testing

* FIX numpy docstring for make_sparse_spd_matrix

* change stable_cumsum from error to warning
---
 sklearn/cluster/k_means_.py           |  5 +++--
 sklearn/datasets/samples_generator.py |  2 +-
 sklearn/decomposition/pca.py          |  3 ++-
 sklearn/ensemble/weight_boosting.py   |  5 +++--
 sklearn/manifold/locally_linear.py    |  3 ++-
 sklearn/mixture/dpgmm.py              |  4 ++--
 sklearn/utils/extmath.py              | 23 ++++++++++++++++-------
 sklearn/utils/stats.py                |  3 ++-
 sklearn/utils/tests/test_extmath.py   | 13 +++++++++----
 9 files changed, 40 insertions(+), 21 deletions(-)

diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py
index 7abd105926..2596a47307 100644
--- a/sklearn/cluster/k_means_.py
+++ b/sklearn/cluster/k_means_.py
@@ -18,7 +18,7 @@ import scipy.sparse as sp
 
 from ..base import BaseEstimator, ClusterMixin, TransformerMixin
 from ..metrics.pairwise import euclidean_distances
-from ..utils.extmath import row_norms, squared_norm
+from ..utils.extmath import row_norms, squared_norm, stable_cumsum
 from ..utils.sparsefuncs_fast import assign_rows_csr
 from ..utils.sparsefuncs import mean_variance_axis
 from ..utils.fixes import astype
@@ -106,7 +106,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
         # Choose center candidates by sampling with probability proportional
         # to the squared distance to the closest existing center
         rand_vals = random_state.random_sample(n_local_trials) * current_pot
-        candidate_ids = np.searchsorted(closest_dist_sq.cumsum(), rand_vals)
+        candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq),
+                                        rand_vals)
 
         # Compute distances to center candidates
         distance_to_candidates = euclidean_distances(
diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py
index 53ee8987ba..acd0733754 100644
--- a/sklearn/datasets/samples_generator.py
+++ b/sklearn/datasets/samples_generator.py
@@ -1194,7 +1194,7 @@ def make_sparse_spd_matrix(dim=1, alpha=0.95, norm_diag=False,
         The size of the random matrix to generate.
 
     alpha : float between 0 and 1, optional (default=0.95)
-        The probability that a coefficient is zero (see notes). Larger values 
+        The probability that a coefficient is zero (see notes). Larger values
         enforce more sparsity.
 
     random_state : int, RandomState instance or None, optional (default=None)
diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py
index b9b171301f..d07a0bf2ed 100644
--- a/sklearn/decomposition/pca.py
+++ b/sklearn/decomposition/pca.py
@@ -24,6 +24,7 @@ from ..utils import deprecated
 from ..utils import check_random_state, as_float_array
 from ..utils import check_array
 from ..utils.extmath import fast_dot, fast_logdet, randomized_svd, svd_flip
+from ..utils.extmath import stable_cumsum
 from ..utils.validation import check_is_fitted
 from ..utils.arpack import svds
 
@@ -393,7 +394,7 @@ class PCA(_BasePCA):
         elif 0 < n_components < 1.0:
             # number of components for which the cumulated explained
             # variance percentage is superior to the desired threshold
-            ratio_cumsum = explained_variance_ratio_.cumsum()
+            ratio_cumsum = stable_cumsum(explained_variance_ratio_)
             n_components = np.searchsorted(ratio_cumsum, n_components) + 1
 
         # Compute noise covariance using Probabilistic PCA model
diff --git a/sklearn/ensemble/weight_boosting.py b/sklearn/ensemble/weight_boosting.py
index 56d7d6ff80..16afc4311e 100644
--- a/sklearn/ensemble/weight_boosting.py
+++ b/sklearn/ensemble/weight_boosting.py
@@ -38,6 +38,7 @@ from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
 from ..tree.tree import BaseDecisionTree
 from ..tree._tree import DTYPE
 from ..utils import check_array, check_X_y, check_random_state
+from ..utils.extmath import stable_cumsum
 from ..metrics import accuracy_score, r2_score
 from sklearn.utils.validation import has_fit_parameter, check_is_fitted
 
@@ -1002,7 +1003,7 @@ class AdaBoostRegressor(BaseWeightBoosting, RegressorMixin):
 
         # Weighted sampling of the training set with replacement
         # For NumPy >= 1.7.0 use np.random.choice
-        cdf = sample_weight.cumsum()
+        cdf = stable_cumsum(sample_weight)
         cdf /= cdf[-1]
         uniform_samples = random_state.random_sample(X.shape[0])
         bootstrap_idx = cdf.searchsorted(uniform_samples, side='right')
@@ -1059,7 +1060,7 @@ class AdaBoostRegressor(BaseWeightBoosting, RegressorMixin):
         sorted_idx = np.argsort(predictions, axis=1)
 
         # Find index of median prediction for each sample
-        weight_cdf = self.estimator_weights_[sorted_idx].cumsum(axis=1)
+        weight_cdf = stable_cumsum(self.estimator_weights_[sorted_idx], axis=1)
         median_or_above = weight_cdf >= 0.5 * weight_cdf[:, -1][:, np.newaxis]
         median_idx = median_or_above.argmax(axis=1)
 
diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py
index a1940333e5..f5a383d58a 100644
--- a/sklearn/manifold/locally_linear.py
+++ b/sklearn/manifold/locally_linear.py
@@ -10,6 +10,7 @@ from scipy.sparse import eye, csr_matrix
 from ..base import BaseEstimator, TransformerMixin
 from ..utils import check_random_state, check_array
 from ..utils.arpack import eigsh
+from ..utils.extmath import stable_cumsum
 from ..utils.validation import check_is_fitted
 from ..utils.validation import FLOAT_DTYPES
 from ..neighbors import NearestNeighbors
@@ -420,7 +421,7 @@ def locally_linear_embedding(
         # this is the size of the largest set of eigenvalues
         # such that Sum[v; v in set]/Sum[v; v not in set] < eta
         s_range = np.zeros(N, dtype=int)
-        evals_cumsum = np.cumsum(evals, 1)
+        evals_cumsum = stable_cumsum(evals, 1)
         eta_range = evals_cumsum[:, -1:] / evals_cumsum[:, :-1] - 1
         for i in range(N):
             s_range[i] = np.searchsorted(eta_range[i, ::-1], eta)
diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py
index 0b7f11affe..1b119b8b72 100644
--- a/sklearn/mixture/dpgmm.py
+++ b/sklearn/mixture/dpgmm.py
@@ -24,7 +24,7 @@ from scipy.spatial.distance import cdist
 
 from ..externals.six.moves import xrange
 from ..utils import check_random_state, check_array, deprecated
-from ..utils.extmath import logsumexp, pinvh, squared_norm
+from ..utils.extmath import logsumexp, pinvh, squared_norm, stable_cumsum
 from ..utils.validation import check_is_fitted
 from .. import cluster
 from .gmm import _GMMBase
@@ -462,7 +462,7 @@ class _DPGMMBase(_GMMBase):
         dg1 = digamma(self.gamma_.T[1]) - dg12
         dg2 = digamma(self.gamma_.T[2]) - dg12
 
-        cz = np.cumsum(z[:, ::-1], axis=-1)[:, -2::-1]
+        cz = stable_cumsum(z[:, ::-1], axis=-1)[:, -2::-1]
         logprior = np.sum(cz * dg2[:-1]) + np.sum(z * dg1)
         del cz  # Save memory
         z_non_zeros = z[z > np.finfo(np.float32).eps]
diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py
index 1857a27adf..741601531d 100644
--- a/sklearn/utils/extmath.py
+++ b/sklearn/utils/extmath.py
@@ -25,7 +25,7 @@ from ._logistic_sigmoid import _log_logistic_sigmoid
 from ..externals.six.moves import xrange
 from .sparsefuncs_fast import csr_row_norms
 from .validation import check_array
-from ..exceptions import NonBLASDotWarning
+from ..exceptions import ConvergenceWarning, NonBLASDotWarning
 
 
 def norm(x):
@@ -844,21 +844,30 @@ def _deterministic_vector_sign_flip(u):
     return u
 
 
-def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
+def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08):
     """Use high precision for cumsum and check that final value matches sum
 
     Parameters
     ----------
     arr : array-like
         To be cumulatively summed as flat
+    axis : int, optional
+        Axis along which the cumulative sum is computed.
+        The default (None) is to compute the cumsum over the flattened array.
     rtol : float
         Relative tolerance, see ``np.allclose``
     atol : float
         Absolute tolerance, see ``np.allclose``
     """
-    out = np.cumsum(arr, dtype=np.float64)
-    expected = np.sum(arr, dtype=np.float64)
-    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
-        raise RuntimeError('cumsum was found to be unstable: '
-                           'its last element does not correspond to sum')
+    # sum is as unstable as cumsum for numpy < 1.9
+    if np_version < (1, 9):
+        return np.cumsum(arr, axis=axis, dtype=np.float64)
+
+    out = np.cumsum(arr, axis=axis, dtype=np.float64)
+    expected = np.sum(arr, axis=axis, dtype=np.float64)
+    if not np.all(np.isclose(out.take(-1, axis=axis), expected, rtol=rtol,
+                             atol=atol, equal_nan=True)):
+        warnings.warn('cumsum was found to be unstable: '
+                      'its last element does not correspond to sum',
+                      ConvergenceWarning)
     return out
diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py
index 463146d038..265d193e6b 100644
--- a/sklearn/utils/stats.py
+++ b/sklearn/utils/stats.py
@@ -1,6 +1,7 @@
 import numpy as np
 from scipy.stats import rankdata as _sp_rankdata
 from .fixes import bincount
+from ..utils.extmath import stable_cumsum
 
 
 # To remove when we support scipy 0.13
@@ -53,7 +54,7 @@ def _weighted_percentile(array, sample_weight, percentile=50):
     sorted_idx = np.argsort(array)
 
     # Find index of median prediction for each sample
-    weight_cdf = sample_weight[sorted_idx].cumsum()
+    weight_cdf = stable_cumsum(sample_weight[sorted_idx])
     percentile_idx = np.searchsorted(
         weight_cdf, (percentile / 100.) * weight_cdf[-1])
     return array[sorted_idx[percentile_idx]]
diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py
index 55f96cdf15..49a0b4abee 100644
--- a/sklearn/utils/tests/test_extmath.py
+++ b/sklearn/utils/tests/test_extmath.py
@@ -18,6 +18,7 @@ from sklearn.utils.testing import assert_false
 from sklearn.utils.testing import assert_greater
 from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_raise_message
+from sklearn.utils.testing import assert_warns
 from sklearn.utils.testing import skip_if_32bit
 from sklearn.utils.testing import SkipTest
 from sklearn.utils.fixes import np_version
@@ -36,6 +37,7 @@ from sklearn.utils.extmath import _incremental_mean_and_var
 from sklearn.utils.extmath import _deterministic_vector_sign_flip
 from sklearn.utils.extmath import softmax
 from sklearn.utils.extmath import stable_cumsum
+from sklearn.exceptions import ConvergenceWarning
 from sklearn.datasets.samples_generator import make_low_rank_matrix
 
 
@@ -654,7 +656,10 @@ def test_stable_cumsum():
         raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9")
     assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3]))
     r = np.random.RandomState(0).rand(100000)
-    assert_raise_message(RuntimeError,
-                         'cumsum was found to be unstable: its last element '
-                         'does not correspond to sum',
-                         stable_cumsum, r, rtol=0, atol=0)
+    assert_warns(ConvergenceWarning, stable_cumsum, r, rtol=0, atol=0)
+
+    # test axis parameter
+    A = np.random.RandomState(36).randint(1000, size=(5, 5, 5))
+    assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0))
+    assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1))
+    assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2))
-- 
GitLab