diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 8c7df440f4ddb2d707054bf94ca3995968ba0c49..76eb410a37d9215ec87b0b2b13885db135ac1d90 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -377,6 +377,8 @@ Signal Decomposition
    decomposition.NMF
    decomposition.SparsePCA
    decomposition.MiniBatchSparsePCA
+   decomposition.DictionaryLearning
+   decomposition.DictionaryLearningOnline
 
 .. autosummary::
    :toctree: generated/
diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst
index 4b5633003c7ed0743eea25d1fd82f05adcabc97f..059f09cd45d79583280cb43a43d4fad602b55698 100644
--- a/doc/modules/decomposition.rst
+++ b/doc/modules/decomposition.rst
@@ -347,3 +347,105 @@ of the data.
       matrix factorization"
       <http://www.cs.rpi.edu/~boutsc/files/nndsvd.pdf>`_
       C. Boutsidis, E. Gallopoulos, 2008
+
+
+
+.. _DictionaryLearning:
+
+Dictionary Learning
+===================
+
+Generic dictionary learning
+---------------------------
+
+Dictionary learning (:class:`DictionaryLearning`) is a matrix factorization
+problem that amounts to finding a (usually overcomplete) dictionary that will
+perform good at sparsely encoding the fitted data.
+
+Representing data as sparse combinations of atoms from an overcomplete
+dictionary is suggested to be the way the mammal primary visual cortex works.
+Consequently, dictionary learning applied on image patches has been shown to 
+give good results in image processing tasks such as image completion,
+inpainting and denoising, as well as for supervised recognition tasks.
+
+Dictionary learning is an optimization problem solved by alternatively updating
+the sparse code, as a solution to multiple Lasso problems, considering the
+dictionary fixed, and then updating the dictionary to best fit the sparse code.
+
+After using such a procedure to fit the dictionary, the fitted object can be 
+used to transform new data. The transformation amounts to a sparse coding
+problem: finding a representation of the data as a linear combination of as few
+dictionary atoms as possible. All variations of dictionary learning implement
+the following transform methods, controllable via the `transform_method` 
+initialization parameter:
+
+
+* Orthogonal matching pursuit (:ref:`omp`)
+
+* Least-angle regression (:ref:`least_angle_regression`)
+
+* Lasso computed by least-angle regression
+
+* Lasso using coordinate descent (:ref:`lasso`)
+
+* Thresholding
+
+Thresholding is very fast but it does not yield accurate reconstructions.
+They have been shown useful in literature for classification tasks. For image
+reconstruction tasks, orthogonal matching pursuit yields the most accurate,
+unbiased reconstruction.
+
+The dictionary learning objects offer, via the `split_code` parameter, the
+possibility to separate the positive and negative values in the results of 
+sparse coding. This is useful when dictionary learning is used for extracting
+features that will be used for supervised learning, because it allows the
+learning algorithm to assign different weights to negative loadings of a
+particular atom, from to the corresponding positive loading.
+
+The split code for a single sample has length `2 * n_atoms`
+and is constructed using the following rule: First, the regular code of length
+`n_atoms` is computed. Then, the first `n_atoms` entries of the split_code are
+filled with the positive part of the regular code vector. The second half of
+the split code is filled with the negative part of the code vector, only with
+a positive sign. Therefore, the split_code is non-negative. 
+
+The following image shows how a dictionary learned from 4x4 pixel image patches
+extracted from part of the image of Lena looks like.
+
+
+.. figure:: ../auto_examples/decomposition/images/plot_img_denoising_1.png
+    :target: ../auto_examples/decomposition/plot_img_denoising.html
+    :align: center
+    :scale: 50%
+
+
+.. topic:: Examples:
+
+  * :ref:`example_decomposition_plot_img_denoising.py`
+
+
+.. topic:: References:
+
+  * `"Online dictionary learning for sparse coding" 
+    <http://www.di.ens.fr/sierra/pdfs/icml09.pdf>`_
+    J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009
+
+.. _DictionaryLearningOnline
+
+Online dictionary learning
+--------------------------
+
+:class:`DictionaryLearningOnline` implements a faster, but less accurate
+version of the dictionary learning algorithm that is better suited for large
+datasets. 
+
+By default, :class:`DictionaryLearningOnline` divides the data into
+mini-batches and optimizes in an online manner by cycling over the mini-batches
+for the specified number of iterations. However, at the moment it does not
+implement a stopping condition.
+
+The estimator also implements `partial_fit`, which updates the dictionary by
+iterating only once over a mini-batch. This can be used for online learning
+when the data is not readily available from the start, or for when the data
+does not fit into the memory.
+
diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst
index 1d1da635f091502f8d2f5288c0130d5efc6153a8..3038716fbae0b7d36fb79c006b2b0c9619f6c29a 100644
--- a/doc/modules/linear_model.rst
+++ b/doc/modules/linear_model.rst
@@ -358,7 +358,8 @@ column is always zero.
    <http://www-stat.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf>`_
    by Hastie et al.
 
-.. _OMP:
+
+.. _omp:
 
 Orthogonal Matching Pursuit (OMP)
 =================================
@@ -370,12 +371,12 @@ Being a forward feature selection method like :ref:`least_angle_regression`,
 orthogonal matching pursuit can approximate the optimum solution vector with a
 fixed number of non-zero elements:
 
-.. math:: \text{arg\,min} ||y - X\gamma||_2^2 \text{ subject to } ||\gamma||_0 \leq n_{features}
+.. math:: \text{arg\,min} ||y - X\gamma||_2^2 \text{ subject to } ||\gamma||_0 \leq n_{nonzero_coefs}
 
 Alternatively, orthogonal matching pursuit can target a specific error instead
 of a specific number of non-zero coefficients. This can be expressed as:
 
-.. math:: \text{arg\,min} ||\gamma||_0 \text{ subject to } ||y-X\gamma||_2^2 \leq \varepsilon
+.. math:: \text{arg\,min} ||\gamma||_0 \text{ subject to } ||y-X\gamma||_2^2 \leq \text{tol}
 
 
 OMP is based on a greedy algorithm that includes at each step the atom most
diff --git a/examples/decomposition/plot_faces_decomposition.py b/examples/decomposition/plot_faces_decomposition.py
index 397b865df700a94b4e708b1293278031d396da4e..1cd41b9e9fc9a6180b9c2177913a5fdd154d0495 100644
--- a/examples/decomposition/plot_faces_decomposition.py
+++ b/examples/decomposition/plot_faces_decomposition.py
@@ -82,6 +82,11 @@ estimators = [
                                       n_iter=100, chunk_size=3),
      True, False),
 
+    ('Dictionary atoms - DictionaryLearningOnline',
+    decomposition.DictionaryLearningOnline(n_atoms=n_components, alpha=1e-3,
+                                           n_iter=100, chunk_size=3),
+     True, False),
+
     ('Cluster centers - MiniBatchKMeans',
      MiniBatchKMeans(k=n_components, tol=1e-3, chunk_size=20, max_iter=50),
      True, False)
