From ea494a0a275a8b5660fd479a6edab5c0127b2fab Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Tue, 7 Dec 2010 21:20:06 +0100 Subject: [PATCH] ENH: make the PCA transformer perform variance scaling by default + update the face recognition accordingly --- .../applications/plot_face_recognition.py | 19 +++++---- scikits/learn/pca.py | 25 ++++++++--- scikits/learn/tests/test_pca.py | 41 +++++++++++++++++-- 3 files changed, 68 insertions(+), 17 deletions(-) diff --git a/examples/applications/plot_face_recognition.py b/examples/applications/plot_face_recognition.py index e869733a3a..4a9ebaa35b 100644 --- a/examples/applications/plot_face_recognition.py +++ b/examples/applications/plot_face_recognition.py @@ -14,13 +14,13 @@ Expected results for the top 5 most represented people in the dataset:: precision recall f1-score support - George_W_Bush 0.84 0.88 0.86 129 - Colin_Powell 0.80 0.84 0.82 58 - Tony_Blair 0.66 0.62 0.64 34 - Donald_Rumsfeld 0.87 0.79 0.83 33 - Gerhard_Schroeder 0.75 0.64 0.69 28 + Gerhard_Schroeder 0.87 0.71 0.78 28 + Donald_Rumsfeld 0.94 0.88 0.91 33 + Tony_Blair 0.78 0.85 0.82 34 + Colin_Powell 0.84 0.88 0.86 58 + George_W_Bush 0.91 0.91 0.91 129 - avg / total 0.81 0.81 0.81 282 + avg / total 0.88 0.88 0.88 282 """ print __doc__ @@ -109,8 +109,9 @@ X_train, X_test = X[:split], X[split:] y_train, y_test = y[:split], y[split:] ################################################################################ -# Compute a PCA (eigenfaces) on the training set -n_components = 200 +# Compute a PCA (eigenfaces) on the face dataset (treated as unlabeled +# dataset): unsupervised feature extraction / dimensionality reduction +n_components = 150 print "Extracting the top %d eigenfaces" % n_components pca = PCA(n_comp=n_components, do_fast_svd=True).fit(X_train) @@ -126,7 +127,7 @@ X_test_pca = pca.transform(X_test) # Train a SVM classification model print "Fitting the classifier to the training set" -clf = SVC(C=100).fit(X_train_pca, y_train, class_weight="auto") +clf = SVC(C=1, gamma=5).fit(X_train_pca, y_train, class_weight="auto") ################################################################################ diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py index 8b4796c4c8..e8bb5120d4 100644 --- a/scikits/learn/pca.py +++ b/scikits/learn/pca.py @@ -90,6 +90,10 @@ def _infer_dimension_(spectrum, n, p): class PCA(BaseEstimator): """Principal component analysis (PCA) + Linear dimensionality reduction using Singular Value Decomposition of the + data and keeping only the most significant singular vectors to project the + data to a lower dimensional space. + Parameters ---------- X: array-like, shape (n_samples, n_features) @@ -107,7 +111,7 @@ class PCA(BaseEstimator): If False, data passed to fit are overwritten components_: array, [n_features, n_comp] - Components with maximum variance + Components with maximum variance. do_fast_svd: bool, optional If True, the k-truncated SVD is computed using random projections @@ -122,6 +126,11 @@ class PCA(BaseEstimator): k is not set then all components are stored and the sum of explained variances is equal to 1.0 + whiten: bool, optional + If True (default) the components_ vectors are divided by the + singular values to ensure uncorrelated outputs with identical + component-wise variances. + iterated_power: int, optional Number of iteration for the power method if do_fast_svd is True. 3 by default. @@ -138,7 +147,7 @@ class PCA(BaseEstimator): >>> from scikits.learn.pca import PCA >>> pca = PCA(n_comp=2) >>> pca.fit(X) - PCA(do_fast_svd=False, n_comp=2, copy=True, iterated_power=3) + PCA(do_fast_svd=False, n_comp=2, copy=True, whiten=True, iterated_power=3) >>> print pca.explained_variance_ratio_ [ 0.99244289 0.00755711] @@ -148,11 +157,12 @@ class PCA(BaseEstimator): """ def __init__(self, n_comp=None, copy=True, do_fast_svd=False, - iterated_power=3): + iterated_power=3, whiten=True): self.n_comp = n_comp self.copy = copy self.do_fast_svd = do_fast_svd self.iterated_power = iterated_power + self.whiten = whiten def fit(self, X, **params): """Fit the model to the data X""" @@ -176,8 +186,13 @@ class PCA(BaseEstimator): self.explained_variance_ = (S ** 2) / n_samples self.explained_variance_ratio_ = self.explained_variance_ / \ self.explained_variance_.sum() - self.components_ = V.T - if self.n_comp=='mle': + + if self.whiten: + self.components_ = np.dot(V.T, np.diag(1.0 / S)) + else: + self.components_ = V.T + + if self.n_comp == 'mle': self.n_comp = _infer_dimension_(self.explained_variance_, n_samples, X.shape[1]) diff --git a/scikits/learn/tests/test_pca.py b/scikits/learn/tests/test_pca.py index dfd060538b..f65694e774 100644 --- a/scikits/learn/tests/test_pca.py +++ b/scikits/learn/tests/test_pca.py @@ -1,6 +1,9 @@ import numpy as np from numpy.random import randn from nose.tools import assert_true +from nose.tools import assert_equal + +from numpy.testing import assert_almost_equal from .. import datasets from ..pca import PCA, ProbabilisticPCA, _assess_dimension_, _infer_dimension_ @@ -17,8 +20,40 @@ def test_pca(): pca = PCA() pca.fit(X) - np.testing.assert_almost_equal(pca.explained_variance_ratio_.sum(), - 1.0, 3) + assert_almost_equal(pca.explained_variance_ratio_.sum(), 1.0, 3) + + +def test_whitening(): + """Check that PCA output has unit-variance""" + np.random.seed(0) + + # some low rank data with correlated features + X = np.dot(randn(100, 50), + np.dot(np.diag(np.linspace(10.0, 1.0, 50)), + randn(50, 80))) + # the component-wise variance of the first 50 features is 3 times the + # mean component-wise variance of the remaingin 30 features + X[:, :50] *= 3 + + assert_equal(X.shape, (100, 80)) + + # the component-wise variance is thus highly varying: + assert_almost_equal(X.std(axis=0).std(), 43.9, 1) + + # whiten by default + X_whitened = PCA(n_comp=30).fit(X).transform(X) + assert_equal(X_whitened.shape, (100, 30)) + + # all output component have identical variance + assert_almost_equal(X_whitened.std(axis=0).std(), 0.0, 3) + + # is possible to project on the low dim space without scaling by the + # singular values + X_unwhitened = PCA(n_comp=30, whiten=False).fit(X).transform(X) + assert_equal(X_unwhitened.shape, (100, 30)) + + # in that case the output components still have varying variances + assert_almost_equal(X_unwhitened.std(axis=0).std(), 74.1, 1) def test_pca_check_projection(): @@ -147,7 +182,7 @@ def test_probabilistic_pca_4(): Xt = randn(n, p) + randn(n, 1)*np.array([3, 4, 5]) + np.array([1, 0, 7]) ll = np.zeros(p) for k in range(p): - ppca = ProbabilisticPCA(n_comp=k) + ppca = ProbabilisticPCA(n_comp=k, whiten=False) ppca.fit(Xl) ll[k] = ppca.score(Xt).mean() -- GitLab