diff --git a/benchmarks/bench_plot_svd.py b/benchmarks/bench_plot_svd.py
index 6b24964ed3936e84a73cf8b0216ee2e5477ba532..9f76a23dcc88117ce6a553cbf033cc762090d276 100644
--- a/benchmarks/bench_plot_svd.py
+++ b/benchmarks/bench_plot_svd.py
@@ -8,11 +8,11 @@ import numpy as np
 from collections import defaultdict
 
 from scipy.linalg import svd
-from sklearn.utils.extmath import fast_svd
+from sklearn.utils.extmath import randomized_svd
 from sklearn.datasets.samples_generator import make_low_rank_matrix
 
 
-def compute_bench(samples_range, features_range, q=3, rank=50):
+def compute_bench(samples_range, features_range, n_iterations=3, rank=50):
 
     it = 0
 
@@ -36,16 +36,19 @@ def compute_bench(samples_range, features_range, q=3, rank=50):
             results['scipy svd'].append(time() - tstart)
 
             gc.collect()
-            print "benching scikit-learn fast_svd: q=0"
+            print "benching scikit-learn randomized_svd: n_iterations=0"
             tstart = time()
-            fast_svd(X, rank, q=0)
-            results['scikit-learn fast_svd (q=0)'].append(time() - tstart)
+            randomized_svd(X, rank, n_iterations=0)
+            results['scikit-learn randomized_svd (n_iterations=0)'].append(
+                time() - tstart)
 
             gc.collect()
-            print "benching scikit-learn fast_svd: q=%d " % q
+            print ("benching scikit-learn randomized_svd: n_iterations=%d "
+                   % n_iterations)
             tstart = time()
-            fast_svd(X, rank, q=q)
-            results['scikit-learn fast_svd (q=%d)' % q].append(time() - tstart)
+            randomized_svd(X, rank, n_iterations=n_iterations)
+            results['scikit-learn randomized_svd (n_iterations=%d)'
+                    % n_iterations].append(time() - tstart)
 
     return results
 
diff --git a/doc/developers/utilities.rst b/doc/developers/utilities.rst
index a7dd46009fcb8202cd562b849e12a0968137203e..f1fa601790060130e5410eb9a918eca350c13ce8 100644
--- a/doc/developers/utilities.rst
+++ b/doc/developers/utilities.rst
@@ -76,9 +76,9 @@ Efficient Linear Algebra & Array Operations
 
 - :func:`extmath.randomized_range_finder`: construct an orthonormal matrix
   whose range approximates the range of the input.  This is used in
-  :func:`extmath.fast_svd`, below.
+  :func:`extmath.randomized_svd`, below.
 
-- :func:`extmath.fast_svd`: compute the k-truncated randomized SVD.
+- :func:`extmath.randomized_svd`: compute the k-truncated randomized SVD.
   This algorithm finds the exact truncated singular values decomposition
   using randomization to speed up the computations. It is particularly
   fast on large matrices on which you wish to extract only a small
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 61d08dd4b6f41a1e5b6cf3832f7b7763470e823d..a39259fa649ef504e56117f2fca12a69c53fdb16 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -136,6 +136,12 @@ version 0.9:
 
   - ``BaseDictionaryLearning`` class replaced by ``SparseCodingMixin``.
 
+  - :func:`sklearn.utils.extmath.fast_svd` has been renamed
+    :func:`sklearn.utils.extmath.randomized_svd` and the default
+    oversampling is now fixed to 10 additional random vectors instead
+    of doubling the number of components to extract. The new behavior
+    follows the reference paper.
+
 
 .. _changes_0_9:
 
diff --git a/examples/applications/wikipedia_principal_eigenvector.py b/examples/applications/wikipedia_principal_eigenvector.py
index 729f30793c2890aee5f09dd023ba36eb56a8def4..cbf4e5be80b75d6d2990b4ae5197403a70446987 100644
--- a/examples/applications/wikipedia_principal_eigenvector.py
+++ b/examples/applications/wikipedia_principal_eigenvector.py
@@ -43,7 +43,7 @@ import numpy as np
 
 from scipy import sparse
 
-from sklearn.utils.extmath import fast_svd
+from sklearn.utils.extmath import randomized_svd
 from sklearn.externals.joblib import Memory
 
 
@@ -170,9 +170,9 @@ X, redirects, index_map = get_adjacency_matrix(
     redirects_filename, page_links_filename, limit=5000000)
 names = dict((i, name) for name, i in index_map.iteritems())
 