diff --git a/examples/decomposition/plot_img_denoising.py b/examples/decomposition/plot_img_denoising.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c91db2497f735a3e621cb4d520d04da238ec6bb
--- /dev/null
+++ b/examples/decomposition/plot_img_denoising.py
@@ -0,0 +1,151 @@
+"""
+=========================================
+Image denoising using dictionary learning
+=========================================
+
+An example comparing the effect of reconstructing noisy fragments
+of Lena using online :ref:`DictionaryLearning` and various transform methods.
+
+The dictionary is fitted on the non-distorted left half of the image, and
+subsequently used to reconstruct the right half.
+
+A common practice for evaluating the results of image denoising is by looking
+at the difference between the reconstruction and the original image. If the
+reconstruction is perfect this will look like gaussian noise.
+
+It can be seen from the plots that the results of :ref:`omp` with two
+non-zero coefficients is a bit less biased than when keeping only one (the
+edges look less prominent). However, it is farther from the ground truth in
+Frobenius norm.
+
+The result of :ref:`least_angle_regression` is much more strongly biased: the
+difference is reminiscent of the local intensity value of the original image.
+
+Thresholding is clearly not useful for denoising, but it is here to show that
+it can produce a suggestive output with very high speed, and thus be useful
+for other tasks such as object classification, where performance is not
+necessarily related to visualisation.
+
+"""
+print __doc__
+
+from time import time
+
+import pylab as pl
+import scipy as sp
+import numpy as np
+
+from sklearn.decomposition import DictionaryLearningOnline
+from sklearn.feature_extraction.image import extract_patches_2d, \
+                                             reconstruct_from_patches_2d
+
+###############################################################################
+# Load Lena image and extract patches
+lena = sp.lena() / 256.0
+
+# downsample for higher speed
+lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2]
+lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2]
+lena /= 16.0
+height, width = lena.shape
+
+# Distort the right half of the image
+print 'Distorting image...'
+distorted = lena.copy()
+distorted[:, height / 2:] += 0.075 * np.random.randn(width, height / 2)
+
+# Extract all clean patches from the left half of the image
+print 'Extracting clean patches...'
+patch_size = (7, 7)
+data = extract_patches_2d(distorted[:, :height / 2], patch_size)
+data = data.reshape(data.shape[0], -1)
+data -= np.mean(data, axis=0)
+data /= np.std(data, axis=0)
+
+###############################################################################
+# Learn the dictionary from clean patches
+print 'Learning the dictionary... ',
+t0 = time()
+dico = DictionaryLearningOnline(n_atoms=100, alpha=1e-2, n_iter=500)
+V = dico.fit(data).components_
+dt = time() - t0
+print 'done in %.2f.' % dt
+
+pl.figure(figsize=(4.2, 4))
+for i, comp in enumerate(V[:100]):
+    pl.subplot(10, 10, i + 1)
+    pl.imshow(comp.reshape(patch_size), cmap=pl.cm.gray_r,
+              interpolation='nearest')
+    pl.xticks(())
+    pl.yticks(())
+pl.suptitle('Dictionary learned from Lena patches\n' +
+            'Train time %.1fs on %d patches' % (dt, len(data)),
+            fontsize=16)
+pl.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
+
+
+def show_with_diff(image, reference, title):
+    """Helper function to display denoising"""
+    pl.figure(figsize=(5, 3.3))
+    pl.subplot(1, 2, 1)
+    pl.title('Image')
+    pl.imshow(image, vmin=0, vmax=1, cmap=pl.cm.gray, interpolation='nearest')
+    pl.xticks(())
+    pl.yticks(())
+    pl.subplot(1, 2, 2)
+    difference = image - reference
+
+    pl.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
+    pl.imshow(difference, vmin=-0.5, vmax=0.5, cmap=pl.cm.PuOr,
+              interpolation='nearest')
+    pl.xticks(())
+    pl.yticks(())
+    pl.suptitle(title, size=16)
+    pl.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)
+
+###############################################################################
+# Display the distorted image
+show_with_diff(distorted, lena, 'Distorted image')
+
+###############################################################################
+# Extract noisy patches and reconstruct them using the dictionary
+print 'Extracting noisy patches... '
+data = extract_patches_2d(distorted[:, height / 2:], patch_size)
+data = data.reshape(data.shape[0], -1)
+intercept = np.mean(data, axis=0)
+data -= intercept
+
+transform_algorithms = [
+    ('Orthogonal Matching Pursuit\n1 atom', 'omp',
+     {'transform_n_nonzero_coefs': 1}),
+    ('Orthogonal Matching Pursuit\n2 atoms', 'omp',
+     {'transform_n_nonzero_coefs': 2}),
+    ('Least-angle regression\n5 atoms', 'lars', {'transform_n_nonzero_coefs': 5}),
+    ('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})]
+
+reconstructions = {}
+for title, transform_algorithm, kwargs in transform_algorithms:
+    print title, '... ',
+    reconstructions[title] = lena.copy()
+    t0 = time()
+    dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
+    code = dico.transform(data)
+    patches = np.dot(code, V)
+
+    if transform_algorithm == 'threshold':
+        patches -= patches.min()
+        patches /= patches.max()
+
+    patches += intercept
+    patches = patches.reshape(len(data), *patch_size)
+    if transform_algorithm == 'threshold':
+        patches -= patches.min()
+        patches /= patches.max()
+    reconstructions[title][:, height / 2:] = reconstruct_from_patches_2d(
+        patches, (width, height / 2))
+    dt = time() - t0
+    print 'done in %.2f.' % dt
+    show_with_diff(reconstructions[title], lena,
+                   title + ' (time: %.1fs)' % dt)
+
+pl.show()
diff --git a/scikits/learn/decomposition/__init__.py b/scikits/learn/decomposition/__init__.py
index 3943c2ed8da19edc7f74144d29dfe640f3b86efb..dae5d9b576d79489365bb33d6b580a1ff9a33e71 100644
--- a/scikits/learn/decomposition/__init__.py
+++ b/scikits/learn/decomposition/__init__.py
@@ -1,3 +1,3 @@
 import warnings
 warnings.warn('scikits.learn namespace is deprecated, please use sklearn instead')
-from sklearn.decomposition import *
\ No newline at end of file
+from sklearn.decomposition import *
diff --git a/sklearn/decomposition/__init__.py b/sklearn/decomposition/__init__.py
index ead01092dab7c78da2abef7501ea9edaa7cc3d1d..fbc9e2f33ad1d13bacfe9561452b41417b2703e8 100644
--- a/sklearn/decomposition/__init__.py
+++ b/sklearn/decomposition/__init__.py
@@ -5,6 +5,7 @@ Matrix decomposition algorithms
 from .nmf import NMF, ProjectedGradientNMF
 from .pca import PCA, RandomizedPCA, ProbabilisticPCA
 from .kernel_pca import KernelPCA
-from .sparse_pca import SparsePCA, MiniBatchSparsePCA, dict_learning, \
-                        dict_learning_online
+from .sparse_pca import SparsePCA, MiniBatchSparsePCA
 from .fastica_ import FastICA, fastica
+from .dict_learning import dict_learning, dict_learning_online, \
+                           DictionaryLearning, DictionaryLearningOnline
diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..03365cde25078b4e587c0accabd0f456d0921a0e
--- /dev/null
+++ b/sklearn/decomposition/dict_learning.py
@@ -0,0 +1,993 @@
+""" Dictionary learning
+"""
+# Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort
+# License: BSD
+
+import time
+import sys
+import itertools
+
+from math import sqrt, floor, ceil
+
+import numpy as np
+from scipy import linalg
+from numpy.lib.stride_tricks import as_strided
+
+from ..base import BaseEstimator, TransformerMixin
+from ..externals.joblib import Parallel, delayed, cpu_count
+from ..utils import check_random_state
+from ..utils import gen_even_slices
+from ..utils.extmath import fast_svd
+from ..linear_model import Lasso, orthogonal_mp_gram, lars_path
+
+
+def sparse_encode(X, Y, gram=None, cov=None, algorithm='lasso_lars',
+                  n_nonzero_coefs=None, alpha=None,
+                  overwrite_gram=False, overwrite_cov=False, init=None):
+    """Generic sparse coding
+
+    Each column of the result is the solution to a Lasso problem.
+
+    Parameters
+    ----------
+    X: array of shape (n_samples, n_components)
+        Dictionary against which to optimize the sparse code.
+
+    Y: array of shape (n_samples, n_features)
+        Data matrix.
+
+    gram: array, shape=(n_components, n_components)
+        Precomputed Gram matrix, X^T * X
+
+    cov: array, shape=(n_components, n_features)
+        Precomputed covariance, X^T * Y
+
+    algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
+        lars: uses the least angle regression method (linear_model.lars_path)
+        lasso_lars: uses Lars to compute the Lasso solution
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). lasso_lars will be faster if
+        the estimated components are sparse.
+        omp: uses orthogonal matching pursuit to estimate the sparse solution
+        threshold: squashes to zero all coefficients less than alpha from
+        the projection X.T * Y
+
+    n_nonzero_coefs: int, 0.1 * n_features by default
+        Number of nonzero coefficients to target in each column of the
+        solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+        and is overridden by `alpha` in the `omp` case.
+
+    alpha: float, 1. by default
+        If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+        penalty applied to the L1 norm.
+        If `algorithm='threhold'`, `alpha` is the absolute value of the
+        threshold below which coefficients will be squashed to zero.
+        If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
+        the reconstruction error targeted. In this case, it overrides
+        `n_nonzero_coefs`.
+
+    init: array of shape (n_components, n_features)
+        Initialization value of the sparse codes. Only used if
+        `algorithm='lasso_cd'`.
+
+    overwrite_gram: boolean,
+        Whether to overwrite the precomputed Gram matrix.
+
+    overwrite_cov: boolean,
+        Whether to overwrite the precomputed covariance matrix.
+
+    Returns
+    -------
+    code: array of shape (n_components, n_features)
+        The sparse codes
+    """
+    alpha = float(alpha) if alpha is not None else None
+    X, Y = map(np.asanyarray, (X, Y))
+    if Y.ndim == 1:
+        Y = Y[:, np.newaxis]
+    n_features = Y.shape[1]
+    # This will always use Gram
+    if gram is None:
+        # I think it's never safe to overwrite Gram when n_features > 1
+        # but I'd like to avoid the complicated logic.
+        # The parameter could be removed in this case. Discuss.
+        gram = np.dot(X.T, X)
+    if cov is None and algorithm != 'lasso_cd':
+        # overwrite_cov is safe
+        overwrite_cov = True
+        cov = np.dot(X.T, Y)
+
+    if algorithm == 'lasso_lars':
+        if alpha is None:
+            alpha = 1.
+        try:
+            new_code = np.empty((X.shape[1], n_features))
+            err_mgt = np.seterr(all='ignore')
+            for k in range(n_features):
+                # A huge amount of time is spent in this loop. It needs to be
+                # tight.
+                _, _, coef_path_ = lars_path(X, Y[:, k], Xy=cov[:, k],
+                                             Gram=gram, alpha_min=alpha,
+                                             method='lasso')
+                new_code[:, k] = coef_path_[:, -1]
+        finally:
+            np.seterr(**err_mgt)
+
+    elif algorithm == 'lasso_cd':
+        if alpha is None:
+            alpha = 1.
+        new_code = np.empty((X.shape[1], n_features))
+        clf = Lasso(alpha=alpha, fit_intercept=False, precompute=gram,
+                    max_iter=1000)
+        for k in xrange(n_features):
+            # A huge amount of time is spent in this loop. It needs to be
+            # tight.
+            if init is not None:
+                clf.coef_ = init[:, k]  # Init with previous value of Vk
+            clf.fit(X, Y[:, k])
+            new_code[:, k] = clf.coef_
+
+    elif algorithm == 'lars':
+        if n_nonzero_coefs is None:
+            n_nonzero_coefs = n_features / 10
+        try:
+            new_code = np.empty((X.shape[1], n_features))
+            err_mgt = np.seterr(all='ignore')
+            for k in xrange(n_features):
+                # A huge amount of time is spent in this loop. It needs to be
+                # tight.
+                _, _, coef_path_ = lars_path(X, Y[:, k], Xy=cov[:, k],
+                                             Gram=gram, method='lar',
+                                             max_iter=n_nonzero_coefs)
+                new_code[:, k] = coef_path_[:, -1]
+        finally:
+            np.seterr(**err_mgt)
+
+    elif algorithm == 'threshold':
+        if alpha is None:
+            alpha = 1.
+        new_code = np.sign(cov) * np.maximum(np.abs(cov) - alpha, 0)
+
+    elif algorithm == 'omp':
+        if n_nonzero_coefs is None and alpha is None:
+            n_nonzero_coefs = n_features / 10
+        norms_squared = np.sum((Y ** 2), axis=0)
+        new_code = orthogonal_mp_gram(gram, cov, n_nonzero_coefs, alpha,
+                                      norms_squared, overwrite_Xy=overwrite_cov
+                                      )
+    else:
+        raise NotImplemented('Sparse coding method %s not implemented' %
+                             algorithm)
+    return new_code
+
+
+def sparse_encode_parallel(X, Y, gram=None, cov=None, algorithm='lasso_lars',
+                  n_nonzero_coefs=None, alpha=None, overwrite_gram=False,
+                  overwrite_cov=False, init=None, n_jobs=1):
+    """Parallel sparse coding using joblib
+
+    Each column of the result is the solution to a Lasso problem.
+
+    Parameters
+    ----------
+    X: array of shape (n_samples, n_components)
+        Dictionary against which to optimize the sparse code.
+
+    Y: array of shape (n_samples, n_features)
+        Data matrix.
+
+    gram: array, shape=(n_components, n_components)
+        Precomputed Gram matrix, X^T * X
+
+    cov: array, shape=(n_components, n_features)
+        Precomputed covariance, X^T * Y
+
+    algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
+        lars: uses the least angle regression method (linear_model.lars_path)
+        lasso_lars: uses Lars to compute the Lasso solution
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). lasso_lars will be faster if
+        the estimated components are sparse.
+        omp: uses orthogonal matching pursuit to estimate the sparse solution
+        threshold: squashes to zero all coefficients less than alpha from
+        the projection X.T * Y
+
+    n_nonzero_coefs: int, 0.1 * n_features by default
+        Number of nonzero coefficients to target in each column of the
+        solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+        and is overridden by `alpha` in the `omp` case.
+
+    alpha: float, 1. by default
+        If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+        penalty applied to the L1 norm.
+        If `algorithm='threhold'`, `alpha` is the absolute value of the
+        threshold below which coefficients will be squashed to zero.
+        If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
+        the reconstruction error targeted. In this case, it overrides
+        `n_nonzero_coefs`.
+
+    init: array of shape (n_components, n_features)
+        Initialization value of the sparse codes. Only used if
+        `algorithm='lasso_cd'`.
+
+    overwrite_gram: boolean,
+        Whether to overwrite the precomputed Gram matrix.
+
+    overwrite_cov: boolean,
+        Whether to overwrite the precomputed covariance matrix.
+
+    n_jobs: int,
+        Number of parallel jobs to run.
+
+    Returns
+    -------
+    code: array of shape (n_components, n_features)
+        The sparse codes
+    """
+    n_samples, n_features = Y.shape
+    n_components = X.shape[1]
+    if gram is None:
+        overwrite_gram = True
+        gram = np.dot(X.T, X)
+    if cov is None and algorithm != 'lasso_cd':
+        overwrite_cov = True
+        cov = np.dot(X.T, Y)
+    if n_jobs == 1 or algorithm == 'threshold':
+        return sparse_encode(X, Y, gram, cov, algorithm, n_nonzero_coefs,
+                             alpha, overwrite_gram, overwrite_cov, init)
+    code = np.empty((n_components, n_features))
+    slices = list(gen_even_slices(n_features, n_jobs))
+    code_views = Parallel(n_jobs=n_jobs)(
+                delayed(sparse_encode)(X, Y[:, this_slice], gram,
+                                       cov[:, this_slice], algorithm,
+                                       n_nonzero_coefs, alpha,
+                                       overwrite_gram, overwrite_cov,
+                                       init=init[:, this_slice] if init is not
+                                       None else None)
+                for this_slice in slices)
+    for this_slice, this_view in zip(slices, code_views):
+        code[:, this_slice] = this_view
+    return code
+
+
+def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
+                 random_state=None):
+    """Update the dense dictionary factor in place.
+
+    Parameters
+    ----------
+    dictionary: array of shape (n_samples, n_components)
+        Value of the dictionary at the previous iteration.
+
+    Y: array of shape (n_samples, n_features)
+        Data matrix.
+
+    code: array of shape (n_components, n_features)
+        Sparse coding of the data against which to optimize the dictionary.
+
+    verbose:
+        Degree of output the procedure will print.
+
+    return_r2: bool
+        Whether to compute and return the residual sum of squares corresponding
+        to the computed solution.
+
+    random_state: int or RandomState
+        Pseudo number generator state used for random sampling.
+
+    Returns
+    -------
+    dictionary: array of shape (n_samples, n_components)
+        Updated dictionary.
+
+    """
+    n_atoms = len(code)
+    n_samples = Y.shape[0]
+    random_state = check_random_state(random_state)
+    # Residuals, computed 'in-place' for efficiency
+    R = -np.dot(dictionary, code)
+    R += Y
+    R = np.asfortranarray(R)
+    ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
+    for k in xrange(n_atoms):
+        # R <- 1.0 * U_k * V_k^T + R
+        R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
+        dictionary[:, k] = np.dot(R, code[k, :].T)
+        # Scale k'th atom
+        atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k])
+        if atom_norm_square < 1e-20:
+            if verbose == 1:
+                sys.stdout.write("+")
+                sys.stdout.flush()
+            elif verbose:
+                print "Adding new random atom"
+            dictionary[:, k] = random_state.randn(n_samples)
+            # Setting corresponding coefs to 0
+            code[k, :] = 0.0
+            dictionary[:, k] /= sqrt(np.dot(dictionary[:, k],
+                                            dictionary[:, k]))
+        else:
+            dictionary[:, k] /= sqrt(atom_norm_square)
+            # R <- -1.0 * U_k * V_k^T + R
+            R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
+    if return_r2:
+        R **= 2
+        # R is fortran-ordered. For numpy version < 1.6, sum does not
+        # follow the quick striding first, and is thus inefficient on
+        # fortran ordered data. We take a flat view of the data with no
+        # striding
+        R = as_strided(R, shape=(R.size, ), strides=(R.dtype.itemsize,))
+        R = np.sum(R)
+        return dictionary, R
+    return dictionary
+
+
+def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8,
+                  method='lasso_lars', n_jobs=1, dict_init=None,
+                  code_init=None, callback=None, verbose=False,
+                  random_state=None):
+    """Solves a dictionary learning matrix factorization problem.
+
+    Finds the best dictionary and the corresponding sparse code for
+    approximating the data matrix X by solving::
+
+    (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
+                 (U,V)
+                with || V_k ||_2 = 1 for all  0 <= k < n_atoms
+
+    where V is the dictionary and U is the sparse code.
+
+    Parameters
+    ----------
+    X: array of shape (n_samples, n_features)
+        Data matrix.
+
+    n_atoms: int,
+        Number of dictionary atoms to extract.
+
+    alpha: int,
+        Sparsity controlling parameter.
+
+    max_iter: int,
+        Maximum number of iterations to perform.
+
+    tol: float,
+        Tolerance for the stopping condition.
+
+    method: {'lasso_lars', 'lasso_cd'}
+        lasso_lars: uses the least angle regression method
+        (linear_model.lars_path)
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). Lars will be faster if
+        the estimated components are sparse.
+
+    n_jobs: int,
+        Number of parallel jobs to run, or -1 to autodetect.
+
+    dict_init: array of shape (n_atoms, n_features),
+        Initial value for the dictionary for warm restart scenarios.
+
+    code_init: array of shape (n_samples, n_atoms),
+        Initial value for the sparse code for warm restart scenarios.
+
+    callback:
+        Callable that gets invoked every five iterations.
+
+    verbose:
+        Degree of output the procedure will print.
+
+    random_state: int or RandomState
+        Pseudo number generator state used for random sampling.
+
+    Returns
+    -------
+    code: array of shape (n_samples, n_atoms)
+        The sparse code factor in the matrix factorization.
+
+    dictionary: array of shape (n_atoms, n_features),
+        The dictionary factor in the matrix factorization.
+
+    errors: array
+        Vector of errors at each iteration.
+
+    """
+    if method not in ('lasso_lars', 'lasso_cd'):
+        raise ValueError('Coding method not supported as a fit algorithm.')
+    t0 = time.time()
+    n_features = X.shape[1]
+    # Avoid integer division problems
+    alpha = float(alpha)
+    random_state = check_random_state(random_state)
+
+    if n_jobs == -1:
+        n_jobs = cpu_count()
+
+    # Init U and V with SVD of Y
+    if code_init is not None and code_init is not None:
+        code = np.array(code_init, order='F')
+        # Don't copy V, it will happen below
+        dictionary = dict_init
+    else:
+        code, S, dictionary = linalg.svd(X, full_matrices=False)
+        dictionary = S[:, np.newaxis] * dictionary
+    r = len(dictionary)
+    if n_atoms <= r:  # True even if n_atoms=None
+        code = code[:, :n_atoms]
+        dictionary = dictionary[:n_atoms, :]
+    else:
+        code = np.c_[code, np.zeros((len(code), n_atoms - r))]
+        dictionary = np.r_[dictionary,
+                           np.zeros((n_atoms - r, dictionary.shape[1]))]
+
+    # Fortran-order dict, as we are going to access its row vectors
+    dictionary = np.array(dictionary, order='F')
+
+    residuals = 0
+
+    errors = []
+    current_cost = np.nan
+
+    if verbose == 1:
+        print '[dict_learning]',
+
+    for ii in xrange(max_iter):
+        dt = (time.time() - t0)
+        if verbose == 1:
+            sys.stdout.write(".")
+            sys.stdout.flush()
+        elif verbose:
+            print ("Iteration % 3i "
+                "(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)" %
+                    (ii, dt, dt / 60, current_cost))
+
+        # Update code
+        code = sparse_encode_parallel(dictionary.T, X.T, algorithm=method,
+                                      alpha=alpha / n_features,
+                                      init=code.T, n_jobs=n_jobs)
+        code = code.T
+        # Update dictionary
+        dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
+                                             verbose=verbose, return_r2=True,
+                                             random_state=random_state)
+        dictionary = dictionary.T
+
+        # Cost function
+        current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
+        errors.append(current_cost)
+
+        if ii > 0:
+            dE = errors[-2] - errors[-1]
+            # assert(dE >= -tol * errors[-1])
+            if dE < tol * errors[-1]:
+                if verbose == 1:
+                    # A line return
+                    print ""
+                elif verbose:
+                    print "--- Convergence reached after %d iterations" % ii
+                break
+        if ii % 5 == 0 and callback is not None:
+            callback(locals())
+
+    return code, dictionary, errors
+
+
+def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True,
+                         dict_init=None, callback=None, chunk_size=3,
+                         verbose=False, shuffle=True, n_jobs=1,
+                         method='lasso_lars', iter_offset=0,
+                         random_state=None):
+    """Solves a dictionary learning matrix factorization problem online.
+
+    Finds the best dictionary and the corresponding sparse code for
+    approximating the data matrix X by solving:
+
+    (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
+                 (U,V)
+                 with || V_k ||_2 = 1 for all  0 <= k < n_atoms
+
+    where V is the dictionary and U is the sparse code. This is
+    accomplished by repeatedly iterating over mini-batches by slicing
+    the input data.
+
+    Parameters
+    ----------
+    X: array of shape (n_samples, n_features)
+        data matrix
+
+    n_atoms: int,
+        number of dictionary atoms to extract
+
+    alpha: int,
+        sparsity controlling parameter
+
+    n_iter: int,
+        number of iterations to perform
+
+    return_code: boolean,
+        whether to also return the code U or just the dictionary V
+
+    dict_init: array of shape (n_atoms, n_features),
+        initial value for the dictionary for warm restart scenarios
+
+    callback:
+        callable that gets invoked every five iterations
+
+    chunk_size: int,
+        the number of samples to take in each batch
+
+    verbose:
+        degree of output the procedure will print
+
+    shuffle: boolean,
+        whether to shuffle the data before splitting it in batches
+
+    n_jobs: int,
+        number of parallel jobs to run, or -1 to autodetect.
+
+    method: {'lasso_lars', 'lasso_cd'}
+        lasso_lars: uses the least angle regression method
+        (linear_model.lars_path)
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). Lars will be faster if
+        the estimated components are sparse.
+
+    iter_offset: int, default 0
+        number of previous iterations completed on the dictionary used for
+        initialization
+
+    random_state: int or RandomState
+        Pseudo number generator state used for random sampling.
+
+    Returns
+    -------
+    dictionary: array of shape (n_atoms, n_features),
+        the solutions to the dictionary learning problem
+
+    code: array of shape (n_samples, n_atoms),
+        the sparse code (only returned if `return_code=True`)
+    """
+    if method not in ('lasso_lars', 'lasso_cd'):
+        raise ValueError('Coding method not supported as a fit algorithm.')
+    t0 = time.time()
+    n_samples, n_features = X.shape
+    # Avoid integer division problems
+    alpha = float(alpha)
+    random_state = check_random_state(random_state)
+
+    if n_jobs == -1:
+        n_jobs = cpu_count()
+
+    # Init V with SVD of X
+    if dict_init is not None:
+        dictionary = dict_init
+    else:
+        _, S, dictionary = fast_svd(X, n_atoms)
+        dictionary = S[:, np.newaxis] * dictionary
+    r = len(dictionary)
+    if n_atoms <= r:
+        dictionary = dictionary[:n_atoms, :]
+    else:
+        dictionary = np.r_[dictionary,
+                           np.zeros((n_atoms - r, dictionary.shape[1]))]
+    dictionary = np.ascontiguousarray(dictionary.T)
+
+    if verbose == 1:
+        print '[dict_learning]',
+
+    n_batches = floor(float(len(X)) / chunk_size)
+    if shuffle:
+        X_train = X.copy()
+        random_state.shuffle(X_train)
+    else:
+        X_train = X
+    batches = np.array_split(X_train, n_batches)
+    batches = itertools.cycle(batches)
+
+    # The covariance of the dictionary
+    A = np.zeros((n_atoms, n_atoms))
+    # The data approximation
+    B = np.zeros((n_features, n_atoms))
+
+    for ii, this_X in itertools.izip(xrange(iter_offset, iter_offset + n_iter),
+                                     batches):
+        dt = (time.time() - t0)
+        if verbose == 1:
+            sys.stdout.write(".")
+            sys.stdout.flush()
+        elif verbose:
+            if verbose > 10 or ii % ceil(100. / verbose) == 0:
+                print ("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)" %
+                    (ii, dt, dt / 60))
+
+        this_code = sparse_encode(dictionary, this_X.T, algorithm=method,
+                                  alpha=alpha)
+
+        # Update the auxiliary variables
+        if ii < chunk_size - 1:
+            theta = float((ii + 1) * chunk_size)
+        else:
+            theta = float(chunk_size ** 2 + ii + 1 - chunk_size)
+        beta = (theta + 1 - chunk_size) / (theta + 1)
+
+        A *= beta
+        A += np.dot(this_code, this_code.T)
+        B *= beta
+        B += np.dot(this_X.T, this_code.T)
+
+        # Update dictionary
+        dictionary = _update_dict(dictionary, B, A, verbose=verbose,
+                                  random_state=random_state)
+        # XXX: Can the residuals be of any use?
+
+        # Maybe we need a stopping criteria based on the amount of
+        # modification in the dictionary
+        if callback is not None:
+            callback(locals())
+
+    if return_code:
+        if verbose > 1:
+            print 'Learning code...',
+        elif verbose == 1:
+            print '|',
+        code = sparse_encode_parallel(dictionary, X.T, algorithm=method,
+                                      alpha=alpha, n_jobs=n_jobs)
+        if verbose > 1:
+            dt = (time.time() - t0)
+            print 'done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60)
+        return code.T, dictionary.T
+
+    return dictionary.T
+
+
+class BaseDictionaryLearning(BaseEstimator, TransformerMixin):
+    """Dictionary learning base class"""
+
+    def __init__(self, n_atoms, transform_algorithm='omp',
+                 transform_n_nonzero_coefs=None, transform_alpha=None,
+                 split_sign=False, n_jobs=1):
+        self.n_atoms = n_atoms
+        self.transform_algorithm = transform_algorithm
+        self.transform_n_nonzero_coefs = transform_n_nonzero_coefs
+        self.transform_alpha = transform_alpha
+        self.split_sign = split_sign
+        self.n_jobs = n_jobs
+
+    def transform(self, X, y=None):
+        """Encode the data as a sparse combination of the dictionary atoms.
+
+        Coding method is determined by the object parameter
+        `transform_algorithm`.
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Test data to be transformed, must have the same number of
+            features as the data used to train the model.
+
+        Returns
+        -------
+        X_new array, shape (n_samples, n_components)
+            Transformed data
+        """
+        # XXX : kwargs is not documented
+        X = np.atleast_2d(X)
+        n_samples, n_features = X.shape
+
+        code = sparse_encode_parallel(
+            self.components_.T, X.T, algorithm=self.transform_algorithm,
+            n_nonzero_coefs=self.transform_n_nonzero_coefs,
+            alpha=self.transform_alpha, n_jobs=self.n_jobs)
+        code = code.T
+
+        if self.split_sign:
+            # feature vector is split into a positive and negative side
+            n_samples, n_features = code.shape
+            split_code = np.empty((n_samples, 2 * n_features))
+            split_code[:, :n_features] = np.maximum(code, 0)
+            split_code[:, n_features:] = -np.minimum(code, 0)
+            code = split_code
+
+        return code
+
+
+class DictionaryLearning(BaseDictionaryLearning):
+    """ Dictionary learning
+
+    Finds a dictionary (a set of atoms) that can best be used to represent data
+    using a sparse code.
+
+    Solves the optimization problem:
+    (U^*,V^*) = argmin 0.5 || Y - U V ||_2^2 + alpha * || U ||_1
+                 (U,V)
+                with || V_k ||_2 = 1 for all  0 <= k < n_atoms
+
+    Parameters
+    ----------
+    n_atoms: int,
+        number of dictionary elements to extract
+
+    alpha: int,
+        sparsity controlling parameter
+
+    max_iter: int,
+        maximum number of iterations to perform
+
+    tol: float,
+        tolerance for numerical error
+
+    fit_algorithm: {'lasso_lars', 'lasso_cd'}
+        lasso_lars: uses the least angle regression method
+        (linear_model.lars_path)
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). Lars will be faster if
+        the estimated components are sparse.
+
+    transform_algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
+        Algorithm used to transform the data
+        lars: uses the least angle regression method (linear_model.lars_path)
+        lasso_lars: uses Lars to compute the Lasso solution
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). lasso_lars will be faster if
+        the estimated components are sparse.
+        omp: uses orthogonal matching pursuit to estimate the sparse solution
+        threshold: squashes to zero all coefficients less than alpha from
+        the projection X.T * Y
+
+    transform_n_nonzero_coefs: int, 0.1 * n_features by default
+        Number of nonzero coefficients to target in each column of the
+        solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+        and is overridden by `alpha` in the `omp` case.
+
+    transform_alpha: float, 1. by default
+        If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+        penalty applied to the L1 norm.
+        If `algorithm='threhold'`, `alpha` is the absolute value of the
+        threshold below which coefficients will be squashed to zero.
+        If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
+        the reconstruction error targeted. In this case, it overrides
+        `n_nonzero_coefs`.
+
+    n_jobs: int,
+        number of parallel jobs to run
+
+    code_init: array of shape (n_samples, n_atoms),
+        initial value for the code, for warm restart
+
+    dict_init: array of shape (n_atoms, n_features),
+        initial values for the dictionary, for warm restart
+
+    verbose:
+        degree of verbosity of the printed output
+
+    random_state: int or RandomState
+        Pseudo number generator state used for random sampling.
+
+    Attributes
+    ----------
+    components_: array, [n_atoms, n_features]
+        dictionary atoms extracted from the data
+
+    error_: array
+        vector of errors at each iteration
+
+    References
+    ----------
+    J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
+    for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf)
+
+
+    See also
+    --------
+    :class:`sklearn.decomposition.SparsePCA` which solves the transposed
+    problem, finding sparse components to represent data.
+
+    """
+    def __init__(self, n_atoms, alpha=1, max_iter=1000, tol=1e-8,
+                 fit_algorithm='lasso_lars', transform_algorithm='omp',
+                 transform_n_nonzero_coefs=None, transform_alpha=None,
+                 n_jobs=1, code_init=None, dict_init=None, verbose=False,
+                 split_sign=False, random_state=None):
+        BaseDictionaryLearning.__init__(self, n_atoms, transform_algorithm,
+                 transform_n_nonzero_coefs, transform_alpha, split_sign,
+                 n_jobs)
+        self.alpha = alpha
+        self.max_iter = max_iter
+        self.tol = tol
+        self.fit_algorithm = fit_algorithm
+        self.code_init = code_init
+        self.dict_init = dict_init
+        self.verbose = verbose
+        self.random_state = random_state
+
+    def fit(self, X, y=None):
+        """Fit the model from data in X.
+
+        Parameters
+        ----------
+        X: array-like, shape (n_samples, n_features)
+            Training vector, where n_samples in the number of samples
+            and n_features is the number of features.
+
+        Returns
+        -------
+        self: object
+            Returns the object itself
+        """
+        self.random_state = check_random_state(self.random_state)
+        X = np.asanyarray(X)
+        V, U, E = dict_learning(X, self.n_atoms, self.alpha,
+                                tol=self.tol, max_iter=self.max_iter,
+                                method=self.fit_algorithm,
+                                n_jobs=self.n_jobs,
+                                code_init=self.code_init,
+                                dict_init=self.dict_init,
+                                verbose=self.verbose,
+                                random_state=self.random_state)
+        self.components_ = U
+        self.error_ = E
+        return self
+
+
+class DictionaryLearningOnline(BaseDictionaryLearning):
+    """ Online dictionary learning
+
+    Finds a dictionary (a set of atoms) that can best be used to represent data
+    using a sparse code.
+
+    Solves the optimization problem:
+    (U^*,V^*) = argmin 0.5 || Y - U V ||_2^2 + alpha * || U ||_1
+                 (U,V)
+                with || V_k ||_2 = 1 for all  0 <= k < n_atoms
+
+    Parameters
+    ----------
+    n_atoms: int,
+        number of dictionary elements to extract
+
+    alpha: int,
+        sparsity controlling parameter
+
+    n_iter: int,
+        total number of iterations to perform
+
+    fit_algorithm: {'lars', 'cd'}
+        lars: uses the least angle regression method (linear_model.lars_path)
+        cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). Lars will be faster if
+        the estimated components are sparse.
+
+    transform_algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
+        Algorithm used to transform the data.
+        lars: uses the least angle regression method (linear_model.lars_path)
+        lasso_lars: uses Lars to compute the Lasso solution
+        lasso_cd: uses the coordinate descent method to compute the
+        Lasso solution (linear_model.Lasso). lasso_lars will be faster if
+        the estimated components are sparse.
+        omp: uses orthogonal matching pursuit to estimate the sparse solution
+        threshold: squashes to zero all coefficients less than alpha from
+        the projection X.T * Y
+
+    transform_n_nonzero_coefs: int, 0.1 * n_features by default
+        Number of nonzero coefficients to target in each column of the
+        solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+        and is overridden by `alpha` in the `omp` case.
+
+    transform_alpha: float, 1. by default
+        If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+        penalty applied to the L1 norm.
+        If `algorithm='threhold'`, `alpha` is the absolute value of the
+        threshold below which coefficients will be squashed to zero.
+        If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
+        the reconstruction error targeted. In this case, it overrides
+        `n_nonzero_coefs`.
+
+    n_jobs: int,
+        number of parallel jobs to run
+
+    dict_init: array of shape (n_atoms, n_features),
+        initial value of the dictionary for warm restart scenarios
+
+    verbose:
+        degree of verbosity of the printed output
+
+    chunk_size: int,
+        number of samples in each mini-batch
+
+    shuffle: bool,
+        whether to shuffle the samples before forming batches
+
+    random_state: int or RandomState
+        Pseudo number generator state used for random sampling.
+
+    Attributes
+    ----------
+    components_: array, [n_atoms, n_features]
+        components extracted from the data
+
+    References
+    ----------
+    J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
+    for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf)
+
+
+    See also
+    --------
+    :class:`sklearn.decomposition.SparsePCA` which solves the transposed
+    problem, finding sparse components to represent data.
+
+    """
+    def __init__(self, n_atoms, alpha=1, n_iter=1000,
+                 fit_algorithm='lasso_lars', n_jobs=1, chunk_size=3,
+                 shuffle=True, dict_init=None, transform_algorithm='omp',
+                 transform_n_nonzero_coefs=None, transform_alpha=None,
+                 verbose=False, split_sign=False, random_state=None):
+        BaseDictionaryLearning.__init__(self, n_atoms, transform_algorithm,
+                 transform_n_nonzero_coefs, transform_alpha, split_sign,
+                 n_jobs)
+        self.alpha = alpha
+        self.n_iter = n_iter
+        self.fit_algorithm = fit_algorithm
+        self.dict_init = dict_init
+        self.verbose = verbose
+        self.shuffle = shuffle
+        self.chunk_size = chunk_size
+        self.split_sign = split_sign
+        self.random_state = random_state
+
+    def fit(self, X, y=None):
+        """Fit the model from data in X.
+
+        Parameters
+        ----------
+        X: array-like, shape (n_samples, n_features)
+            Training vector, where n_samples in the number of samples
+            and n_features is the number of features.
+
+        Returns
+        -------
+        self : object
+            Returns the instance itself.
+        """
+        self.random_state = check_random_state(self.random_state)
+        X = np.asanyarray(X)
+        U = dict_learning_online(X, self.n_atoms, self.alpha,
+                                 n_iter=self.n_iter, return_code=False,
+                                 method=self.fit_algorithm,
+                                 n_jobs=self.n_jobs,
+                                 dict_init=self.dict_init,
+                                 chunk_size=self.chunk_size,
+                                 shuffle=self.shuffle, verbose=self.verbose,
+                                 random_state=self.random_state)
+        self.components_ = U
+        return self
+
+    def partial_fit(self, X, y=None, iter_offset=0):
+        """Updates the model using the data in X as a mini-batch.
+
+        Parameters
+        ----------
+        X: array-like, shape (n_samples, n_features)
+            Training vector, where n_samples in the number of samples
+            and n_features is the number of features.
+
+        Returns
+        -------
+        self : object
+            Returns the instance itself.
+        """
+        self.random_state = check_random_state(self.random_state)
+        X = np.atleast_2d(X)
+        if hasattr(self, 'components_'):
+            dict_init = self.components_
+        else:
+            dict_init = self.dict_init
+        U = dict_learning_online(X, self.n_atoms, self.alpha,
+                                 n_iter=self.n_iter,
+                                 method=self.fit_algorithm,
+                                 n_jobs=self.n_jobs, dict_init=dict_init,
+                                 chunk_size=len(X), shuffle=False,
+                                 verbose=self.verbose, return_code=False,
+                                 iter_offset=iter_offset,
+                                 random_state=self.random_state)
+        self.components_ = U
+        return self
diff --git a/sklearn/decomposition/sparse_pca.py b/sklearn/decomposition/sparse_pca.py
index f6b02dc1d4ac78f35dbd5b4e12c6d3e308d85073..7bcf2b729b34aa853c3c1cdc636e2ea34b966a13 100644
--- a/sklearn/decomposition/sparse_pca.py
+++ b/sklearn/decomposition/sparse_pca.py
@@ -2,529 +2,12 @@
 # Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort
 # License: BSD
 
