diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py
index 051879147cf18a55d7f4f13a92b1ed6d7fe3615f..6a37bc360247fff5f7d5a2b1cefba94e8fe40a84 100644
--- a/scikits/learn/pca.py
+++ b/scikits/learn/pca.py
@@ -5,7 +5,7 @@ import numpy as np
 from scipy import linalg
 
 from .base import BaseEstimator
-from .utils.extmath import fast_logdet
+from .utils.extmath import fast_logdet, fast_svd
 
 
 def _assess_dimension_(spect, rk, n_samples, dim):
@@ -117,7 +117,7 @@ class PCA(BaseEstimator):
     >>> from scikits.learn.pca import PCA
     >>> pca = PCA(n_comp=2)
     >>> pca.fit(X)
-    PCA(n_comp=2, copy=True)
+    PCA(do_fast_svd=False, n_comp=2, copy=True)
     >>> print pca.explained_variance_ratio_
     [ 0.99244289  0.00755711]
 
@@ -126,9 +126,10 @@ class PCA(BaseEstimator):
     ProbabilisticPCA
 
     """
-    def __init__(self, n_comp=None, copy=True):
+    def __init__(self, n_comp=None, copy=True, do_fast_svd=False):
         self.n_comp = n_comp
         self.copy = copy
+        self.do_fast_svd = do_fast_svd
 
     def fit(self, X, **params):
         """ Fit the model to the data X
@@ -141,7 +142,12 @@ class PCA(BaseEstimator):
         # Center data
         self.mean_ = np.mean(X, axis=0)
         X -= self.mean_
-        U, S, V = linalg.svd(X, full_matrices=False)
+        if self.do_fast_svd:
+            if  self.n_comp == "mle":
+                raise NotImplemented
+            U, S, V = fast_svd(X, self.n_comp)
+        else:
+            U, S, V = linalg.svd(X, full_matrices=False)
         self.explained_variance_ = (S**2)/n_samples
         self.explained_variance_ratio_ = self.explained_variance_ / \
                                         self.explained_variance_.sum()
@@ -165,7 +171,6 @@ class PCA(BaseEstimator):
         Xr = np.dot(Xr, self.components_)
         return Xr
 
-
 ################################################################################
 class ProbabilisticPCA(PCA):
     """ Additional layer on top of PCA that add a probabilistic evaluation
diff --git a/scikits/learn/tests/test_pca.py b/scikits/learn/tests/test_pca.py
index e84e9e9c15f8d139431b9e51acde420c9a184d13..ec5b0d3375b86540fca5d24b3d366d70040f759c 100644
--- a/scikits/learn/tests/test_pca.py
+++ b/scikits/learn/tests/test_pca.py
@@ -37,6 +37,20 @@ def test_pca_check_projection():
     np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1)
 
 
+def test_fast_pca_check_projection():
+    """test that the projection of data is correct
+    """
+    n, p = 100, 3
+    X = randn(n, p) * .1
+    X[:10] += np.array([3, 4, 5])
+    pca = PCA(n_comp=2, do_fast_svd=True)
+    pca.fit(X)
+    Xt = 0.1* randn(1, p) + np.array([3, 4, 5])
+    Yt = pca.transform(Xt)
+    Yt /= np.sqrt((Yt**2).sum())
+    np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1)
+
+
 def test_pca_dim():
     """
     """
diff --git a/scikits/learn/utils/extmath.py b/scikits/learn/utils/extmath.py
index c24bc894ea5d99f2a7f6b95e4362f00080ec71ce..e8755e5dee9b4f37a48aa7fdb07ca655e2e02fe9 100644
--- a/scikits/learn/utils/extmath.py
+++ b/scikits/learn/utils/extmath.py
@@ -1,6 +1,9 @@
 import sys
 import math
 import numpy as np
+import scipy.sparse
+
+import numpy.linalg as linalg
 
 #XXX: We should have a function with numpy's slogdet API
 def _fast_logdet(A):
@@ -76,3 +79,39 @@ def density(w, **kwargs):
     """
     d = 0 if w is None else float((w != 0).sum()) / w.size
     return d
+
+
+def fast_svd(M, k):
+    """Computes the k-truncated SVD of the matrix M using the random
+    projections algorithm from
+
+@article{halko2009finding,
+  title={{Finding structure with randomness: Stochastic 
+             algorithms for constructing approximate matrix decompositions}},
+  author={Halko, N. and Martinsson, P.G. and Tropp, J.A.},
+  journal={arXiv},
+  volume={909},
+  year={2009}
+}
+
+This finds the exact truncated eigenvalue decomposition using
+randomization to speed up the computations."""
+    p = k+5
+    r = np.random.normal(size=(M.shape[1],p))
+    if scipy.sparse.issparse(M):
+        Y = M*r
+    else:
+        Y = np.dot(M,r)
+    del r
+    Q,r = linalg.qr(Y)
+    if scipy.sparse.issparse(M):
+        B = Q.T*M
+    else:
+        B = np.dot(Q.T, M)
+    a = linalg.svd(B, full_matrices=False)
+    Uhat = a[0]
+    del B
+    s = a[1]
+    v = a[2]
+    U = np.dot(Q, Uhat)
+    return np.asfortranarray(U.T[:k]).T, s[:k], np.asfortranarray(v[:k])