-print "Computing the principal singular vectors using fast_svd"
+print "Computing the principal singular vectors using randomized_svd"
 t0 = time()
-U, s, V = fast_svd(X, 5, q=3)
+U, s, V = randomized_svd(X, 5, q=3)
 print "done in %0.3fs" % (time() - t0)
 
 # print the names of the wikipedia related strongest compenents of the the
diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py
index 8fdd3208a2707bcd2bae84339b99b8b8a589591d..20afe75089f0b3eefbbb4de8c2643b7acdc02927 100644
--- a/sklearn/decomposition/dict_learning.py
+++ b/sklearn/decomposition/dict_learning.py
@@ -17,7 +17,7 @@ from numpy.lib.stride_tricks import as_strided
 from ..base import BaseEstimator, TransformerMixin
 from ..externals.joblib import Parallel, delayed, cpu_count
 from ..utils import array2d, check_random_state, gen_even_slices, deprecated
-from ..utils.extmath import fast_svd
+from ..utils.extmath import randomized_svd
 from ..linear_model import Lasso, orthogonal_mp_gram, lars_path
 
 
@@ -622,7 +622,7 @@ def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True,
     if dict_init is not None:
         dictionary = dict_init
     else:
-        _, S, dictionary = fast_svd(X, n_atoms)
+        _, S, dictionary = randomized_svd(X, n_atoms)
         dictionary = S[:, np.newaxis] * dictionary
     r = len(dictionary)
     if n_atoms <= r:
diff --git a/sklearn/decomposition/nmf.py b/sklearn/decomposition/nmf.py
index 5a8a74ceb14254bc603838b26de5d13b9e2b1c19..e1ace319d266527bd5719907f271c1e9c12903e3 100644
--- a/sklearn/decomposition/nmf.py
+++ b/sklearn/decomposition/nmf.py
@@ -12,7 +12,7 @@ from __future__ import division
 
 from ..base import BaseEstimator, TransformerMixin
 from ..utils import atleast2d_or_csr, check_random_state
-from ..utils.extmath import fast_svd, safe_sparse_dot
+from ..utils.extmath import randomized_svd, safe_sparse_dot
 
 import numpy as np
 from scipy.optimize import nnls