-import time
-import sys
-
-from math import sqrt, floor, ceil
-import itertools
-
 import numpy as np
-from numpy.lib.stride_tricks import as_strided
-from scipy import linalg
 
 from ..utils import check_random_state
-from ..utils import gen_even_slices
-from ..utils.extmath import fast_svd
-from ..linear_model import Lasso, lars_path, ridge_regression
-from ..externals.joblib import Parallel, delayed, cpu_count
+from ..linear_model import ridge_regression
 from ..base import BaseEstimator, TransformerMixin
-
-
-def _update_code(dictionary, Y, alpha, code=None, Gram=None, method='lars',
-                 tol=1e-8):
-    """Update the sparse code factor in the sparse_pca loop.
-
-    Each column of the result is the solution to a Lasso problem.
-
-    Parameters
-    ----------
-    dictionary: array of shape (n_samples, n_components)
-        Dictionary against which to optimize the sparse code.
-
-    Y: array of shape (n_samples, n_features)
-        Data matrix.
-
-    alpha: float
-        Regularization parameter for the Lasso problem.
-
-    code: array of shape (n_components, n_features)
-        Value of the sparse codes at the previous iteration.
-
-    Gram: array of shape (n_features, n_features)
-        Precomputed Gram matrix, (Y^T * Y).
-
-    method: {'lars', 'cd'}
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
-        Lasso solution (linear_model.Lasso). Lars will be faster if
-        the estimated components are sparse.
-
-    tol: float
-        Numerical tolerance for coordinate descent Lasso convergence.
-        Only used if `method='cd'`
-
-    Returns
-    -------
-    new_code : array of shape (n_components, n_features)
-        The sparse codes precomputed using this iteration's dictionary
-    """
-    n_features = Y.shape[1]
-    n_atoms = dictionary.shape[1]
-    new_code = np.empty((n_atoms, n_features))
-    if Gram is None:
-        Gram = np.dot(dictionary.T, dictionary)
-    if method == 'lars':
-        XY = np.dot(dictionary.T, Y)
-        try:
-            err_mgt = np.seterr(all='ignore')
-            for k in range(n_features):
-                # A huge amount of time is spent in this loop. It needs to be
-                # tight.
-                _, _, coef_path_ = lars_path(dictionary, Y[:, k], Xy=XY[:, k],
-                                             Gram=Gram, alpha_min=alpha,
-                                             method='lasso')
-                new_code[:, k] = coef_path_[:, -1]
-        finally:
-            np.seterr(**err_mgt)
-    elif method == 'cd':
-        clf = Lasso(alpha=alpha, fit_intercept=False, precompute=Gram,
-                    max_iter=1000, tol=tol)
-        for k in range(n_features):
-            # A huge amount of time is spent in this loop. It needs to be
-            # tight.
-            if code is not None:
-                clf.coef_ = code[:, k]  # Init with previous value of Vk
-            clf.fit(dictionary, Y[:, k])
-            new_code[:, k] = clf.coef_
-    else:
-        raise NotImplemented("Lasso method %s is not implemented." % method)
-    return new_code
-
-
-def _update_code_parallel(dictionary, Y, alpha, code=None, Gram=None,
-                          method='lars', n_jobs=1, tol=1e-8):
-    """Update the sparse factor V in the sparse_pca loop in parallel.
-
-    The computation is spread over all the available cores.
-
-    Parameters
-    ----------
-    dictionary: array of shape (n_samples, n_components)
-        Dictionary against which to optimize the sparse code.
-
-    Y: array of shape (n_samples, n_features)
-        Data matrix.
-
-    alpha: float
-        Regularization parameter for the Lasso problem.
-
-    code: array of shape (n_components, n_features)
-        Previous iteration of the sparse code.
-
-    Gram: array of shape (n_features, n_features)
-        Precomputed Gram matrix, (Y^T * Y).
-
-    method: 'lars' | 'cd'
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
-        lasso solution (linear_model.Lasso). Lars will be faster if
-        the components extracted are sparse.
-
-    n_jobs: int
-        Number of parallel jobs to run.
-
-    tol: float
-        Numerical tolerance for coordinate descent Lasso convergence.
-        Only used if `method='cd`.
-
-    """
-    n_samples, n_features = Y.shape
-    n_atoms = dictionary.shape[1]
-    if Gram is None:
-        Gram = np.dot(dictionary.T, dictionary)
-    if n_jobs == 1:
-        return _update_code(dictionary, Y, alpha, code=code, Gram=Gram,
-                            method=method)
-    if code is None:
-        code = np.empty((n_atoms, n_features))
-    slices = list(gen_even_slices(n_features, n_jobs))
-    code_views = Parallel(n_jobs=n_jobs)(
-                delayed(_update_code)(dictionary, Y[:, this_slice],
-                                      code=code[:, this_slice], alpha=alpha,
-                                      Gram=Gram, method=method, tol=tol)
-                for this_slice in slices)
-    for this_slice, this_view in zip(slices, code_views):
-        code[:, this_slice] = this_view
-    return code
-
-
-def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
-                 random_state=None):
-    """Update the dense dictionary factor in place.
-
-    Parameters
-    ----------
-    dictionary: array of shape (n_samples, n_components)
-        Value of the dictionary at the previous iteration.
-
-    Y: array of shape (n_samples, n_features)
-        Data matrix.
-
-    code: array of shape (n_components, n_features)
-        Sparse coding of the data against which to optimize the dictionary.
-
-    verbose:
-        Degree of output the procedure will print.
-
-    return_r2: bool
-        Whether to compute and return the residual sum of squares corresponding
-        to the computed solution.
-
-    random_state: int or RandomState
-        Pseudo number generator state used for random sampling.
-
-    Returns
-    -------
-    dictionary: array of shape (n_samples, n_components)
-        Updated dictionary.
-
-    """
-    n_atoms = len(code)
-    n_samples = Y.shape[0]
-    random_state = check_random_state(random_state)
-    # Residuals, computed 'in-place' for efficiency
-    R = -np.dot(dictionary, code)
-    R += Y
-    R = np.asfortranarray(R)
-    ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
-    for k in xrange(n_atoms):
-        # R <- 1.0 * U_k * V_k^T + R
-        R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
-        dictionary[:, k] = np.dot(R, code[k, :].T)
-        # Scale k'th atom
-        atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k])
-        if atom_norm_square < 1e-20:
-            if verbose == 1:
-                sys.stdout.write("+")
-                sys.stdout.flush()
-            elif verbose:
-                print "Adding new random atom"
-            dictionary[:, k] = random_state.randn(n_samples)
-            # Setting corresponding coefs to 0
-            code[k, :] = 0.0
-            dictionary[:, k] /= sqrt(np.dot(dictionary[:, k],
-                                            dictionary[:, k]))
-        else:
-            dictionary[:, k] /= sqrt(atom_norm_square)
-            # R <- -1.0 * U_k * V_k^T + R
-            R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
-    if return_r2:
-        R **= 2
-        # R is fortran-ordered. For numpy version < 1.6, sum does not
-        # follow the quick striding first, and is thus inefficient on
-        # fortran ordered data. We take a flat view of the data with no
-        # striding
-        R = as_strided(R, shape=(R.size, ), strides=(R.dtype.itemsize,))
-        R = np.sum(R)
-        return dictionary, R
-    return dictionary
-
-
-def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, method='lars',
-                  n_jobs=1, dict_init=None, code_init=None, callback=None,
-                  verbose=False, random_state=None):
-    """Solves a dictionary learning matrix factorization problem.
-
-    Finds the best dictionary and the corresponding sparse code for
-    approximating the data matrix X by solving::
-
-    (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
-                 (U,V)
-                with || V_k ||_2 = 1 for all  0 <= k < n_atoms
-
-    where V is the dictionary and U is the sparse code.
-
-    Parameters
-    ----------
-    X: array of shape (n_samples, n_features)
-        Data matrix.
-
-    n_atoms: int,
-        Number of dictionary atoms to extract.
-
-    alpha: int,
-        Sparsity controlling parameter.
-
-    max_iter: int,
-        Maximum number of iterations to perform.
-
-    tol: float,
-        Tolerance for the stopping condition.
-
-    method: {'lars', 'cd'}
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
-        Lasso solution (linear_model.Lasso). Lars will be faster if
-        the estimated components are sparse.
-
-    n_jobs: int,
-        Number of parallel jobs to run, or -1 to autodetect.
-
-    dict_init: array of shape (n_atoms, n_features),
-        Initial value for the dictionary for warm restart scenarios.
-
-    code_init: array of shape (n_samples, n_atoms),
-        Initial value for the sparse code for warm restart scenarios.
-
-    callback:
-        Callable that gets invoked every five iterations.
-
-    verbose:
-        Degree of output the procedure will print.
-
-    random_state: int or RandomState
-        Pseudo number generator state used for random sampling.
-
-    Returns
-    -------
-    code: array of shape (n_samples, n_atoms)
-        The sparse code factor in the matrix factorization.
-
-    dictionary: array of shape (n_atoms, n_features),
-        The dictionary factor in the matrix factorization.
-
-    errors: array
-        Vector of errors at each iteration.
-
-    """
-    t0 = time.time()
-    n_features = X.shape[1]
-    # Avoid integer division problems
-    alpha = float(alpha)
-    random_state = check_random_state(random_state)
-
-    if n_jobs == -1:
-        n_jobs = cpu_count()
-
-    # Init U and V with SVD of Y
-    if code_init is not None and code_init is not None:
-        code = np.array(code_init, order='F')
-        # Don't copy V, it will happen below
-        dictionary = dict_init
-    else:
-        code, S, dictionary = linalg.svd(X, full_matrices=False)
-        dictionary = S[:, np.newaxis] * dictionary
-    r = len(dictionary)
-    if n_atoms <= r:  # True even if n_atoms=None
-        code = code[:, :n_atoms]
-        dictionary = dictionary[:n_atoms, :]
-    else:
-        code = np.c_[code, np.zeros((len(code), n_atoms - r))]
-        dictionary = np.r_[dictionary,
-                           np.zeros((n_atoms - r, dictionary.shape[1]))]
-
-    # Fortran-order dict, as we are going to access its row vectors
-    dictionary = np.array(dictionary, order='F')
-
-    residuals = 0
-
-    errors = []
-    current_cost = np.nan
-
-    if verbose == 1:
-        print '[dict_learning]',
-
-    for ii in xrange(max_iter):
-        dt = (time.time() - t0)
-        if verbose == 1:
-            sys.stdout.write(".")
-            sys.stdout.flush()
-        elif verbose:
-            print ("Iteration % 3i "
-                "(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)" %
-                    (ii, dt, dt / 60, current_cost))
-
-        # Update code
-        code = _update_code_parallel(dictionary.T, X.T, alpha / n_features,
-                                     code.T, method=method, n_jobs=n_jobs)
-        code = code.T
-        # Update dictionary
-        dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
-                                             verbose=verbose, return_r2=True,
-                                             random_state=random_state)
-        dictionary = dictionary.T
-
-        # Cost function
-        current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
-        errors.append(current_cost)
-
-        if ii > 0:
-            dE = errors[-2] - errors[-1]
-            assert(dE >= -tol * errors[-1])
-            if dE < tol * errors[-1]:
-                if verbose == 1:
-                    # A line return
-                    print ""
-                elif verbose:
-                    print "--- Convergence reached after %d iterations" % ii
-                break
-        if ii % 5 == 0 and callback is not None:
-            callback(locals())
-
-    return code, dictionary, errors
-
-
-def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True,
-                         dict_init=None, callback=None, chunk_size=3,
-                         verbose=False, shuffle=True, n_jobs=1,
-                         method='lars', iter_offset=0, random_state=None):
-    """Solves a dictionary learning matrix factorization problem online.
-
-    Finds the best dictionary and the corresponding sparse code for
-    approximating the data matrix X by solving:
-
-    (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
-                 (U,V)
-                 with || V_k ||_2 = 1 for all  0 <= k < n_atoms
-
-    where V is the dictionary and U is the sparse code. This is
-    accomplished by repeatedly iterating over mini-batches by slicing
-    the input data.
-
-    Parameters
-    ----------
-    X: array of shape (n_samples, n_features)
-        data matrix
-
-    n_atoms: int,
-        number of dictionary atoms to extract
-
-    alpha: int,
-        sparsity controlling parameter
-
-    n_iter: int,
-        number of iterations to perform
-
-    return_code: boolean,
-        whether to also return the code U or just the dictionary V
-
-    dict_init: array of shape (n_atoms, n_features),
-        initial value for the dictionary for warm restart scenarios
-
-    callback:
-        callable that gets invoked every five iterations
-
-    chunk_size: int,
-        the number of samples to take in each batch
-
-    verbose:
-        degree of output the procedure will print
-
-    shuffle: boolean,
-        whether to shuffle the data before splitting it in batches
-
-    n_jobs: int,
-        number of parallel jobs to run, or -1 to autodetect.
-
-    method: {'lars', 'cd'}
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
-        Lasso solution (linear_model.Lasso). Lars will be faster if
-        the estimated components are sparse.
-
-    iter_offset: int, default 0
-        number of previous iterations completed on the dictionary used for
-        initialization
-
-    random_state: int or RandomState
-        Pseudo number generator state used for random sampling.
-
-    Returns
-    -------
-    dictionary: array of shape (n_atoms, n_features),
-        the solutions to the dictionary learning problem
-
-    code: array of shape (n_samples, n_atoms),
-        the sparse code (only returned if `return_code=True`)
-    """
-    t0 = time.time()
-    n_samples, n_features = X.shape
-    # Avoid integer division problems
-    alpha = float(alpha)
-    random_state = check_random_state(random_state)
-
-    if n_jobs == -1:
-        n_jobs = cpu_count()
-
-    # Init V with SVD of X
-    if dict_init is not None:
-        dictionary = dict_init
-    else:
-        _, S, dictionary = fast_svd(X, n_atoms)
-        dictionary = S[:, np.newaxis] * dictionary
-    r = len(dictionary)
-    if n_atoms <= r:
-        dictionary = dictionary[:n_atoms, :]
-    else:
-        dictionary = np.r_[dictionary,
-                           np.zeros((n_atoms - r, dictionary.shape[1]))]
-    dictionary = np.ascontiguousarray(dictionary.T)
-
-    if verbose == 1:
-        print '[dict_learning]',
-
-    n_batches = floor(float(len(X)) / chunk_size)
-    if shuffle:
-        X_train = X.copy()
-        random_state.shuffle(X_train)
-    else:
-        X_train = X
-    batches = np.array_split(X_train, n_batches)
-    batches = itertools.cycle(batches)
-
-    # The covariance of the dictionary
-    A = np.zeros((n_atoms, n_atoms))
-    # The data approximation
-    B = np.zeros((n_features, n_atoms))
-
-    for ii, this_X in itertools.izip(xrange(iter_offset, iter_offset + n_iter),
-                                     batches):
-        dt = (time.time() - t0)
-        if verbose == 1:
-            sys.stdout.write(".")
-            sys.stdout.flush()
-        elif verbose:
-            if verbose > 10 or ii % ceil(100. / verbose) == 0:
-                print ("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)" %
-                    (ii, dt, dt / 60))
-
-        this_code = _update_code(dictionary, this_X.T, alpha, method=method)
-
-        # Update the auxiliary variables
-        if ii < chunk_size - 1:
-            theta = float((ii + 1) * chunk_size)
-        else:
-            theta = float(chunk_size ** 2 + ii + 1 - chunk_size)
-        beta = (theta + 1 - chunk_size) / (theta + 1)
-
-        A *= beta
-        A += np.dot(this_code, this_code.T)
-        B *= beta
-        B += np.dot(this_X.T, this_code.T)
-
-        # Update dictionary
-        dictionary = _update_dict(dictionary, B, A, verbose=verbose,
-                                  random_state=random_state)
-        # XXX: Can the residuals be of any use?
-
-        # Maybe we need a stopping criteria based on the amount of
-        # modification in the dictionary
-        if callback is not None:
-            callback(locals())
-
-    if return_code:
-        if verbose > 1:
-            print 'Learning code...',
-        elif verbose == 1:
-            print '|',
-        code = _update_code_parallel(dictionary, X.T, alpha, n_jobs=n_jobs,
-                    method=method)
-        if verbose > 1:
-            dt = (time.time() - t0)
-            print 'done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60)
-        return code.T, dictionary.T
-
-    return dictionary.T
+from .dict_learning import dict_learning, dict_learning_online
 
 
 class SparsePCA(BaseEstimator, TransformerMixin):
