diff --git a/examples/applications/plot_face_recognition.py b/examples/applications/plot_face_recognition.py
index e869733a3a75817b4cf5e54d52b935a7177d09ff..4a9ebaa35be0a9a1ac917bd5caf9277da9ac97d2 100644
--- a/examples/applications/plot_face_recognition.py
+++ b/examples/applications/plot_face_recognition.py
@@ -14,13 +14,13 @@ Expected results for the top 5 most represented people in the dataset::
 
                      precision    recall  f1-score   support
 
-      George_W_Bush       0.84      0.88      0.86       129
-       Colin_Powell       0.80      0.84      0.82        58
-         Tony_Blair       0.66      0.62      0.64        34
-    Donald_Rumsfeld       0.87      0.79      0.83        33
-  Gerhard_Schroeder       0.75      0.64      0.69        28
+  Gerhard_Schroeder       0.87      0.71      0.78        28
+    Donald_Rumsfeld       0.94      0.88      0.91        33
+         Tony_Blair       0.78      0.85      0.82        34
+       Colin_Powell       0.84      0.88      0.86        58
+      George_W_Bush       0.91      0.91      0.91       129
 
-        avg / total       0.81      0.81      0.81       282
+        avg / total       0.88      0.88      0.88       282
 
 """
 print __doc__
@@ -109,8 +109,9 @@ X_train, X_test = X[:split], X[split:]
 y_train, y_test = y[:split], y[split:]
 
 ################################################################################
-# Compute a PCA (eigenfaces) on the training set
-n_components = 200
+# Compute a PCA (eigenfaces) on the face dataset (treated as unlabeled
+# dataset): unsupervised feature extraction / dimensionality reduction
+n_components = 150
 
 print "Extracting the top %d eigenfaces" % n_components
 pca = PCA(n_comp=n_components, do_fast_svd=True).fit(X_train)
@@ -126,7 +127,7 @@ X_test_pca = pca.transform(X_test)
 # Train a SVM classification model
 
 print "Fitting the classifier to the training set"
-clf = SVC(C=100).fit(X_train_pca, y_train, class_weight="auto")
+clf = SVC(C=1, gamma=5).fit(X_train_pca, y_train, class_weight="auto")
 
 
 ################################################################################
diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py
index 8b4796c4c8ac770dc79d5d469d218c0e15d68695..e8bb5120d4b2d7778648b243ff86db9f6c344925 100644
--- a/scikits/learn/pca.py
+++ b/scikits/learn/pca.py
@@ -90,6 +90,10 @@ def _infer_dimension_(spectrum, n, p):
 class PCA(BaseEstimator):
     """Principal component analysis (PCA)
 
+    Linear dimensionality reduction using Singular Value Decomposition of the
+    data and keeping only the most significant singular vectors to project the
+    data to a lower dimensional space.
+
     Parameters
     ----------
     X: array-like, shape (n_samples, n_features)
@@ -107,7 +111,7 @@ class PCA(BaseEstimator):
         If False, data passed to fit are overwritten
 
     components_: array, [n_features, n_comp]
-        Components with maximum variance
+        Components with maximum variance.
 
     do_fast_svd: bool, optional
         If True, the k-truncated SVD is computed using random projections
@@ -122,6 +126,11 @@ class PCA(BaseEstimator):
         k is not set then all components are stored and the sum of
         explained variances is equal to 1.0
 
+    whiten: bool, optional
+        If True (default) the components_ vectors are divided by the
+        singular values to ensure uncorrelated outputs with identical
+        component-wise variances.
+
     iterated_power: int, optional
         Number of iteration for the power method if do_fast_svd is True. 3 by
         default.
@@ -138,7 +147,7 @@ class PCA(BaseEstimator):
     >>> from scikits.learn.pca import PCA
     >>> pca = PCA(n_comp=2)
     >>> pca.fit(X)
-    PCA(do_fast_svd=False, n_comp=2, copy=True, iterated_power=3)
+    PCA(do_fast_svd=False, n_comp=2, copy=True, whiten=True, iterated_power=3)
     >>> print pca.explained_variance_ratio_
     [ 0.99244289  0.00755711]
 
