diff --git a/scikits/learn/cluster/spectral.py b/scikits/learn/cluster/spectral.py index 920c2ac5fc95adb89d27d6541906c17d00262e4d..82ff5089e27b3376f4bc5235c55b8f19c2c3b3c0 100644 --- a/scikits/learn/cluster/spectral.py +++ b/scikits/learn/cluster/spectral.py @@ -40,7 +40,7 @@ def spectral_embedding(adjacency, k=8, mode=None): """ from scipy import sparse - from scipy.sparse.linalg.eigen.arpack import eigen_symmetric + from ..utils.fixes import arpack_eigsh from scipy.sparse.linalg import lobpcg try: from pyamg import smoothed_aggregation_solver @@ -78,7 +78,7 @@ def spectral_embedding(adjacency, k=8, mode=None): # csr has the fastest matvec and is thus best suited to # arpack laplacian = laplacian.tocsr() - lambdas, diffusion_map = eigen_symmetric(-laplacian, k=k, which='LA') + lambdas, diffusion_map = arpack_eigsh(-laplacian, k=k, which='LA') embedding = diffusion_map.T[::-1]*dd elif mode == 'amg': # Use AMG to get a preconditionner and speed up the eigenvalue diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py index 05af562ede9e1865826b42be08a58f6bb2e83f97..67963c2cbfe62a350e1b0e0fac1927895d14fccc 100644 --- a/scikits/learn/datasets/samples_generator.py +++ b/scikits/learn/datasets/samples_generator.py @@ -7,7 +7,6 @@ Generate samples of synthetic data sets. import numpy as np import numpy.random as nr -from scipy import linalg def test_dataset_classif(n_samples=100, n_features=100, param=[1,1], @@ -233,18 +232,20 @@ def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10, n = min(n_samples, n_features) # random (ortho normal) vectors - u = linalg.qr(random.randn(n_samples, n), econ=True)[0] - v = linalg.qr(random.randn(n_features, n), econ=True)[0].T + from ..utils.fixes import qr_economic + u, _ = qr_economic(random.randn(n_samples, n)) + v, _ = qr_economic(random.randn(n_features, n)) # index of the singular values - i = np.arange(n, dtype=np.float64) + singular_ind = np.arange(n, dtype=np.float64) # build the singular profile by assembling signal and noise components - low_rank = (1 - tail_strength) * np.exp(-1.0 * (i / effective_rank) ** 2) - tail = tail_strength * np.exp(-0.1 * i / effective_rank) + low_rank = (1 - tail_strength) * \ + np.exp(-1.0 * (singular_ind / effective_rank) ** 2) + tail = tail_strength * np.exp(-0.1 * singular_ind / effective_rank) s = np.identity(n) * (low_rank + tail) - return np.dot(np.dot(u, s), v) + return np.dot(np.dot(u, s), v.T) def make_regression_dataset(n_train_samples=100, n_test_samples=100, diff --git a/scikits/learn/utils/extmath.py b/scikits/learn/utils/extmath.py index 27aa8d5a38358e7ddb4893f077501548ddff8e04..495ceb271c9e1f61477ee0bd66ab2ad3c0cd1ce9 100644 --- a/scikits/learn/utils/extmath.py +++ b/scikits/learn/utils/extmath.py @@ -8,7 +8,6 @@ import sys import math import numpy as np -from scipy import linalg #XXX: We should have a function with numpy's slogdet API def _fast_logdet(A): @@ -20,6 +19,7 @@ def _fast_logdet(A): """ # XXX: Should be implemented as in numpy, using ATLAS # http://projects.scipy.org/numpy/browser/trunk/numpy/linalg/linalg.py#L1559 + from scipy import linalg ld = np.sum(np.log(np.diag(A))) a = np.exp(ld/A.shape[0]) d = np.linalg.det(A/a) @@ -35,6 +35,7 @@ def _fast_logdet_numpy(A): but more robust It returns -Inf if det(A) is non positive or is not defined. """ + from scipy import linalg sign, ld = np.linalg.slogdet(A) if not sign > 0: return -np.inf @@ -175,15 +176,16 @@ def fast_svd(M, k, p=None, q=0, transpose='auto', rng=0): for i in xrange(q): Y = safe_sparse_dot(M, safe_sparse_dot(M.T, Y)) - # extracting an orthonormal basis of the M range samples: econ=True raises a - # deprecation warning but as of today there is no way to avoid it... - Q, R = linalg.qr(Y, econ=True) + # extracting an orthonormal basis of the M range samples + from .fixes import qr_economic + Q, R = qr_economic(Y) del R # project M to the (k + p) dimensional space using the basis vectors B = safe_sparse_dot(Q.T, M) # compute the SVD on the thin matrix: (k + p) wide + from scipy import linalg Uhat, s, V = linalg.svd(B, full_matrices=False) del B U = np.dot(Q, Uhat) diff --git a/scikits/learn/utils/fixes.py b/scikits/learn/utils/fixes.py index 52557f62fe66c224ac9fba0b800db852fdb8562a..d83ed43795a55dc07713f11c2bd86670d4f56825 100644 --- a/scikits/learn/utils/fixes.py +++ b/scikits/learn/utils/fixes.py @@ -75,12 +75,38 @@ def _in1d(ar1, ar2, assume_unique=False): return flag[indx][rev_idx] -if np.__version__ >= '1.4': - from numpy import in1d, copysign, unique -else: +def qr_economic(A, **kwargs): + """ + Scipy 0.9 changed the keyword econ=True to mode='economic' + """ + import scipy.linalg + version = scipy.__version__.split('.') + if version[0] < 1 and version[1] < 9: + return scipy.linalg.qr(A, econ=True, **kwargs) + else: + return scipy.linalg.qr(A, mode='economic', **kwargs) + + +def arpack_eigsh(A, **kwargs): + """ + Scipy 0.9 renamed eigen_symmetric to eigsh in + scipy.sparse.linalg.eigen.arpack + """ + from scipy.sparse.linalg.eigen import arpack + if hasattr(arpack, 'eigsh'): + return arpack.eigsh(A, **kwargs) + else: + return arpack.eigen_symmetric(A, **kwargs) + +# export fixes for np =< 1.4 +np_version = np.__version__.split('.') +if np_version[0] < 2 and np_version[1] < 5: in1d = _in1d copysign = _copysign unique = _unique +else: + from numpy import in1d, copysign, unique +