@@ -553,9 +36,10 @@ class SparsePCA(BaseEstimator, TransformerMixin):
     tol: float,
         Tolerance for the stopping condition.
 
-    method: {'lars', 'cd'}
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
+    method: {'lasso_lars', 'lasso_cd'}
+        lasso_lars: uses the least angle regression method
+        (linear_model.lars_path)
+        lasso_cd: uses the coordinate descent method to compute the
         Lasso solution (linear_model.Lasso). Lars will be faster if
         the estimated components are sparse.
 
@@ -588,8 +72,8 @@ class SparsePCA(BaseEstimator, TransformerMixin):
 
     """
     def __init__(self, n_components, alpha=1, ridge_alpha=0.01, max_iter=1000,
-                 tol=1e-8, method='lars', n_jobs=1, U_init=None, V_init=None,
-                 verbose=False, random_state=None):
+                 tol=1e-8, method='lasso_lars', n_jobs=1, U_init=None,
+                 V_init=None, verbose=False, random_state=None):
         self.n_components = n_components
         self.alpha = alpha
         self.ridge_alpha = ridge_alpha
@@ -701,9 +185,10 @@ class MiniBatchSparsePCA(SparsePCA):
     n_jobs: int,
         number of parallel jobs to run, or -1 to autodetect.
 
-    method: {'lars', 'cd'}
-        lars: uses the least angle regression method (linear_model.lars_path)
-        cd: uses the coordinate descent method to compute the
+    method: {'lasso_lars', 'lasso_cd'}
+        lasso_lars: uses the least angle regression method
+        (linear_model.lars_path)
+        lasso_cd: uses the coordinate descent method to compute the
         Lasso solution (linear_model.Lasso). Lars will be faster if
         the estimated components are sparse.
 
@@ -713,7 +198,7 @@ class MiniBatchSparsePCA(SparsePCA):
     """
     def __init__(self, n_components, alpha=1, ridge_alpha=0.01, n_iter=100,
                  callback=None, chunk_size=3, verbose=False, shuffle=True,
-                 n_jobs=1, method='lars', random_state=None):
+                 n_jobs=1, method='lasso_lars', random_state=None):
         self.n_components = n_components
         self.alpha = alpha
         self.ridge_alpha = ridge_alpha
diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcde71495e44ed6a8f2c52b31f873e7591255b59
--- /dev/null
+++ b/sklearn/decomposition/tests/test_dict_learning.py
@@ -0,0 +1,122 @@
+import numpy as np
+from numpy.testing import assert_array_almost_equal, assert_array_equal, \
+                          assert_equal
+
+from .. import DictionaryLearning, DictionaryLearningOnline, \
+               dict_learning_online
+from ..dict_learning import sparse_encode, sparse_encode_parallel
+
+rng = np.random.RandomState(0)
+n_samples, n_features = 10, 8
+X = rng.randn(n_samples, n_features)
+
+
+def test_dict_learning_shapes():
+    n_atoms = 5
+    dico = DictionaryLearning(n_atoms).fit(X)
+    assert dico.components_.shape == (n_atoms, n_features)
+
+
+def test_dict_learning_overcomplete():
+    n_atoms = 12
+    X = rng.randn(n_samples, n_features)
+    dico = DictionaryLearning(n_atoms).fit(X)
+    assert dico.components_.shape == (n_atoms, n_features)
+
+
+def test_dict_learning_reconstruction():
+    n_atoms = 12
+    dico = DictionaryLearning(n_atoms, transform_algorithm='omp',
+                              transform_alpha=0.001, random_state=0)
+    code = dico.fit(X).transform(X)
+    assert_array_almost_equal(np.dot(code, dico.components_), X)
+
+    dico.set_params(transform_algorithm='lasso_lars')
+    code = dico.transform(X)
+    assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2)
+
+    dico.set_params(transform_algorithm='lars')
+    code = dico.transform(X)
+    assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2)
+
+
+def test_dict_learning_nonzero_coefs():
+    n_atoms = 4
+    dico = DictionaryLearning(n_atoms, transform_algorithm='lars',
+                              transform_n_nonzero_coefs=3, random_state=0)
+    code = dico.fit(X).transform(X[0])
+    assert len(np.flatnonzero(code)) == 3
+
+    dico.set_params(transform_algorithm='omp')
+    code = dico.transform(X[0])
+    assert len(np.flatnonzero(code)) == 3
+
+
+def test_dict_learning_split():
+    n_atoms = 5
+    dico = DictionaryLearning(n_atoms, transform_algorithm='threshold')
+    code = dico.fit(X).transform(X)
+    dico.split_sign = True
+    split_code = dico.transform(X)
+
+    assert_array_equal(split_code[:, :n_atoms] - split_code[:, n_atoms:], code)
+
+
+def test_dict_learning_online_shapes():
+    rng = np.random.RandomState(0)
+    X = rng.randn(12, 10)
+    dictionaryT, codeT = dict_learning_online(X.T, n_atoms=8, alpha=1,
+                                              random_state=rng)
+    assert_equal(codeT.shape, (8, 12))
+    assert_equal(dictionaryT.shape, (10, 8))
+    assert_equal(np.dot(codeT.T, dictionaryT.T).shape, X.shape)
+
+
+def test_dict_learning_online_estimator_shapes():
+    n_atoms = 5
+    dico = DictionaryLearningOnline(n_atoms, n_iter=20).fit(X)
+    assert dico.components_.shape == (n_atoms, n_features)
+
+
+def test_dict_learning_online_overcomplete():
+    n_atoms = 12
+    dico = DictionaryLearningOnline(n_atoms, n_iter=20).fit(X)
+    assert dico.components_.shape == (n_atoms, n_features)
+
+
+def test_dict_learning_online_initialization():
+    n_atoms = 12
+    V = rng.randn(n_atoms, n_features)
+    dico = DictionaryLearningOnline(n_atoms, n_iter=0, dict_init=V).fit(X)
+    assert_array_equal(dico.components_, V)
+
+
+def test_dict_learning_online_partial_fit():
+    n_atoms = 12
+    V = rng.randn(n_atoms, n_features)  # random init
+    rng1 = np.random.RandomState(0)
+    rng2 = np.random.RandomState(0)
+    dico1 = DictionaryLearningOnline(n_atoms, n_iter=10, chunk_size=1,
+                                     shuffle=False, dict_init=V,
+                                     transform_algorithm='threshold',
+                                     random_state=rng1).fit(X)
+    dico2 = DictionaryLearningOnline(n_atoms, n_iter=1, dict_init=V,
+                                     transform_algorithm='threshold',
+                                     random_state=rng2)
+    for ii, sample in enumerate(X):
+        dico2.partial_fit(sample, iter_offset=ii * dico2.n_iter)
+
+    assert_array_equal(dico1.components_, dico2.components_)
+
+
+def test_sparse_code():
+    rng = np.random.RandomState(0)
+    dictionary = rng.randn(10, 3)
+    real_code = np.zeros((3, 5))
+    real_code.ravel()[rng.randint(15, size=6)] = 1.0
+    Y = np.dot(dictionary, real_code)
+    est_code_1 = sparse_encode(dictionary, Y, alpha=1.0)
+    est_code_2 = sparse_encode_parallel(dictionary, Y, alpha=1.0)
+    assert_equal(est_code_1.shape, real_code.shape)
+    assert_equal(est_code_1, est_code_2)
+    assert_equal(est_code_1.nonzero(), real_code.nonzero())
diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py
index f7da34a75ae9c682f5bc3f677fa1ed89d0550065..e8e3cda667d0117a8c2b9e74f5dbdfa613c984b0 100644
--- a/sklearn/decomposition/tests/test_sparse_pca.py
+++ b/sklearn/decomposition/tests/test_sparse_pca.py
@@ -6,8 +6,7 @@ import sys
 import numpy as np
 from numpy.testing import assert_array_almost_equal, assert_equal
 