@@ -106,7 +106,7 @@ def _initialize_nmf(X, n_components, variant=None, eps=1e-6,
     if variant not in (None, 'a', 'ar'):
         raise ValueError("Invalid variant name")
 
-    U, S, V = fast_svd(X, n_components)
+    U, S, V = randomized_svd(X, n_components)
     W, H = np.zeros(U.shape), np.zeros(V.shape)
 
     # The leading singular triplet is non-negative
diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py
index 976ff79daafeaccbe6dfd2cb111840b0f3093c91..bfd67f4b5ef68be7570b78ae43475570fabe097e 100644
--- a/sklearn/decomposition/pca.py
+++ b/sklearn/decomposition/pca.py
@@ -11,7 +11,9 @@ from scipy import linalg
 
 from ..base import BaseEstimator, TransformerMixin
 from ..utils import array2d, check_random_state, as_float_array
-from ..utils.extmath import fast_logdet, fast_svd, safe_sparse_dot
+from ..utils.extmath import fast_logdet
+from ..utils.extmath import safe_sparse_dot
+from ..utils.extmath import randomized_svd
 
 
 def _assess_dimension_(spectrum, rank, n_samples, dim):
@@ -458,9 +460,9 @@ class RandomizedPCA(BaseEstimator, TransformerMixin):
             self.mean_ = np.mean(X, axis=0)
             X -= self.mean_
 
-        U, S, V = fast_svd(X, self.n_components,
-                           n_iterations=self.iterated_power,
-                           random_state=self.random_state)
+        U, S, V = randomized_svd(X, self.n_components,
+                                 n_iterations=self.iterated_power,
+                                 random_state=self.random_state)
 
         self.explained_variance_ = (S ** 2) / n_samples
         self.explained_variance_ratio_ = self.explained_variance_ / \
diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py
index d27d046d5504c15dcb8e2652586db24d6612679d..ad30eec6032e88ce42a7c26e480d3bc30cad7719 100644
--- a/sklearn/utils/extmath.py
+++ b/sklearn/utils/extmath.py
@@ -4,10 +4,12 @@ Extended math utilities.
 # Authors: G. Varoquaux, A. Gramfort, A. Passos, O. Grisel
 # License: BSD
 
+import warnings
 import numpy as np
 from scipy import linalg
 
 from . import check_random_state
+from . import deprecated
 from .fixes import qr_economic
 
 
@@ -115,8 +117,8 @@ def randomized_range_finder(A, size, n_iterations, random_state=None):
     Y = safe_sparse_dot(A, R)
     del R
 
-    # apply q power iterations on Y to make to further 'imprint' the top
-    # singular values of A in Y
+    # perform power iterations with Y to further 'imprint' the top
+    # singular vectors of A in Y
     for i in xrange(n_iterations):
         Y = safe_sparse_dot(A, safe_sparse_dot(A.T, Y))
 
@@ -125,20 +127,22 @@ def randomized_range_finder(A, size, n_iterations, random_state=None):
     return Q
 
 
-def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
-    """Computes the k-truncated randomized SVD
+def randomized_svd(M, n_components, n_oversamples=10, n_iterations=0,
+                   transpose='auto', random_state=0):
+    """Computes a truncated randomized SVD
 
     Parameters
     ----------
     M: ndarray or sparse matrix
         Matrix to decompose
 
-    k: int
+    n_components: int
         Number of singular values and vectors to extract.
 
-    p: int (default is k)
-        Additional number of samples of the range of M to ensure proper
-        conditioning. See the notes below.
+    n_oversamples: int (default is 10)
+        Additional number of random vectors to sample the range of M so as
+        to ensure proper conditioning. The total number of random vectors
+        used to find the range of M is n_components + n_oversamples.
 
     n_iterations: int (default is 0)
         Number of power iterations (can be used to deal with very noisy
@@ -156,14 +160,10 @@ def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
 
     Notes
     -----
-    This algorithm finds the exact truncated singular values decomposition
-    using randomization to speed up the computations. It is particularly
-    fast on large matrices on which you whish to extract only a small
-    number of components.
-
-    (k + p) should be strictly higher than the rank of M. This can be
-    checked by ensuring that the lowest extracted singular value is on
-    the order of the machine precision of floating points.
+    This algorithm finds a (usually very good) approximate truncated
+    singular value decomposition using randomization to speed up the
+    computations. It is particularly fast on large matrices on which
+    you wish to extract only a small number of components.
 
     **References**:
 
@@ -174,10 +174,8 @@ def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
     * A randomized algorithm for the decomposition of matrices
       Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert
     """
-    if p == None:
-        p = k
-
     random_state = check_random_state(random_state)
+    n_random = n_components + n_oversamples
     n_samples, n_features = M.shape
 
     if transpose == 'auto' and n_samples > n_features:
@@ -186,7 +184,7 @@ def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
         # this implementation is a bit faster with smaller shape[1]
         M = M.T
 
-    Q = randomized_range_finder(M, k + p, n_iterations, random_state)
+    Q = randomized_range_finder(M, n_random, n_iterations, random_state)
 
     # project M to the (k + p) dimensional space using the basis vectors
     B = safe_sparse_dot(Q.T, M)
@@ -199,9 +197,16 @@ def fast_svd(M, k, p=None, n_iterations=0, transpose='auto', random_state=0):
 
     if transpose:
         # transpose back the results according to the input convention
-        return V[:k, :].T, s[:k], U[:, :k].T
+        return V[:n_components, :].T, s[:n_components], U[:, :n_components].T
     else:
-        return U[:, :k], s[:k], V[:k, :]
+        return U[:, :n_components], s[:n_components], V[:n_components, :]
+
+
+@deprecated("fast_svd is deprecated in 0.10 and will be removed in 0.12: "
+            "use randomized_svd instead")
+def fast_svd(M, k, p=10, n_iterations=0, transpose='auto', random_state=0):
+    return randomized_svd(M, k, n_oversamples=p, n_iterations=n_iterations,
+                          transpose='auto', random_state=random_state)
 
 
 def logsumexp(arr, axis=0):
diff --git a/sklearn/utils/tests/test_svd.py b/sklearn/utils/tests/test_svd.py
index d714a1f8ec38f93fdd63fff4ac56d77887163a6e..ea4b125e78b16fbd3d129e03622bddf654e73cbe 100644
--- a/sklearn/utils/tests/test_svd.py
+++ b/sklearn/utils/tests/test_svd.py
@@ -8,12 +8,12 @@ from scipy import linalg
 from numpy.testing import assert_equal
 from numpy.testing import assert_almost_equal
 
-from sklearn.utils.extmath import fast_svd
+from sklearn.utils.extmath import randomized_svd
 from sklearn.datasets.samples_generator import make_low_rank_matrix
 
 
-def test_fast_svd_low_rank():
-    """Check that extmath.fast_svd is consistent with linalg.svd"""
+def test_randomized_svd_low_rank():
+    """Check that extmath.randomized_svd is consistent with linalg.svd"""
     n_samples = 100
     n_features = 500
     rank = 5
@@ -29,7 +29,7 @@ def test_fast_svd_low_rank():
     U, s, V = linalg.svd(X, full_matrices=False)
 
     # compute the singular values of X using the fast approximate method
-    Ua, sa, Va = fast_svd(X, k)
+    Ua, sa, Va = randomized_svd(X, k)
     assert_equal(Ua.shape, (n_samples, k))
     assert_equal(sa.shape, (k,))
     assert_equal(Va.shape, (k, n_features))
@@ -45,12 +45,12 @@ def test_fast_svd_low_rank():
     X = sparse.csr_matrix(X)
 
     # compute the singular values of X using the fast approximate method
-    Ua, sa, Va = fast_svd(X, k)
+    Ua, sa, Va = randomized_svd(X, k)
     assert_almost_equal(s[:rank], sa[:rank])
 
 
-def test_fast_svd_low_rank_with_noise():
-    """Check that extmath.fast_svd can handle noisy matrices"""
+def test_randomized_svd_low_rank_with_noise():
+    """Check that extmath.randomized_svd can handle noisy matrices"""
     n_samples = 100
     n_features = 500
     rank = 5
@@ -67,21 +67,21 @@ def test_fast_svd_low_rank_with_noise():
 
     # compute the singular values of X using the fast approximate method
     # without the iterated power method
-    _, sa, _ = fast_svd(X, k, n_iterations=0)
+    _, sa, _ = randomized_svd(X, k, n_iterations=0)
 
     # the approximation does not tolerate the noise:
     assert np.abs(s[:k] - sa).max() > 0.05
 
     # compute the singular values of X using the fast approximate method with
     # iterated power method
-    _, sap, _ = fast_svd(X, k, n_iterations=5)
+    _, sap, _ = randomized_svd(X, k, n_iterations=5)
 
     # the iterated power method is helping getting rid of the noise:
     assert_almost_equal(s[:k], sap, decimal=3)
 
 
-def test_fast_svd_infinite_rank():
-    """Check that extmath.fast_svd can handle noisy matrices"""
+def test_randomized_svd_infinite_rank():
+    """Check that extmath.randomized_svd can handle noisy matrices"""
     n_samples = 100
     n_features = 500
     rank = 5
@@ -98,21 +98,21 @@ def test_fast_svd_infinite_rank():
 
     # compute the singular values of X using the fast approximate method
     # without the iterated power method
-    _, sa, _ = fast_svd(X, k, n_iterations=0)
+    _, sa, _ = randomized_svd(X, k, n_iterations=0)
 
     # the approximation does not tolerate the noise:
     assert np.abs(s[:k] - sa).max() > 0.1
 
     # compute the singular values of X using the fast approximate method with
     # iterated power method
-    _, sap, _ = fast_svd(X, k, n_iterations=5)
+    _, sap, _ = randomized_svd(X, k, n_iterations=5)
 
     # the iterated power method is still managing to get most of the structure
     # at the requested rank
     assert_almost_equal(s[:k], sap, decimal=3)
 
 
-def test_fast_svd_transpose_consistency():
+def test_randomized_svd_transpose_consistency():
     """Check that transposing the design matrix has limit impact"""
     n_samples = 100
     n_features = 500
@@ -123,9 +123,12 @@ def test_fast_svd_transpose_consistency():
         effective_rank=rank, tail_strength=0.5, random_state=0)
     assert_equal(X.shape, (n_samples, n_features))
 
-    U1, s1, V1 = fast_svd(X, k, n_iterations=3, transpose=False, random_state=0)
-    U2, s2, V2 = fast_svd(X, k, n_iterations=3, transpose=True, random_state=0)
-    U3, s3, V3 = fast_svd(X, k, n_iterations=3, transpose='auto', random_state=0)
+    U1, s1, V1 = randomized_svd(X, k, n_iterations=3, transpose=False,
+                                random_state=0)
+    U2, s2, V2 = randomized_svd(X, k, n_iterations=3, transpose=True,
+                                random_state=0)
+    U3, s3, V3 = randomized_svd(X, k, n_iterations=3, transpose='auto',
+                                random_state=0)
     U4, s4, V4 = linalg.svd(X, full_matrices=False)
 
     assert_almost_equal(s1, s4[:k], decimal=3)