@@ -148,11 +157,12 @@ class PCA(BaseEstimator):
 
     """
     def __init__(self, n_comp=None, copy=True, do_fast_svd=False,
-                 iterated_power=3):
+                 iterated_power=3, whiten=True):
         self.n_comp = n_comp
         self.copy = copy
         self.do_fast_svd = do_fast_svd
         self.iterated_power = iterated_power
+        self.whiten = whiten
 
     def fit(self, X, **params):
         """Fit the model to the data X"""
@@ -176,8 +186,13 @@ class PCA(BaseEstimator):
         self.explained_variance_ = (S ** 2) / n_samples
         self.explained_variance_ratio_ = self.explained_variance_ / \
                                         self.explained_variance_.sum()
-        self.components_ = V.T
-        if self.n_comp=='mle':
+
+        if self.whiten:
+            self.components_ = np.dot(V.T, np.diag(1.0 / S))
+        else:
+            self.components_ = V.T
+
+        if self.n_comp == 'mle':
             self.n_comp = _infer_dimension_(self.explained_variance_,
                                             n_samples, X.shape[1])
 
diff --git a/scikits/learn/tests/test_pca.py b/scikits/learn/tests/test_pca.py
index dfd060538bbcde55f3432d80ed0396f50706cac2..f65694e77457b5a7c58173987a959f6271b16784 100644
--- a/scikits/learn/tests/test_pca.py
+++ b/scikits/learn/tests/test_pca.py
@@ -1,6 +1,9 @@
 import numpy as np
 from numpy.random import randn
 from nose.tools import assert_true
+from nose.tools import assert_equal
+
+from numpy.testing import assert_almost_equal
 
 from .. import datasets
 from ..pca import PCA, ProbabilisticPCA, _assess_dimension_, _infer_dimension_
@@ -17,8 +20,40 @@ def test_pca():
 
     pca = PCA()
     pca.fit(X)
-    np.testing.assert_almost_equal(pca.explained_variance_ratio_.sum(),
-                                   1.0, 3)
+    assert_almost_equal(pca.explained_variance_ratio_.sum(), 1.0, 3)
+
+
+def test_whitening():
+    """Check that PCA output has unit-variance"""
+    np.random.seed(0)
+
+    # some low rank data with correlated features
+    X = np.dot(randn(100, 50),
+               np.dot(np.diag(np.linspace(10.0, 1.0, 50)),
+                      randn(50, 80)))
+    # the component-wise variance of the first 50 features is 3 times the
+    # mean component-wise variance of the remaingin 30 features
+    X[:, :50] *= 3
+
+    assert_equal(X.shape, (100, 80))
+
+    # the component-wise variance is thus highly varying:
+    assert_almost_equal(X.std(axis=0).std(), 43.9, 1)
+
+    # whiten by default
+    X_whitened = PCA(n_comp=30).fit(X).transform(X)
+    assert_equal(X_whitened.shape, (100, 30))
+
+    # all output component have identical variance
+    assert_almost_equal(X_whitened.std(axis=0).std(), 0.0, 3)
+
+    # is possible to project on the low dim space without scaling by the
+    # singular values
+    X_unwhitened = PCA(n_comp=30, whiten=False).fit(X).transform(X)
+    assert_equal(X_unwhitened.shape, (100, 30))
+
+    # in that case the output components still have varying variances
+    assert_almost_equal(X_unwhitened.std(axis=0).std(), 74.1, 1)
 
 
 def test_pca_check_projection():
@@ -147,7 +182,7 @@ def test_probabilistic_pca_4():
     Xt = randn(n, p) + randn(n, 1)*np.array([3, 4, 5]) + np.array([1, 0, 7])
     ll = np.zeros(p)
     for k in range(p):
-        ppca = ProbabilisticPCA(n_comp=k)
+        ppca = ProbabilisticPCA(n_comp=k, whiten=False)
         ppca.fit(Xl)
         ll[k] = ppca.score(Xt).mean()