-from .. import SparsePCA, MiniBatchSparsePCA, dict_learning_online
-from ..sparse_pca import _update_code, _update_code_parallel
+from .. import SparsePCA, MiniBatchSparsePCA
 from ...utils import check_random_state
 
 
@@ -53,7 +52,8 @@ def test_correct_shapes():
 def test_fit_transform():
     rng = np.random.RandomState(0)
     Y, _, _ = generate_toy_data(3, 10, (8, 8), random_state=rng)  # wide array
-    spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng)
+    spca_lars = SparsePCA(n_components=3, method='lasso_lars',
+                          random_state=rng)
     spca_lars.fit(Y)
     U1 = spca_lars.transform(Y)
     # Test multiple CPUs
@@ -71,7 +71,7 @@ def test_fit_transform():
         U2 = spca.transform(Y)
     assert_array_almost_equal(U1, U2)
     # Test that CD gives similar results
-    spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng)
+    spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng)
     spca_lasso.fit(Y)
     assert_array_almost_equal(spca_lasso.components_, spca_lars.components_)
 
@@ -79,26 +79,14 @@ def test_fit_transform():
 def test_fit_transform_tall():
     rng = np.random.RandomState(0)
     Y, _, _ = generate_toy_data(3, 65, (8, 8), random_state=rng)  # tall array
