diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py index b211c122e79c012ea6c0e13bfbc69e85509b1cd8..b8e74087b68daf0741486602856079e21a6ee73f 100644 --- a/sklearn/decomposition/dict_learning.py +++ b/sklearn/decomposition/dict_learning.py @@ -323,9 +323,8 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False, def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, - method='lasso_lars', n_jobs=1, dict_init=None, - code_init=None, callback=None, verbose=False, - random_state=None): + method='lars', n_jobs=1, dict_init=None, code_init=None, + callback=None, verbose=False, random_state=None): """Solves a dictionary learning matrix factorization problem. Finds the best dictionary and the corresponding sparse code for @@ -354,10 +353,10 @@ def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, tol: float, Tolerance for the stopping condition. - method: {'lasso_lars', 'lasso_cd'} - lasso_lars: uses the least angle regression method + method: {'lars', 'cd'} + lars: uses the least angle regression method to solve the lasso problem (linear_model.lars_path) - lasso_cd: uses the coordinate descent method to compute the + cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -391,8 +390,10 @@ def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, Vector of errors at each iteration. """ - if method not in ('lasso_lars', 'lasso_cd'): + if method not in ('lars', 'cd'): raise ValueError('Coding method not supported as a fit algorithm.') + method = 'lasso_' + method + t0 = time.time() n_features = X.shape[1] # Avoid integer division problems @@ -474,8 +475,7 @@ def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True, dict_init=None, callback=None, chunk_size=3, verbose=False, shuffle=True, n_jobs=1, - method='lasso_lars', iter_offset=0, - random_state=None): + method='lars', iter_offset=0, random_state=None): """Solves a dictionary learning matrix factorization problem online. Finds the best dictionary and the corresponding sparse code for @@ -524,10 +524,10 @@ def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True, n_jobs: int, number of parallel jobs to run, or -1 to autodetect. - method: {'lasso_lars', 'lasso_cd'} - lasso_lars: uses the least angle regression method + method: {'lars', 'cd'} + lars: uses the least angle regression method to solve the lasso problem (linear_model.lars_path) - lasso_cd: uses the coordinate descent method to compute the + cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -546,8 +546,10 @@ def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True, code: array of shape (n_samples, n_atoms), the sparse code (only returned if `return_code=True`) """ - if method not in ('lasso_lars', 'lasso_cd'): + if method not in ('lars', 'cd'): raise ValueError('Coding method not supported as a fit algorithm.') + method = 'lasso_' + method + t0 = time.time() n_samples, n_features = X.shape # Avoid integer division problems @@ -715,10 +717,10 @@ class DictionaryLearning(BaseDictionaryLearning): tol: float, tolerance for numerical error - fit_algorithm: {'lasso_lars', 'lasso_cd'} - lasso_lars: uses the least angle regression method + fit_algorithm: {'lars', 'cd'} + lars: uses the least angle regression method to solve the lasso problem (linear_model.lars_path) - lasso_cd: uses the coordinate descent method to compute the + cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -783,7 +785,7 @@ class DictionaryLearning(BaseDictionaryLearning): """ def __init__(self, n_atoms, alpha=1, max_iter=1000, tol=1e-8, - fit_algorithm='lasso_lars', transform_algorithm='omp', + fit_algorithm='lars', transform_algorithm='omp', transform_n_nonzero_coefs=None, transform_alpha=None, n_jobs=1, code_init=None, dict_init=None, verbose=False, split_sign=False, random_state=None): @@ -851,7 +853,8 @@ class MiniBatchDictionaryLearning(BaseDictionaryLearning): total number of iterations to perform fit_algorithm: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) + lars: uses the least angle regression method to solve the lasso problem + (linear_model.lars_path) cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -875,7 +878,7 @@ class MiniBatchDictionaryLearning(BaseDictionaryLearning): transform_alpha: float, 1. by default If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the penalty applied to the L1 norm. - If `algorithm='threhold'`, `alpha` is the absolute value of the + If `algorithm='threshold'`, `alpha` is the absolute value of the threshold below which coefficients will be squashed to zero. If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of the reconstruction error targeted. In this case, it overrides @@ -917,7 +920,7 @@ class MiniBatchDictionaryLearning(BaseDictionaryLearning): """ def __init__(self, n_atoms, alpha=1, n_iter=1000, - fit_algorithm='lasso_lars', n_jobs=1, chunk_size=3, + fit_algorithm='lars', n_jobs=1, chunk_size=3, shuffle=True, dict_init=None, transform_algorithm='omp', transform_n_nonzero_coefs=None, transform_alpha=None, verbose=False, split_sign=False, random_state=None): diff --git a/sklearn/decomposition/sparse_pca.py b/sklearn/decomposition/sparse_pca.py index 7bcf2b729b34aa853c3c1cdc636e2ea34b966a13..018a891f14c99fe5b4b848d2200a20246495ec4a 100644 --- a/sklearn/decomposition/sparse_pca.py +++ b/sklearn/decomposition/sparse_pca.py @@ -36,10 +36,10 @@ class SparsePCA(BaseEstimator, TransformerMixin): tol: float, Tolerance for the stopping condition. - method: {'lasso_lars', 'lasso_cd'} - lasso_lars: uses the least angle regression method + method: {'lars', 'cd'} + lars: uses the least angle regression method to solve the lasso problem (linear_model.lars_path) - lasso_cd: uses the coordinate descent method to compute the + cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -72,7 +72,7 @@ class SparsePCA(BaseEstimator, TransformerMixin): """ def __init__(self, n_components, alpha=1, ridge_alpha=0.01, max_iter=1000, - tol=1e-8, method='lasso_lars', n_jobs=1, U_init=None, + tol=1e-8, method='lars', n_jobs=1, U_init=None, V_init=None, verbose=False, random_state=None): self.n_components = n_components self.alpha = alpha @@ -185,10 +185,10 @@ class MiniBatchSparsePCA(SparsePCA): n_jobs: int, number of parallel jobs to run, or -1 to autodetect. - method: {'lasso_lars', 'lasso_cd'} - lasso_lars: uses the least angle regression method + method: {'lars', 'cd'} + lars: uses the least angle regression method to solve the lasso problem (linear_model.lars_path) - lasso_cd: uses the coordinate descent method to compute the + cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -198,7 +198,7 @@ class MiniBatchSparsePCA(SparsePCA): """ def __init__(self, n_components, alpha=1, ridge_alpha=0.01, n_iter=100, callback=None, chunk_size=3, verbose=False, shuffle=True, - n_jobs=1, method='lasso_lars', random_state=None): + n_jobs=1, method='lars', random_state=None): self.n_components = n_components self.alpha = alpha self.ridge_alpha = ridge_alpha diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py index e8e3cda667d0117a8c2b9e74f5dbdfa613c984b0..82c57e55c975b6ad873185e1ddc2122b72aaad27 100644 --- a/sklearn/decomposition/tests/test_sparse_pca.py +++ b/sklearn/decomposition/tests/test_sparse_pca.py @@ -52,7 +52,7 @@ def test_correct_shapes(): def test_fit_transform(): rng = np.random.RandomState(0) Y, _, _ = generate_toy_data(3, 10, (8, 8), random_state=rng) # wide array - spca_lars = SparsePCA(n_components=3, method='lasso_lars', + spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng) spca_lars.fit(Y) U1 = spca_lars.transform(Y) @@ -71,7 +71,7 @@ def test_fit_transform(): U2 = spca.transform(Y) assert_array_almost_equal(U1, U2) # Test that CD gives similar results - spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng) + spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng) spca_lasso.fit(Y) assert_array_almost_equal(spca_lasso.components_, spca_lars.components_) @@ -79,10 +79,10 @@ def test_fit_transform(): def test_fit_transform_tall(): rng = np.random.RandomState(0) Y, _, _ = generate_toy_data(3, 65, (8, 8), random_state=rng) # tall array - spca_lars = SparsePCA(n_components=3, method='lasso_lars', + spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng) U1 = spca_lars.fit_transform(Y) - spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng) + spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng) U2 = spca_lasso.fit(Y).transform(Y) assert_array_almost_equal(U1, U2) @@ -131,6 +131,6 @@ def test_mini_batch_fit_transform(): random_state=rng).fit(Y).transform(Y) assert_array_almost_equal(U1, U2) # Test that CD gives similar results - spca_lasso = MiniBatchSparsePCA(n_components=3, method='lasso_cd', + spca_lasso = MiniBatchSparsePCA(n_components=3, method='cd', random_state=rng).fit(Y) assert_array_almost_equal(spca_lasso.components_, spca_lars.components_)