diff --git a/scikits/learn/cluster/spectral.py b/scikits/learn/cluster/spectral.py
index 920c2ac5fc95adb89d27d6541906c17d00262e4d..82ff5089e27b3376f4bc5235c55b8f19c2c3b3c0 100644
--- a/scikits/learn/cluster/spectral.py
+++ b/scikits/learn/cluster/spectral.py
@@ -40,7 +40,7 @@ def spectral_embedding(adjacency, k=8, mode=None):
     """
 
     from scipy import sparse
-    from scipy.sparse.linalg.eigen.arpack import eigen_symmetric
+    from ..utils.fixes import arpack_eigsh
     from scipy.sparse.linalg import lobpcg
     try:
         from pyamg import smoothed_aggregation_solver
@@ -78,7 +78,7 @@ def spectral_embedding(adjacency, k=8, mode=None):
                 # csr has the fastest matvec and is thus best suited to
                 # arpack
                 laplacian = laplacian.tocsr()
-        lambdas, diffusion_map = eigen_symmetric(-laplacian, k=k, which='LA')
+        lambdas, diffusion_map = arpack_eigsh(-laplacian, k=k, which='LA')
         embedding = diffusion_map.T[::-1]*dd
     elif mode == 'amg':
         # Use AMG to get a preconditionner and speed up the eigenvalue
diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py
index 05af562ede9e1865826b42be08a58f6bb2e83f97..67963c2cbfe62a350e1b0e0fac1927895d14fccc 100644
--- a/scikits/learn/datasets/samples_generator.py
+++ b/scikits/learn/datasets/samples_generator.py
@@ -7,7 +7,6 @@ Generate samples of synthetic data sets.
 
 import numpy as np
 import numpy.random as nr
-from scipy import linalg
 
 
 def test_dataset_classif(n_samples=100, n_features=100, param=[1,1],
@@ -233,18 +232,20 @@ def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10,
     n = min(n_samples, n_features)
 
     # random (ortho normal) vectors
-    u = linalg.qr(random.randn(n_samples, n), econ=True)[0]
-    v = linalg.qr(random.randn(n_features, n), econ=True)[0].T
+    from ..utils.fixes import qr_economic
+    u, _ = qr_economic(random.randn(n_samples, n))
+    v, _ = qr_economic(random.randn(n_features, n))
 
     # index of the singular values
-    i = np.arange(n, dtype=np.float64)
+    singular_ind = np.arange(n, dtype=np.float64)
 
     # build the singular profile by assembling signal and noise components
-    low_rank = (1 - tail_strength) * np.exp(-1.0 * (i / effective_rank) ** 2)
-    tail = tail_strength * np.exp(-0.1 * i / effective_rank)
+    low_rank = (1 - tail_strength) * \
+               np.exp(-1.0 * (singular_ind / effective_rank) ** 2)
+    tail = tail_strength * np.exp(-0.1 * singular_ind / effective_rank)
     s = np.identity(n) * (low_rank + tail)
 
-    return np.dot(np.dot(u, s), v)
+    return np.dot(np.dot(u, s), v.T)
 
 
 def make_regression_dataset(n_train_samples=100, n_test_samples=100,
diff --git a/scikits/learn/utils/extmath.py b/scikits/learn/utils/extmath.py
index 27aa8d5a38358e7ddb4893f077501548ddff8e04..495ceb271c9e1f61477ee0bd66ab2ad3c0cd1ce9 100644
--- a/scikits/learn/utils/extmath.py
+++ b/scikits/learn/utils/extmath.py
@@ -8,7 +8,6 @@ import sys
 import math
 
 import numpy as np
-from scipy import linalg
 
 #XXX: We should have a function with numpy's slogdet API
 def _fast_logdet(A):
@@ -20,6 +19,7 @@ def _fast_logdet(A):
     """
     # XXX: Should be implemented as in numpy, using ATLAS
     # http://projects.scipy.org/numpy/browser/trunk/numpy/linalg/linalg.py#L1559
+    from scipy import linalg
     ld = np.sum(np.log(np.diag(A)))
     a = np.exp(ld/A.shape[0])
     d = np.linalg.det(A/a)
@@ -35,6 +35,7 @@ def _fast_logdet_numpy(A):
     but more robust
     It returns -Inf if det(A) is non positive or is not defined.
     """
+    from scipy import linalg
     sign, ld = np.linalg.slogdet(A)
     if not sign > 0:
         return -np.inf
@@ -175,15 +176,16 @@ def fast_svd(M, k, p=None, q=0, transpose='auto', rng=0):
     for i in xrange(q):
         Y = safe_sparse_dot(M, safe_sparse_dot(M.T, Y))
 
-    # extracting an orthonormal basis of the M range samples: econ=True raises a
-    # deprecation warning but as of today there is no way to avoid it...
-    Q, R = linalg.qr(Y, econ=True)
+    # extracting an orthonormal basis of the M range samples
+    from .fixes import qr_economic
+    Q, R = qr_economic(Y)
     del R
 
     # project M to the (k + p) dimensional space using the basis vectors
     B = safe_sparse_dot(Q.T, M)
 
     # compute the SVD on the thin matrix: (k + p) wide
+    from scipy import linalg
     Uhat, s, V = linalg.svd(B, full_matrices=False)
     del B
     U = np.dot(Q, Uhat)
diff --git a/scikits/learn/utils/fixes.py b/scikits/learn/utils/fixes.py
index 52557f62fe66c224ac9fba0b800db852fdb8562a..d83ed43795a55dc07713f11c2bd86670d4f56825 100644
--- a/scikits/learn/utils/fixes.py
+++ b/scikits/learn/utils/fixes.py
@@ -75,12 +75,38 @@ def _in1d(ar1, ar2, assume_unique=False):
         return flag[indx][rev_idx]
 
 
-if np.__version__ >= '1.4':
-    from numpy import in1d, copysign, unique
-else:
+def qr_economic(A, **kwargs):
+    """
+    Scipy 0.9 changed the keyword econ=True to mode='economic'
+    """
+    import scipy.linalg
+    version = scipy.__version__.split('.')
+    if version[0] < 1 and version[1] < 9:
+        return scipy.linalg.qr(A, econ=True, **kwargs)
+    else:
+        return scipy.linalg.qr(A, mode='economic', **kwargs)
+        
+
+def arpack_eigsh(A, **kwargs):
+    """
+    Scipy 0.9 renamed eigen_symmetric to eigsh in
+    scipy.sparse.linalg.eigen.arpack
+    """
+    from scipy.sparse.linalg.eigen import arpack
+    if hasattr(arpack, 'eigsh'):
+        return arpack.eigsh(A, **kwargs)
+    else:
+        return arpack.eigen_symmetric(A, **kwargs)
+
+# export fixes for np =< 1.4
+np_version = np.__version__.split('.')
+if np_version[0] < 2 and np_version[1] < 5:
     in1d = _in1d
     copysign = _copysign
     unique = _unique
+else:
+    from numpy import in1d, copysign, unique
+