-    spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng)
+    spca_lars = SparsePCA(n_components=3, method='lasso_lars',
+                          random_state=rng)
     U1 = spca_lars.fit_transform(Y)
-    spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng)
+    spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng)
     U2 = spca_lasso.fit(Y).transform(Y)
     assert_array_almost_equal(U1, U2)
 
 
-def test_sparse_code():
-    rng = np.random.RandomState(0)
-    dictionary = rng.randn(10, 3)
-    real_code = np.zeros((3, 5))
-    real_code.ravel()[rng.randint(15, size=6)] = 1.0
-    Y = np.dot(dictionary, real_code)
-    est_code_1 = _update_code(dictionary, Y, alpha=1.0)
-    est_code_2 = _update_code_parallel(dictionary, Y, alpha=1.0)
-    assert_equal(est_code_1.shape, real_code.shape)
-    assert_equal(est_code_1, est_code_2)
-    assert_equal(est_code_1.nonzero(), real_code.nonzero())
-
-
 def test_initialization():
     rng = np.random.RandomState(0)
     U_init = rng.randn(5, 3)
@@ -109,16 +97,6 @@ def test_initialization():
     assert_equal(model.components_, V_init)
 
 
-def test_dict_learning_online_shapes():
-    rng = np.random.RandomState(0)
-    X = rng.randn(12, 10)
-    dictionaryT, codeT = dict_learning_online(X.T, n_atoms=8, alpha=1,
-                                              random_state=rng)
-    assert_equal(codeT.shape, (8, 12))
-    assert_equal(dictionaryT.shape, (10, 8))
-    assert_equal(np.dot(codeT.T, dictionaryT.T).shape, X.shape)
-
-
 def test_mini_batch_correct_shapes():
     rng = np.random.RandomState(0)
     X = rng.randn(12, 10)
@@ -153,6 +131,6 @@ def test_mini_batch_fit_transform():
                                 random_state=rng).fit(Y).transform(Y)
     assert_array_almost_equal(U1, U2)
     # Test that CD gives similar results
-    spca_lasso = MiniBatchSparsePCA(n_components=3, method='cd',
+    spca_lasso = MiniBatchSparsePCA(n_components=3, method='lasso_cd',
                                     random_state=rng).fit(Y)
     assert_array_almost_equal(spca_lasso.components_, spca_lars.components_)
diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py
index 8ec6ab0b8cb15abf6cfcc749bd0a982cae378fc4..c8ed4277c5f076a0c0f0ea61279fb1ee484e8b8f 100644
--- a/sklearn/feature_extraction/image.py
+++ b/sklearn/feature_extraction/image.py
@@ -360,7 +360,6 @@ class PatchExtractor(BaseEstimator):
              `n_patches` is either `n_samples * max_patches` or the total
              number of patches that can be extracted.
 
-
         """
         self.random_state = check_random_state(self.random_state)
         n_images, i_h, i_w = X.shape[:3]
diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py
index 3c6e7bda6d6031b86b81831024f71109773b26b2..81b4f38ed8582e6c71e22747dab6d0c1f62d9fa2 100644
--- a/sklearn/linear_model/omp.py
+++ b/sklearn/linear_model/omp.py
@@ -20,7 +20,7 @@ dependence in the dictionary. The requested precision might not have been met.
 """
 
 
-def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False):
+def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, overwrite_X=False):
     """Orthogonal Matching Pursuit step using the Cholesky decomposition.
 
     Parameters:
@@ -34,7 +34,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False):
     n_nonzero_coefs: int
         Targeted number of non-zero elements
 
-    eps: float
+    tol: float
         Targeted squared error, if not None overrides n_nonzero_coefs.
 
     overwrite_X: bool,
@@ -66,7 +66,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False):
     n_active = 0
     idx = []
 
-    max_features = X.shape[1] if eps is not None else n_nonzero_coefs
+    max_features = X.shape[1] if tol is not None else n_nonzero_coefs
     L = np.empty((max_features, max_features), dtype=X.dtype)
     L[0, 0] = 1.
 
@@ -94,7 +94,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False):
                          overwrite_b=False)
 
         residual = y - np.dot(X[:, :n_active], gamma)
-        if eps is not None and nrm2(residual) ** 2 <= eps:
+        if tol is not None and nrm2(residual) ** 2 <= tol:
             break
         elif n_active == max_features:
             break
@@ -102,7 +102,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False):
     return gamma, idx
 
 
-def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None,
+def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None,
               overwrite_gram=False, overwrite_Xy=False):
     """Orthogonal Matching Pursuit step on a precomputed Gram matrix.
 
@@ -119,10 +119,10 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None,
     n_nonzero_coefs: int
         Targeted number of non-zero elements
 
-    eps_0: float
-        Squared norm of y, required if eps is not None.
+    tol_0: float
+        Squared norm of y, required if tol is not None.
 
-    eps: float
+    tol: float
         Targeted squared error, if not None overrides n_nonzero_coefs.
 
     overwrite_gram: bool,
@@ -157,11 +157,11 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None,
 
     idx = []
     alpha = Xy
-    eps_curr = eps_0
+    tol_curr = tol_0
     delta = 0
     n_active = 0
 
-    max_features = len(Gram) if eps is not None else n_nonzero_coefs
+    max_features = len(Gram) if tol is not None else n_nonzero_coefs
     L = np.empty((max_features, max_features), dtype=Gram.dtype)
     L[0, 0] = 1.
 
@@ -190,11 +190,11 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None,
 
         beta = np.dot(Gram[:, :n_active], gamma)
         alpha = Xy - beta
