From fa598738570e03c50df5836a4e0a4285b91c957c Mon Sep 17 00:00:00 2001 From: yangarbiter <yangarbiter@users.noreply.github.com> Date: Mon, 17 Oct 2016 06:49:23 -0700 Subject: [PATCH] [MRG+1] FIX unstable cumsum (#7376) * FIX unstable cumsum in utils.random * equal_nan = true for isclose since numpy < 1.9 sum is as unstable as cumsum, fallback to np.cumsum * added axis parameter to stable_cumsum * FIX unstable sumsum in ensemble.weight_boosting and utils.stats * FIX axis problem in stable_cumsum * FIX unstable cumsum in mixture.gmm and mixture.dpgmm * FIX unstable cumsum in cluster.k_means_, decomposition.pca, and manifold.locally_linear * FIX unstable sumsum in dataset.samples_generator * added docstring for parameter axis of stable_cumsum * added comment for why fall back to np.cumsum when np version < 1.9 * remove unneeded stable_cumsum * added stable_cumsum's axis testing * FIX numpy docstring for make_sparse_spd_matrix * change stable_cumsum from error to warning --- sklearn/cluster/k_means_.py | 5 +++-- sklearn/datasets/samples_generator.py | 2 +- sklearn/decomposition/pca.py | 3 ++- sklearn/ensemble/weight_boosting.py | 5 +++-- sklearn/manifold/locally_linear.py | 3 ++- sklearn/mixture/dpgmm.py | 4 ++-- sklearn/utils/extmath.py | 23 ++++++++++++++++------- sklearn/utils/stats.py | 3 ++- sklearn/utils/tests/test_extmath.py | 13 +++++++++---- 9 files changed, 40 insertions(+), 21 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 7abd105926..2596a47307 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -18,7 +18,7 @@ import scipy.sparse as sp from ..base import BaseEstimator, ClusterMixin, TransformerMixin from ..metrics.pairwise import euclidean_distances -from ..utils.extmath import row_norms, squared_norm +from ..utils.extmath import row_norms, squared_norm, stable_cumsum from ..utils.sparsefuncs_fast import assign_rows_csr from ..utils.sparsefuncs import mean_variance_axis from ..utils.fixes import astype @@ -106,7 +106,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): # Choose center candidates by sampling with probability proportional # to the squared distance to the closest existing center rand_vals = random_state.random_sample(n_local_trials) * current_pot - candidate_ids = np.searchsorted(closest_dist_sq.cumsum(), rand_vals) + candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), + rand_vals) # Compute distances to center candidates distance_to_candidates = euclidean_distances( diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py index 53ee8987ba..acd0733754 100644 --- a/sklearn/datasets/samples_generator.py +++ b/sklearn/datasets/samples_generator.py @@ -1194,7 +1194,7 @@ def make_sparse_spd_matrix(dim=1, alpha=0.95, norm_diag=False, The size of the random matrix to generate. alpha : float between 0 and 1, optional (default=0.95) - The probability that a coefficient is zero (see notes). Larger values + The probability that a coefficient is zero (see notes). Larger values enforce more sparsity. random_state : int, RandomState instance or None, optional (default=None) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index b9b171301f..d07a0bf2ed 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -24,6 +24,7 @@ from ..utils import deprecated from ..utils import check_random_state, as_float_array from ..utils import check_array from ..utils.extmath import fast_dot, fast_logdet, randomized_svd, svd_flip +from ..utils.extmath import stable_cumsum from ..utils.validation import check_is_fitted from ..utils.arpack import svds @@ -393,7 +394,7 @@ class PCA(_BasePCA): elif 0 < n_components < 1.0: # number of components for which the cumulated explained # variance percentage is superior to the desired threshold - ratio_cumsum = explained_variance_ratio_.cumsum() + ratio_cumsum = stable_cumsum(explained_variance_ratio_) n_components = np.searchsorted(ratio_cumsum, n_components) + 1 # Compute noise covariance using Probabilistic PCA model diff --git a/sklearn/ensemble/weight_boosting.py b/sklearn/ensemble/weight_boosting.py index 56d7d6ff80..16afc4311e 100644 --- a/sklearn/ensemble/weight_boosting.py +++ b/sklearn/ensemble/weight_boosting.py @@ -38,6 +38,7 @@ from ..tree import DecisionTreeClassifier, DecisionTreeRegressor from ..tree.tree import BaseDecisionTree from ..tree._tree import DTYPE from ..utils import check_array, check_X_y, check_random_state +from ..utils.extmath import stable_cumsum from ..metrics import accuracy_score, r2_score from sklearn.utils.validation import has_fit_parameter, check_is_fitted @@ -1002,7 +1003,7 @@ class AdaBoostRegressor(BaseWeightBoosting, RegressorMixin): # Weighted sampling of the training set with replacement # For NumPy >= 1.7.0 use np.random.choice - cdf = sample_weight.cumsum() + cdf = stable_cumsum(sample_weight) cdf /= cdf[-1] uniform_samples = random_state.random_sample(X.shape[0]) bootstrap_idx = cdf.searchsorted(uniform_samples, side='right') @@ -1059,7 +1060,7 @@ class AdaBoostRegressor(BaseWeightBoosting, RegressorMixin): sorted_idx = np.argsort(predictions, axis=1) # Find index of median prediction for each sample - weight_cdf = self.estimator_weights_[sorted_idx].cumsum(axis=1) + weight_cdf = stable_cumsum(self.estimator_weights_[sorted_idx], axis=1) median_or_above = weight_cdf >= 0.5 * weight_cdf[:, -1][:, np.newaxis] median_idx = median_or_above.argmax(axis=1) diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index a1940333e5..f5a383d58a 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -10,6 +10,7 @@ from scipy.sparse import eye, csr_matrix from ..base import BaseEstimator, TransformerMixin from ..utils import check_random_state, check_array from ..utils.arpack import eigsh +from ..utils.extmath import stable_cumsum from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES from ..neighbors import NearestNeighbors @@ -420,7 +421,7 @@ def locally_linear_embedding( # this is the size of the largest set of eigenvalues # such that Sum[v; v in set]/Sum[v; v not in set] < eta s_range = np.zeros(N, dtype=int) - evals_cumsum = np.cumsum(evals, 1) + evals_cumsum = stable_cumsum(evals, 1) eta_range = evals_cumsum[:, -1:] / evals_cumsum[:, :-1] - 1 for i in range(N): s_range[i] = np.searchsorted(eta_range[i, ::-1], eta) diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index 0b7f11affe..1b119b8b72 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -24,7 +24,7 @@ from scipy.spatial.distance import cdist from ..externals.six.moves import xrange from ..utils import check_random_state, check_array, deprecated -from ..utils.extmath import logsumexp, pinvh, squared_norm +from ..utils.extmath import logsumexp, pinvh, squared_norm, stable_cumsum from ..utils.validation import check_is_fitted from .. import cluster from .gmm import _GMMBase @@ -462,7 +462,7 @@ class _DPGMMBase(_GMMBase): dg1 = digamma(self.gamma_.T[1]) - dg12 dg2 = digamma(self.gamma_.T[2]) - dg12 - cz = np.cumsum(z[:, ::-1], axis=-1)[:, -2::-1] + cz = stable_cumsum(z[:, ::-1], axis=-1)[:, -2::-1] logprior = np.sum(cz * dg2[:-1]) + np.sum(z * dg1) del cz # Save memory z_non_zeros = z[z > np.finfo(np.float32).eps] diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 1857a27adf..741601531d 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -25,7 +25,7 @@ from ._logistic_sigmoid import _log_logistic_sigmoid from ..externals.six.moves import xrange from .sparsefuncs_fast import csr_row_norms from .validation import check_array -from ..exceptions import NonBLASDotWarning +from ..exceptions import ConvergenceWarning, NonBLASDotWarning def norm(x): @@ -844,21 +844,30 @@ def _deterministic_vector_sign_flip(u): return u -def stable_cumsum(arr, rtol=1e-05, atol=1e-08): +def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """Use high precision for cumsum and check that final value matches sum Parameters ---------- arr : array-like To be cumulatively summed as flat + axis : int, optional + Axis along which the cumulative sum is computed. + The default (None) is to compute the cumsum over the flattened array. rtol : float Relative tolerance, see ``np.allclose`` atol : float Absolute tolerance, see ``np.allclose`` """ - out = np.cumsum(arr, dtype=np.float64) - expected = np.sum(arr, dtype=np.float64) - if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): - raise RuntimeError('cumsum was found to be unstable: ' - 'its last element does not correspond to sum') + # sum is as unstable as cumsum for numpy < 1.9 + if np_version < (1, 9): + return np.cumsum(arr, axis=axis, dtype=np.float64) + + out = np.cumsum(arr, axis=axis, dtype=np.float64) + expected = np.sum(arr, axis=axis, dtype=np.float64) + if not np.all(np.isclose(out.take(-1, axis=axis), expected, rtol=rtol, + atol=atol, equal_nan=True)): + warnings.warn('cumsum was found to be unstable: ' + 'its last element does not correspond to sum', + ConvergenceWarning) return out diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 463146d038..265d193e6b 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,6 +1,7 @@ import numpy as np from scipy.stats import rankdata as _sp_rankdata from .fixes import bincount +from ..utils.extmath import stable_cumsum # To remove when we support scipy 0.13 @@ -53,7 +54,7 @@ def _weighted_percentile(array, sample_weight, percentile=50): sorted_idx = np.argsort(array) # Find index of median prediction for each sample - weight_cdf = sample_weight[sorted_idx].cumsum() + weight_cdf = stable_cumsum(sample_weight[sorted_idx]) percentile_idx = np.searchsorted( weight_cdf, (percentile / 100.) * weight_cdf[-1]) return array[sorted_idx[percentile_idx]] diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 55f96cdf15..49a0b4abee 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -18,6 +18,7 @@ from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_warns from sklearn.utils.testing import skip_if_32bit from sklearn.utils.testing import SkipTest from sklearn.utils.fixes import np_version @@ -36,6 +37,7 @@ from sklearn.utils.extmath import _incremental_mean_and_var from sklearn.utils.extmath import _deterministic_vector_sign_flip from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum +from sklearn.exceptions import ConvergenceWarning from sklearn.datasets.samples_generator import make_low_rank_matrix @@ -654,7 +656,10 @@ def test_stable_cumsum(): raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9") assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3])) r = np.random.RandomState(0).rand(100000) - assert_raise_message(RuntimeError, - 'cumsum was found to be unstable: its last element ' - 'does not correspond to sum', - stable_cumsum, r, rtol=0, atol=0) + assert_warns(ConvergenceWarning, stable_cumsum, r, rtol=0, atol=0) + + # test axis parameter + A = np.random.RandomState(36).randint(1000, size=(5, 5, 5)) + assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0)) + assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1)) + assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2)) -- GitLab