From 25b7bffed1d28b98471315aaf404195435c89dbb Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Fri, 10 Dec 2010 19:56:33 +0100 Subject: [PATCH] extract the randomized SVD implementation as a toplevel class able to handle sparse data as well --- scikits/learn/pca.py | 221 ++++++++++++++++++++++++-------- scikits/learn/tests/test_pca.py | 74 +++++++---- 2 files changed, 211 insertions(+), 84 deletions(-) diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py index 1a19ebfaca..3e475079a8 100644 --- a/scikits/learn/pca.py +++ b/scikits/learn/pca.py @@ -2,14 +2,16 @@ """ # Author: Alexandre Gramfort <alexandre.gramfort@inria.fr> +# Olivier Grisel <olivier.grisel@ensta.org> # License: BSD Style. -import warnings import numpy as np from scipy import linalg from .base import BaseEstimator -from .utils.extmath import fast_logdet, fast_svd +from .utils.extmath import fast_logdet +from .utils.extmath import fast_svd +from .utils.extmath import safe_sparse_dot def _assess_dimension_(spectrum, rank, n_samples, dim): @@ -86,7 +88,6 @@ def _infer_dimension_(spectrum, n, p): return ll.argmax() -################################################################################ class PCA(BaseEstimator): """Principal component analysis (PCA) @@ -94,51 +95,47 @@ class PCA(BaseEstimator): data and keeping only the most significant singular vectors to project the data to a lower dimensional space. + This implementation uses the scipy.linalg implementation of the singular + value decomposition. It only works for dense arrays and is not scalable to + large dimensional data. + + The time complexity of this implementation is O(n ** 3) assuming + n ~ n_samples ~ n_features. + Parameters ---------- X: array-like, shape (n_samples, n_features) Training vector, where n_samples in the number of samples and n_features is the number of features. - Attributes - ---------- - n_comp: int, none or string - Number of components - if n_comp is not set all components are kept - if n_comp=='mle', Minka's MLE is used to guess the dimension + n_components: int, none or string + Number of components to keep. + if n_components is not set all components are kept: + n_components == min(n_samples, n_features) + if n_components == 'mle', Minka's MLE is used to guess the dimension copy: bool If False, data passed to fit are overwritten - components_: array, [n_features, n_comp] - Components with maximum variance. - - do_fast_svd: bool, optional - If True, the k-truncated SVD is computed using random projections - which speeds up the computation on large arrays. If all the - components are to be computed (as in n_comp=None or - n_comp='mle'), this option has no effects. Note that the solution will - be correct only if the requested n_comp is as large as the approximate - effective rank of the data. - - explained_variance_: array, [n_comp] - Percentage of variance explained by each of the selected components. - k is not set then all components are stored and the sum of - explained variances is equal to 1.0 - whiten: bool, optional When True (False by default) the components_ vectors are divided - by the singular values to ensure uncorrelated outputs with unit - component-wise variances. + by n_samples times singular values to ensure uncorrelated outputs + with unit component-wise variances. Whitening will remove some information from the transformed signal (the relative variance scales of the components) but can sometime improve the predictive accuracy of the downstream estimators by making there data respect some hard-wired assumptions. - iterated_power: int, optional - Number of iteration for the power method if do_fast_svd is True. 3 by - default. + Attributes + ---------- + components_: array, [n_features, n_comp] + Components with maximum variance. + + explained_variance_ratio_: array, [n_comp] + Percentage of variance explained by each of the selected components. + k is not set then all components are stored and the sum of + explained variances is equal to 1.0 Notes ----- @@ -150,23 +147,21 @@ class PCA(BaseEstimator): >>> import numpy as np >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) >>> from scikits.learn.pca import PCA - >>> pca = PCA(n_comp=2) + >>> pca = PCA(n_components=2) >>> pca.fit(X) - PCA(do_fast_svd=False, n_comp=2, copy=True, whiten=False, iterated_power=3) + PCA(copy=True, n_components=2, whiten=False) >>> print pca.explained_variance_ratio_ [ 0.99244289 0.00755711] See also -------- ProbabilisticPCA + RandomizedPCA """ - def __init__(self, n_comp=None, copy=True, do_fast_svd=False, - iterated_power=3, whiten=False): - self.n_comp = n_comp + def __init__(self, n_components=None, copy=True, whiten=False): + self.n_components = n_components self.copy = copy - self.do_fast_svd = do_fast_svd - self.iterated_power = iterated_power self.whiten = whiten def fit(self, X, **params): @@ -179,15 +174,7 @@ class PCA(BaseEstimator): # Center data self.mean_ = np.mean(X, axis=0) X -= self.mean_ - if self.do_fast_svd: - if self.n_comp == "mle" or self.n_comp is None: - warnings.warn('All components are to be computed' - 'Not using fast truncated SVD') - U, S, V = linalg.svd(X, full_matrices=False) - else: - U, S, V = fast_svd(X, self.n_comp, q=self.iterated_power) - else: - U, S, V = linalg.svd(X, full_matrices=False) + U, S, V = linalg.svd(X, full_matrices=False) self.explained_variance_ = (S ** 2) / n_samples self.explained_variance_ratio_ = self.explained_variance_ / \ self.explained_variance_.sum() @@ -198,15 +185,16 @@ class PCA(BaseEstimator): else: self.components_ = V.T - if self.n_comp == 'mle': - self.n_comp = _infer_dimension_(self.explained_variance_, + if self.n_components == 'mle': + self.n_components = _infer_dimension_(self.explained_variance_, n_samples, X.shape[1]) - if self.n_comp is not None: - self.components_ = self.components_[:, :self.n_comp] - self.explained_variance_ = self.explained_variance_[:self.n_comp] - self.explained_variance_ratio_ = self.explained_variance_ratio_[ - :self.n_comp] + if self.n_components is not None: + self.components_ = self.components_[:, :self.n_components] + self.explained_variance_ = \ + self.explained_variance_[:self.n_components] + self.explained_variance_ratio_ = \ + self.explained_variance_ratio_[:self.n_components] return self @@ -217,7 +205,6 @@ class PCA(BaseEstimator): return Xr -################################################################################ class ProbabilisticPCA(PCA): """Additional layer on top of PCA that add a probabilistic evaluation @@ -238,14 +225,14 @@ class ProbabilisticPCA(PCA): Xr = X - self.mean_ Xr -= np.dot(np.dot(Xr, self.components_), self.components_.T) n_samples = X.shape[0] - if self.dim <= self.n_comp: + if self.dim <= self.n_components: delta = np.zeros(self.dim) elif homoscedastic: delta = (Xr ** 2).sum() / (n_samples*(self.dim)) * np.ones(self.dim) else: - delta = (Xr ** 2).mean(0) / (self.dim - self.n_comp) + delta = (Xr ** 2).mean(0) / (self.dim - self.n_components) self.covariance_ = np.diag(delta) - for k in range(self.n_comp): + for k in range(self.n_components): add_cov = np.dot( self.components_[:, k:k+1], self.components_[:, k:k+1].T) self.covariance_ += self.explained_variance_[k] * add_cov @@ -272,3 +259,123 @@ class ProbabilisticPCA(PCA): log_like += fast_logdet(self.precision_) - \ self.dim / 2 * np.log(2 * np.pi) return log_like + + +class RandomizedPCA(BaseEstimator): + """Principal component analysis (PCA) using randomized SVD + + Linear dimensionality reduction using approximated Singular Value + Decomposition of the data and keeping only the most significant + singular vectors to project the data to a lower dimensional space. + + This implementation uses a randomized SVD implementation and can + handle both scipy.sparse and numpy dense arrays as input. + + Parameters + ---------- + X: array-like or scipy.sparse matrix, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples and + n_features is the number of features. + + n_components: int + Maximum number of components to keep: default is 50. + + copy: bool + If False, data passed to fit are overwritten + + iterated_power: int, optional + Number of iteration for the power method. 3 by default. + + whiten: bool, optional + When True (False by default) the components_ vectors are divided + by the singular values to ensure uncorrelated outputs with unit + component-wise variances. + + Whitening will remove some information from the transformed signal + (the relative variance scales of the components) but can sometime + improve the predictive accuracy of the downstream estimators by + making there data respect some hard-wired assumptions. + + Attributes + ---------- + components_: array, [n_features, n_components] + Components with maximum variance. + + explained_variance_ratio_: array, [n_components] + Percentage of variance explained by each of the selected components. + k is not set then all components are stored and the sum of + explained variances is equal to 1.0 + + References + ----- + Finding structure with randomness: Stochastic algorithms for constructing + approximate matrix decompositions + Halko, et al., 2009 (arXiv:909) + + A randomized algorithm for the decomposition of matrices + Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert + + Examples + -------- + >>> import numpy as np + >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) + >>> from scikits.learn.pca import PCA + >>> pca = RandomizedPCA(n_components=2) + >>> pca.fit(X) + RandomizedPCA(copy=True, n_components=2, iterated_power=3, whiten=False) + >>> print pca.explained_variance_ratio_ + [ 0.99244289 0.00755711] + + See also + -------- + PCA + ProbabilisticPCA + """ + + def __init__(self, n_components, copy=True, iterated_power=3, + whiten=False): + self.n_components = n_components + self.copy = copy + self.iterated_power = iterated_power + self.whiten = whiten + self.mean = None + + def fit(self, X, **params): + """Fit the model to the data X""" + self._set_params(**params) + n_samples = X.shape[0] + + if self.copy: + X = X.copy() + + if not hasattr(X, 'todense'): + # not a sparse matrix, ensure this is a 2D array + X = np.atleast_2d(X) + + # Center data + self.mean_ = np.mean(X, axis=0) + X -= self.mean_ + + U, S, V = fast_svd(X, self.n_components, q=self.iterated_power) + + self.explained_variance_ = (S ** 2) / n_samples + self.explained_variance_ratio_ = self.explained_variance_ / \ + self.explained_variance_.sum() + + if self.whiten: + n = X.shape[0] + self.components_ = np.dot(V.T, np.diag(1.0 / S)) * np.sqrt(n) + else: + self.components_ = V.T + + return self + + def transform(self, X): + """Apply the dimension reduction learned on the training data.""" + if self.mean is not None: + X = X - self.mean_ + + X = safe_sparse_dot(X, self.components_) + return X + + diff --git a/scikits/learn/tests/test_pca.py b/scikits/learn/tests/test_pca.py index 1282213625..2e63012e47 100644 --- a/scikits/learn/tests/test_pca.py +++ b/scikits/learn/tests/test_pca.py @@ -3,17 +3,22 @@ from numpy.random import randn from nose.tools import assert_true from nose.tools import assert_equal +from scipy.sparse import csr_matrix from numpy.testing import assert_almost_equal from .. import datasets -from ..pca import PCA, ProbabilisticPCA, _assess_dimension_, _infer_dimension_ +from ..pca import PCA +from ..pca import ProbabilisticPCA +from ..pca import RandomizedPCA +from ..pca import _assess_dimension_ +from ..pca import _infer_dimension_ iris = datasets.load_iris() def test_pca(): """PCA on dense arrays""" - pca = PCA(n_comp=2) + pca = PCA(n_components=2) X = iris.data X_r = pca.fit(X).transform(X) np.testing.assert_equal(X_r.shape[1], 2) @@ -28,7 +33,7 @@ def test_whitening(): np.random.seed(0) n_samples = 100 n_features = 80 - n_components = 30 + n_componentsonents = 30 rank = 50 # some low rank data with correlated features @@ -45,18 +50,18 @@ def test_whitening(): assert_almost_equal(X.std(axis=0).std(), 43.9, 1) # whiten the data while projecting to the lower dim subspace - pca = PCA(n_comp=n_components, whiten=True).fit(X) + pca = PCA(n_components=n_componentsonents, whiten=True).fit(X) X_whitened = pca.transform(X) - assert_equal(X_whitened.shape, (n_samples, n_components)) + assert_equal(X_whitened.shape, (n_samples, n_componentsonents)) # all output component have unit variances - assert_almost_equal(X_whitened.std(axis=0), np.ones(n_components)) + assert_almost_equal(X_whitened.std(axis=0), np.ones(n_componentsonents)) # is possible to project on the low dim space without scaling by the # singular values - pca = PCA(n_comp=n_components, whiten=False).fit(X) + pca = PCA(n_components=n_componentsonents, whiten=False).fit(X) X_unwhitened = pca.transform(X) - assert_equal(X_unwhitened.shape, (n_samples, n_components)) + assert_equal(X_unwhitened.shape, (n_samples, n_componentsonents)) # in that case the output components still have varying variances assert_almost_equal(X_unwhitened.std(axis=0).std(), 74.1, 1) @@ -67,24 +72,39 @@ def test_pca_check_projection(): n, p = 100, 3 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5]) - pca = PCA(n_comp=2) - pca.fit(X) - Xt = 0.1* randn(1, p) + np.array([3, 4, 5]) - Yt = pca.transform(Xt) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + + Yt = PCA(n_components=2).fit(X).transform(Xt) Yt /= np.sqrt((Yt**2).sum()) + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) -def test_fast_pca_check_projection(): - """Test that the projection of data is correct""" +def test_randomized_pca_check_projection(): + """Test that the projection by RandomizedPCA on dense data is correct""" n, p = 100, 3 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5]) - pca = PCA(n_comp=2, do_fast_svd=True) - pca.fit(X) - Xt = 0.1* randn(1, p) + np.array([3, 4, 5]) - Yt = pca.transform(Xt) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + + Yt = RandomizedPCA(n_components=2).fit(X).transform(Xt) + Yt /= np.sqrt((Yt ** 2).sum()) + + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) + + +def test_sparse_randomized_pca_check_projection(): + """Test that the projection by RandomizedPCA on sparse data is correct""" + n, p = 100, 3 + X = randn(n, p) * .1 + X[:10] += np.array([3, 4, 5]) + X = csr_matrix(X) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + Xt = csr_matrix(Xt) + + Yt = RandomizedPCA(n_components=2).fit(X).transform(Xt) Yt /= np.sqrt((Yt ** 2).sum()) + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) @@ -93,9 +113,9 @@ def test_pca_dim(): n, p = 100, 5 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5, 1, 2]) - pca = PCA(n_comp='mle') + pca = PCA(n_components='mle') pca.fit(X) - assert_true(pca.n_comp == 1) + assert_true(pca.n_components == 1) def test_infer_dim_1(): @@ -106,7 +126,7 @@ def test_infer_dim_1(): n, p = 1000, 5 X = randn(n, p) * .1 + randn(n, 1) * np.array([3, 4, 5, 1, 2]) \ + np.array([1, 0, 7, 4, 6]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ ll = [] @@ -125,7 +145,7 @@ def test_infer_dim_2(): X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5, 1, 2]) X[10:20] += np.array([6, 0, 7, 2, -1]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ assert_true(_infer_dimension_(spect, n, p) > 1) @@ -139,7 +159,7 @@ def test_infer_dim_3(): X[:10] += np.array([3, 4, 5, 1, 2]) X[10:20] += np.array([6, 0, 7, 2, -1]) X[30:40] += 2*np.array([-1, 1, -1, 1, -1]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ assert_true(_infer_dimension_(spect, n, p) > 2) @@ -149,7 +169,7 @@ def test_probabilistic_pca_1(): """Test that probabilistic PCA yields a reasonable score""" n, p = 1000, 3 X = randn(n, p)*.1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) h = 0.5 * np.log(2 * np.pi * np.exp(1) / 0.1**2) * p @@ -160,7 +180,7 @@ def test_probabilistic_pca_2(): """Test that probabilistic PCA correctly separated different datasets""" n, p = 100, 3 X = randn(n, p) * .1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) ll2 = ppca.score(randn(n, p) * .2 + np.array([3, 4, 5])) @@ -173,7 +193,7 @@ def test_probabilistic_pca_3(): """ n, p = 100, 3 X = randn(n, p)*.1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) ppca.fit(X, False) @@ -188,7 +208,7 @@ def test_probabilistic_pca_4(): Xt = randn(n, p) + randn(n, 1)*np.array([3, 4, 5]) + np.array([1, 0, 7]) ll = np.zeros(p) for k in range(p): - ppca = ProbabilisticPCA(n_comp=k) + ppca = ProbabilisticPCA(n_components=k) ppca.fit(Xl) ll[k] = ppca.score(Xt).mean() -- GitLab