-        if eps is not None:
-            eps_curr += delta
+        if tol is not None:
+            tol_curr += delta
             delta = np.inner(gamma, beta[:n_active])
-            eps_curr -= delta
-            if eps_curr <= eps:
+            tol_curr -= delta
+            if tol_curr <= tol:
                 break
         elif n_active == max_features:
             break
@@ -202,7 +202,7 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None,
     return gamma, idx
 
 
-def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False,
+def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute_gram=False,
                   overwrite_X=False):
     """Orthogonal Matching Pursuit (OMP)
 
@@ -213,8 +213,8 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False,
     `n_nonzero_coefs`:
     argmin ||y - X\gamma||^2 subject to ||\gamma||_0 <= n_{nonzero coefs}
 
-    When parametrized by error using the parameter `eps`:
-    argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= \epsilon
+    When parametrized by error using the parameter `tol`:
+    argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= tol
 
     Parameters
     ----------
@@ -228,7 +228,7 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False,
         Desired number of non-zero entries in the solution. If None (by
         default) this value is set to 10% of n_features.
 
-    eps: float
+    tol: float
         Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
 
     precompute_gram: {True, False, 'auto'},
@@ -274,13 +274,13 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False,
         X = np.asfortranarray(X)
     if y.shape[1] > 1:  # subsequent targets will be affected
         overwrite_X = False
-    if n_nonzero_coefs == None and eps == None:
+    if n_nonzero_coefs == None and tol == None:
         n_nonzero_coefs = int(0.1 * X.shape[1])
-    if eps is not None and eps < 0:
+    if tol is not None and tol < 0:
         raise ValueError("Epsilon cannot be negative")
-    if eps is None and n_nonzero_coefs <= 0:
+    if tol is None and n_nonzero_coefs <= 0:
         raise ValueError("The number of atoms must be positive")
-    if eps is None and n_nonzero_coefs > X.shape[1]:
+    if tol is None and n_nonzero_coefs > X.shape[1]:
         raise ValueError("The number of atoms cannot be more than the number \
                           of features")
     if precompute_gram == 'auto':
@@ -289,23 +289,23 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False,
         G = np.dot(X.T, X)
         G = np.asfortranarray(G)
         Xy = np.dot(X.T, y)
-        if eps is not None:
+        if tol is not None:
             norms_squared = np.sum((y ** 2), axis=0)
         else:
             norms_squared = None
-        return orthogonal_mp_gram(G, Xy, n_nonzero_coefs, eps, norms_squared,
+        return orthogonal_mp_gram(G, Xy, n_nonzero_coefs, tol, norms_squared,
                                   overwrite_gram=overwrite_X,
                                   overwrite_Xy=True)
 
     coef = np.zeros((X.shape[1], y.shape[1]))
     for k in xrange(y.shape[1]):
-        x, idx = _cholesky_omp(X, y[:, k], n_nonzero_coefs, eps,
+        x, idx = _cholesky_omp(X, y[:, k], n_nonzero_coefs, tol,
                                overwrite_X=overwrite_X)
         coef[idx, k] = x
     return np.squeeze(coef)
 
 
-def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None,
+def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None,
                        norms_squared=None, overwrite_gram=False,
                        overwrite_Xy=False):
     """Gram Orthogonal Matching Pursuit (OMP)
@@ -325,11 +325,11 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None,
         Desired number of non-zero entries in the solution. If None (by
         default) this value is set to 10% of n_features.
 
-    eps: float
+    tol: float
         Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
 
     norms_squared: array-like, shape = (n_targets,)
-        Squared L2 norms of the lines of y. Required if eps is not None.
+        Squared L2 norms of the lines of y. Required if tol is not None.
 
     overwrite_gram: bool,
         Whether the gram matrix can be overwritten by the algorithm. This
@@ -366,25 +366,25 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None,
     Gram, Xy = map(np.asanyarray, (Gram, Xy))
     if Xy.ndim == 1:
         Xy = Xy[:, np.newaxis]
-        if eps is not None:
+        if tol is not None:
             norms_squared = [norms_squared]
 
-    if n_nonzero_coefs == None and eps is None:
+    if n_nonzero_coefs == None and tol is None:
         n_nonzero_coefs = int(0.1 * len(Gram))
-    if eps is not None and norms_squared == None:
+    if tol is not None and norms_squared == None:
         raise ValueError('Gram OMP needs the precomputed norms in order \
                           to evaluate the error sum of squares.')
-    if eps is not None and eps < 0:
+    if tol is not None and tol < 0:
         raise ValueError("Epsilon cennot be negative")
-    if eps is None and n_nonzero_coefs <= 0:
+    if tol is None and n_nonzero_coefs <= 0:
         raise ValueError("The number of atoms must be positive")
-    if eps is None and n_nonzero_coefs > len(Gram):
+    if tol is None and n_nonzero_coefs > len(Gram):
         raise ValueError("The number of atoms cannot be more than the number \
                           of features")
     coef = np.zeros((len(Gram), Xy.shape[1]))
     for k in range(Xy.shape[1]):
         x, idx = _gram_omp(Gram, Xy[:, k], n_nonzero_coefs,
-                           norms_squared[k] if eps is not None else None, eps,
+                           norms_squared[k] if tol is not None else None, tol,
                            overwrite_gram=overwrite_gram,
                            overwrite_Xy=overwrite_Xy)
         coef[idx, k] = x
@@ -400,7 +400,7 @@ class OrthogonalMatchingPursuit(LinearModel):
         Desired number of non-zero entries in the solution. If None (by
         default) this value is set to 10% of n_features.
 
-    eps: float, optional
+    tol: float, optional
         Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
 
     fit_intercept: boolean, optional
@@ -462,17 +462,16 @@ class OrthogonalMatchingPursuit(LinearModel):
 
     """
     def __init__(self, overwrite_X=False, overwrite_gram=False,
-            overwrite_Xy=False, n_nonzero_coefs=None, eps=None,
+            overwrite_Xy=False, n_nonzero_coefs=None, tol=None,
             fit_intercept=True, normalize=True, precompute_gram=False):
         self.n_nonzero_coefs = n_nonzero_coefs
-        self.eps = eps
+        self.tol = tol
         self.fit_intercept = fit_intercept
         self.normalize = normalize
         self.precompute_gram = precompute_gram
         self.overwrite_gram = overwrite_gram
         self.overwrite_Xy = overwrite_Xy
         self.overwrite_X = overwrite_X
-        self.eps = eps
 
     def fit(self, X, y, Gram=None, Xy=None):
         """Fit the model using X, y as training data.
@@ -507,7 +506,7 @@ class OrthogonalMatchingPursuit(LinearModel):
                                                         self.fit_intercept,
                                                         self.normalize)
 
-        if self.n_nonzero_coefs == None and self.eps is None:
+        if self.n_nonzero_coefs == None and self.tol is None:
             self.n_nonzero_coefs = int(0.1 * n_features)
 
         if Gram is not None:
@@ -538,15 +537,15 @@ class OrthogonalMatchingPursuit(LinearModel):
                 Gram /= X_std
                 Gram /= X_std[:, np.newaxis]
 
-            norms_sq = np.sum(y ** 2, axis=0) if self.eps is not None else None
+            norms_sq = np.sum(y ** 2, axis=0) if self.tol is not None else None
             self.coef_ = orthogonal_mp_gram(Gram, Xy, self.n_nonzero_coefs,
-                                            self.eps, norms_sq,
+                                            self.tol, norms_sq,
                                             overwrite_gram, True).T
         else:
             precompute_gram = self.precompute_gram
             if precompute_gram == 'auto':
                 precompute_gram = X.shape[0] > X.shape[1]
-            self.coef_ = orthogonal_mp(X, y, self.n_nonzero_coefs, self.eps,
+            self.coef_ = orthogonal_mp(X, y, self.n_nonzero_coefs, self.tol,
                                        precompute_gram=self.precompute_gram,
                                        overwrite_X=self.overwrite_X).T
 
diff --git a/sklearn/linear_model/tests/test_omp.py b/sklearn/linear_model/tests/test_omp.py
index 44d73ef7e48d74c516c95cc27c4f3e3deee68dda..71f22b78a334c4ee363ba6d46d8c536396f2663b 100644
--- a/sklearn/linear_model/tests/test_omp.py
+++ b/sklearn/linear_model/tests/test_omp.py
@@ -39,12 +39,12 @@ def test_n_nonzero_coefs():
                                        precompute_gram=True)) <= 5
 
 
-def test_eps():
-    eps = 0.5
-    gamma = orthogonal_mp(X, y[:, 0], eps=eps)
-    gamma_gram = orthogonal_mp(X, y[:, 0], eps=eps, precompute_gram=True)
-    assert np.sum((y[:, 0] - np.dot(X, gamma)) ** 2) <= eps
-    assert np.sum((y[:, 0] - np.dot(X, gamma_gram)) ** 2) <= eps
+def test_tol():
+    tol = 0.5
+    gamma = orthogonal_mp(X, y[:, 0], tol=tol)
+    gamma_gram = orthogonal_mp(X, y[:, 0], tol=tol, precompute_gram=True)
+    assert np.sum((y[:, 0] - np.dot(X, gamma)) ** 2) <= tol
+    assert np.sum((y[:, 0] - np.dot(X, gamma_gram)) ** 2) <= tol
 
 
 def test_with_without_gram():
@@ -53,32 +53,32 @@ def test_with_without_gram():
         orthogonal_mp(X, y, n_nonzero_coefs=5, precompute_gram=True))
 
 
-def test_with_without_gram_eps():
+def test_with_without_gram_tol():
     assert_array_almost_equal(
-        orthogonal_mp(X, y, eps=1.),
-        orthogonal_mp(X, y, eps=1., precompute_gram=True))
+        orthogonal_mp(X, y, tol=1.),
+        orthogonal_mp(X, y, tol=1., precompute_gram=True))
 
 
 def test_unreachable_accuracy():
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter('always')
         assert_array_almost_equal(
-            orthogonal_mp(X, y, eps=0),
+            orthogonal_mp(X, y, tol=0),
             orthogonal_mp(X, y, n_nonzero_coefs=n_features))
 
         assert_array_almost_equal(
-            orthogonal_mp(X, y, eps=0, precompute_gram=True),
+            orthogonal_mp(X, y, tol=0, precompute_gram=True),
             orthogonal_mp(X, y, precompute_gram=True,
                           n_nonzero_coefs=n_features))
         assert len(w) > 0  # warnings should be raised
 
 
 def test_bad_input():
-    assert_raises(ValueError, orthogonal_mp, X, y, eps=-1)
+    assert_raises(ValueError, orthogonal_mp, X, y, tol=-1)
     assert_raises(ValueError, orthogonal_mp, X, y, n_nonzero_coefs=-1)
     assert_raises(ValueError, orthogonal_mp, X, y,
                   n_nonzero_coefs=n_features + 1)
-    assert_raises(ValueError, orthogonal_mp_gram, G, Xy, eps=-1)
+    assert_raises(ValueError, orthogonal_mp_gram, G, Xy, tol=-1)
     assert_raises(ValueError, orthogonal_mp_gram, G, Xy, n_nonzero_coefs=-1)
     assert_raises(ValueError, orthogonal_mp_gram, G, Xy,
                   n_nonzero_coefs=n_features + 1)