diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py
index be11dff5d86f869a00712ba26076fcb5e950de53..1cb9be8a2852aee6909efa7b2d5bca0e04af4f3d 100644
--- a/scikits/learn/datasets/samples_generator.py
+++ b/scikits/learn/datasets/samples_generator.py
@@ -7,6 +7,7 @@ Generate samples of synthetic data sets.
 
 import numpy as np
 import numpy.random as nr
+from scipy import linalg
 
 
 def test_dataset_classif(n_samples=100, n_features=100, param=[1,1],
@@ -179,3 +180,67 @@ def friedman(n_samples=100, n_features=10, noise_std=1):
     y += noise_std * nr.normal(loc=0, scale=1, size=n_samples)
     return X, y
 
+
+def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10,
+                      tail_strength=0.5, seed=None):
+    """Mostly low rank random matrix with bell-shaped singular values profile
+
+    Most of the variance can be explained by a bell-shaped curve of width
+    effective_rank: the low rank part of the singular values profile is::
+
+      (1 - tail_strength) * exp(-1.0 * (i / effective_rank) ** 2)
+
+    The remaining singular values' tail is fat, decreasing as::
+
+      tail_strength * exp(-0.1 * i / effective_rank).
+
+    The low rank part of the profile can be considered the structured
+    signal part of the data while the tail can be considered the noisy
+    part of the data that cannot be summarized by a low number of linear
+    components (singular vectors).
+
+    This kind of singular profiles is often seen in practice, for instance:
+     - graw level pictures of faces
+     - TF-IDF vectors of text documents crawled from the web
+
+    Parameters
+    ----------
+    n_samples : int
+        number of samples (default is 100)
+
+    n_features : int
+        number of features (default is 100)
+
+    effective_rank : int
+        approximate number of singular vectors required to explain most of the
+        data by linear combinations (default is 10)
+
+    tail_strength: float between 0.0 and 1.0
+        relative importance of the fat noisy tail of the singular values
+        profile.
+
+    """
+    if isinstance(seed, np.random.RandomState):
+        random = seed
+    elif seed is not None:
+        random = np.random.RandomState(seed)
+    else:
+        random = np.random
+
+    n = min(n_samples, n_features)
+
+    # random (ortho normal) vectors
+    u = linalg.qr(random.randn(n_samples, n))[0][:, :n]
+    v = linalg.qr(random.randn(n_features, n))[0][:, :n].T
+
+    # index of the singular values
+    i = np.arange(n, dtype=np.float64)
+
+    # build the singular profile by assembling signal and noise components
+    low_rank = (1 - tail_strength) * np.exp(-1.0 * (i / effective_rank) ** 2)
+    tail = tail_strength * np.exp(-0.1 * i / effective_rank)
+    s = np.identity(n) * (low_rank + tail)
+
+    return np.dot(np.dot(u, s), v)
+
+
diff --git a/scikits/learn/utils/extmath.py b/scikits/learn/utils/extmath.py
index 1ec6cc1858d7254311fa5234f27ef66baa152925..217e7268bd8793176324bd206a0c1404e189d03e 100644
--- a/scikits/learn/utils/extmath.py
+++ b/scikits/learn/utils/extmath.py
@@ -129,8 +129,8 @@ def fast_svd(M, k, p=None, rng=0, q=0):
 
     References
     ==========
-    Finding structure with randomness: Stochastic
-    algorithms for constructing approximate matrix decompositions,
+    Finding structure with randomness: Stochastic algorithms for constructing
+    approximate matrix decompositions
     Halko, et al., 2009 (arXiv:909)
 
     A randomized algorithm for the decomposition of matrices
diff --git a/scikits/learn/utils/tests/test_svd.py b/scikits/learn/utils/tests/test_svd.py
index 96d9dc0a028076663fbc42a6138706e4081144e7..b1f96d9561372969c5ab3570ef15120bde5c85e0 100644
--- a/scikits/learn/utils/tests/test_svd.py
+++ b/scikits/learn/utils/tests/test_svd.py
@@ -8,7 +8,8 @@ from scipy import linalg
 from numpy.testing import assert_equal
 from numpy.testing import assert_almost_equal
 
-from ..extmath import fast_svd
+from scikits.learn.utils.extmath import fast_svd
+from scikits.learn.datasets.samples_generator import low_rank_fat_tail
 
 
 def test_fast_svd():
@@ -16,12 +17,12 @@ def test_fast_svd():
     n_samples = 100
     n_features = 500
     rank = 5
-    k = 100
+    k = 10
 
-    # generate a matrix X of rank `rank`
-    np.random.seed(42)
-    X = np.dot(np.random.randn(n_samples, rank),
-               np.random.randn(rank, n_features))
+    # generate a matrix X of approximate effective rank `rank` and no noise
+    # component (very structured signal):
+    X = low_rank_fat_tail(n_samples, n_features, effective_rank=rank,
+                          tail_strength=0.0, seed=0)
     assert_equal(X.shape, (n_samples, n_features))
 
     # compute the singular values of X using the slow exact method
@@ -49,3 +50,34 @@ def test_fast_svd():
     assert_almost_equal(s[:rank], sa[:rank])
 
 
+def test_fast_svd_with_noise():
+    """Check that extmath.fast_svd can handle noisy matrices"""
+    n_samples = 100
+    n_features = 500
+    rank = 5
+    k = 10
+
+    # generate a matrix X wity structure approximate rank `rank` and an
+    # important noisy component
+    X = low_rank_fat_tail(n_samples, n_features, effective_rank=rank,
+                          tail_strength=0.5, seed=0)
+    assert_equal(X.shape, (n_samples, n_features))
+
+    # compute the singular values of X using the slow exact method
+    _, s, _ = linalg.svd(X, full_matrices=False)
+
+    # compute the singular values of X using the fast approximate method without
+    # the iterated power method
+    _, sa, _ = fast_svd(X, k, q=0)
+
+    # the approximation does not tolerate the noise:
+    assert np.abs(s[:rank] - sa[:rank]).max() > 0.1
+
+    # compute the singular values of X using the fast approximate method with
+    # iterated power method
+    _, sap, _ = fast_svd(X, k, q=3)
+
+    # the iterated power method is helping getting rid of the noise:
+    assert_almost_equal(s[:rank], sap[:rank], decimal=5)
+
+