diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 8c7df440f4ddb2d707054bf94ca3995968ba0c49..76eb410a37d9215ec87b0b2b13885db135ac1d90 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -377,6 +377,8 @@ Signal Decomposition decomposition.NMF decomposition.SparsePCA decomposition.MiniBatchSparsePCA + decomposition.DictionaryLearning + decomposition.DictionaryLearningOnline .. autosummary:: :toctree: generated/ diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index 4b5633003c7ed0743eea25d1fd82f05adcabc97f..059f09cd45d79583280cb43a43d4fad602b55698 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -347,3 +347,105 @@ of the data. matrix factorization" <http://www.cs.rpi.edu/~boutsc/files/nndsvd.pdf>`_ C. Boutsidis, E. Gallopoulos, 2008 + + + +.. _DictionaryLearning: + +Dictionary Learning +=================== + +Generic dictionary learning +--------------------------- + +Dictionary learning (:class:`DictionaryLearning`) is a matrix factorization +problem that amounts to finding a (usually overcomplete) dictionary that will +perform good at sparsely encoding the fitted data. + +Representing data as sparse combinations of atoms from an overcomplete +dictionary is suggested to be the way the mammal primary visual cortex works. +Consequently, dictionary learning applied on image patches has been shown to +give good results in image processing tasks such as image completion, +inpainting and denoising, as well as for supervised recognition tasks. + +Dictionary learning is an optimization problem solved by alternatively updating +the sparse code, as a solution to multiple Lasso problems, considering the +dictionary fixed, and then updating the dictionary to best fit the sparse code. + +After using such a procedure to fit the dictionary, the fitted object can be +used to transform new data. The transformation amounts to a sparse coding +problem: finding a representation of the data as a linear combination of as few +dictionary atoms as possible. All variations of dictionary learning implement +the following transform methods, controllable via the `transform_method` +initialization parameter: + + +* Orthogonal matching pursuit (:ref:`omp`) + +* Least-angle regression (:ref:`least_angle_regression`) + +* Lasso computed by least-angle regression + +* Lasso using coordinate descent (:ref:`lasso`) + +* Thresholding + +Thresholding is very fast but it does not yield accurate reconstructions. +They have been shown useful in literature for classification tasks. For image +reconstruction tasks, orthogonal matching pursuit yields the most accurate, +unbiased reconstruction. + +The dictionary learning objects offer, via the `split_code` parameter, the +possibility to separate the positive and negative values in the results of +sparse coding. This is useful when dictionary learning is used for extracting +features that will be used for supervised learning, because it allows the +learning algorithm to assign different weights to negative loadings of a +particular atom, from to the corresponding positive loading. + +The split code for a single sample has length `2 * n_atoms` +and is constructed using the following rule: First, the regular code of length +`n_atoms` is computed. Then, the first `n_atoms` entries of the split_code are +filled with the positive part of the regular code vector. The second half of +the split code is filled with the negative part of the code vector, only with +a positive sign. Therefore, the split_code is non-negative. + +The following image shows how a dictionary learned from 4x4 pixel image patches +extracted from part of the image of Lena looks like. + + +.. figure:: ../auto_examples/decomposition/images/plot_img_denoising_1.png + :target: ../auto_examples/decomposition/plot_img_denoising.html + :align: center + :scale: 50% + + +.. topic:: Examples: + + * :ref:`example_decomposition_plot_img_denoising.py` + + +.. topic:: References: + + * `"Online dictionary learning for sparse coding" + <http://www.di.ens.fr/sierra/pdfs/icml09.pdf>`_ + J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009 + +.. _DictionaryLearningOnline + +Online dictionary learning +-------------------------- + +:class:`DictionaryLearningOnline` implements a faster, but less accurate +version of the dictionary learning algorithm that is better suited for large +datasets. + +By default, :class:`DictionaryLearningOnline` divides the data into +mini-batches and optimizes in an online manner by cycling over the mini-batches +for the specified number of iterations. However, at the moment it does not +implement a stopping condition. + +The estimator also implements `partial_fit`, which updates the dictionary by +iterating only once over a mini-batch. This can be used for online learning +when the data is not readily available from the start, or for when the data +does not fit into the memory. + diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index 1d1da635f091502f8d2f5288c0130d5efc6153a8..3038716fbae0b7d36fb79c006b2b0c9619f6c29a 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -358,7 +358,8 @@ column is always zero. <http://www-stat.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf>`_ by Hastie et al. -.. _OMP: + +.. _omp: Orthogonal Matching Pursuit (OMP) ================================= @@ -370,12 +371,12 @@ Being a forward feature selection method like :ref:`least_angle_regression`, orthogonal matching pursuit can approximate the optimum solution vector with a fixed number of non-zero elements: -.. math:: \text{arg\,min} ||y - X\gamma||_2^2 \text{ subject to } ||\gamma||_0 \leq n_{features} +.. math:: \text{arg\,min} ||y - X\gamma||_2^2 \text{ subject to } ||\gamma||_0 \leq n_{nonzero_coefs} Alternatively, orthogonal matching pursuit can target a specific error instead of a specific number of non-zero coefficients. This can be expressed as: -.. math:: \text{arg\,min} ||\gamma||_0 \text{ subject to } ||y-X\gamma||_2^2 \leq \varepsilon +.. math:: \text{arg\,min} ||\gamma||_0 \text{ subject to } ||y-X\gamma||_2^2 \leq \text{tol} OMP is based on a greedy algorithm that includes at each step the atom most diff --git a/examples/decomposition/plot_faces_decomposition.py b/examples/decomposition/plot_faces_decomposition.py index 397b865df700a94b4e708b1293278031d396da4e..1cd41b9e9fc9a6180b9c2177913a5fdd154d0495 100644 --- a/examples/decomposition/plot_faces_decomposition.py +++ b/examples/decomposition/plot_faces_decomposition.py @@ -82,6 +82,11 @@ estimators = [ n_iter=100, chunk_size=3), True, False), + ('Dictionary atoms - DictionaryLearningOnline', + decomposition.DictionaryLearningOnline(n_atoms=n_components, alpha=1e-3, + n_iter=100, chunk_size=3), + True, False), + ('Cluster centers - MiniBatchKMeans', MiniBatchKMeans(k=n_components, tol=1e-3, chunk_size=20, max_iter=50), True, False) diff --git a/examples/decomposition/plot_img_denoising.py b/examples/decomposition/plot_img_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..4c91db2497f735a3e621cb4d520d04da238ec6bb --- /dev/null +++ b/examples/decomposition/plot_img_denoising.py @@ -0,0 +1,151 @@ +""" +========================================= +Image denoising using dictionary learning +========================================= + +An example comparing the effect of reconstructing noisy fragments +of Lena using online :ref:`DictionaryLearning` and various transform methods. + +The dictionary is fitted on the non-distorted left half of the image, and +subsequently used to reconstruct the right half. + +A common practice for evaluating the results of image denoising is by looking +at the difference between the reconstruction and the original image. If the +reconstruction is perfect this will look like gaussian noise. + +It can be seen from the plots that the results of :ref:`omp` with two +non-zero coefficients is a bit less biased than when keeping only one (the +edges look less prominent). However, it is farther from the ground truth in +Frobenius norm. + +The result of :ref:`least_angle_regression` is much more strongly biased: the +difference is reminiscent of the local intensity value of the original image. + +Thresholding is clearly not useful for denoising, but it is here to show that +it can produce a suggestive output with very high speed, and thus be useful +for other tasks such as object classification, where performance is not +necessarily related to visualisation. + +""" +print __doc__ + +from time import time + +import pylab as pl +import scipy as sp +import numpy as np + +from sklearn.decomposition import DictionaryLearningOnline +from sklearn.feature_extraction.image import extract_patches_2d, \ + reconstruct_from_patches_2d + +############################################################################### +# Load Lena image and extract patches +lena = sp.lena() / 256.0 + +# downsample for higher speed +lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2] +lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2] +lena /= 16.0 +height, width = lena.shape + +# Distort the right half of the image +print 'Distorting image...' +distorted = lena.copy() +distorted[:, height / 2:] += 0.075 * np.random.randn(width, height / 2) + +# Extract all clean patches from the left half of the image +print 'Extracting clean patches...' +patch_size = (7, 7) +data = extract_patches_2d(distorted[:, :height / 2], patch_size) +data = data.reshape(data.shape[0], -1) +data -= np.mean(data, axis=0) +data /= np.std(data, axis=0) + +############################################################################### +# Learn the dictionary from clean patches +print 'Learning the dictionary... ', +t0 = time() +dico = DictionaryLearningOnline(n_atoms=100, alpha=1e-2, n_iter=500) +V = dico.fit(data).components_ +dt = time() - t0 +print 'done in %.2f.' % dt + +pl.figure(figsize=(4.2, 4)) +for i, comp in enumerate(V[:100]): + pl.subplot(10, 10, i + 1) + pl.imshow(comp.reshape(patch_size), cmap=pl.cm.gray_r, + interpolation='nearest') + pl.xticks(()) + pl.yticks(()) +pl.suptitle('Dictionary learned from Lena patches\n' + + 'Train time %.1fs on %d patches' % (dt, len(data)), + fontsize=16) +pl.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23) + + +def show_with_diff(image, reference, title): + """Helper function to display denoising""" + pl.figure(figsize=(5, 3.3)) + pl.subplot(1, 2, 1) + pl.title('Image') + pl.imshow(image, vmin=0, vmax=1, cmap=pl.cm.gray, interpolation='nearest') + pl.xticks(()) + pl.yticks(()) + pl.subplot(1, 2, 2) + difference = image - reference + + pl.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2))) + pl.imshow(difference, vmin=-0.5, vmax=0.5, cmap=pl.cm.PuOr, + interpolation='nearest') + pl.xticks(()) + pl.yticks(()) + pl.suptitle(title, size=16) + pl.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2) + +############################################################################### +# Display the distorted image +show_with_diff(distorted, lena, 'Distorted image') + +############################################################################### +# Extract noisy patches and reconstruct them using the dictionary +print 'Extracting noisy patches... ' +data = extract_patches_2d(distorted[:, height / 2:], patch_size) +data = data.reshape(data.shape[0], -1) +intercept = np.mean(data, axis=0) +data -= intercept + +transform_algorithms = [ + ('Orthogonal Matching Pursuit\n1 atom', 'omp', + {'transform_n_nonzero_coefs': 1}), + ('Orthogonal Matching Pursuit\n2 atoms', 'omp', + {'transform_n_nonzero_coefs': 2}), + ('Least-angle regression\n5 atoms', 'lars', {'transform_n_nonzero_coefs': 5}), + ('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})] + +reconstructions = {} +for title, transform_algorithm, kwargs in transform_algorithms: + print title, '... ', + reconstructions[title] = lena.copy() + t0 = time() + dico.set_params(transform_algorithm=transform_algorithm, **kwargs) + code = dico.transform(data) + patches = np.dot(code, V) + + if transform_algorithm == 'threshold': + patches -= patches.min() + patches /= patches.max() + + patches += intercept + patches = patches.reshape(len(data), *patch_size) + if transform_algorithm == 'threshold': + patches -= patches.min() + patches /= patches.max() + reconstructions[title][:, height / 2:] = reconstruct_from_patches_2d( + patches, (width, height / 2)) + dt = time() - t0 + print 'done in %.2f.' % dt + show_with_diff(reconstructions[title], lena, + title + ' (time: %.1fs)' % dt) + +pl.show() diff --git a/scikits/learn/decomposition/__init__.py b/scikits/learn/decomposition/__init__.py index 3943c2ed8da19edc7f74144d29dfe640f3b86efb..dae5d9b576d79489365bb33d6b580a1ff9a33e71 100644 --- a/scikits/learn/decomposition/__init__.py +++ b/scikits/learn/decomposition/__init__.py @@ -1,3 +1,3 @@ import warnings warnings.warn('scikits.learn namespace is deprecated, please use sklearn instead') -from sklearn.decomposition import * \ No newline at end of file +from sklearn.decomposition import * diff --git a/sklearn/decomposition/__init__.py b/sklearn/decomposition/__init__.py index ead01092dab7c78da2abef7501ea9edaa7cc3d1d..fbc9e2f33ad1d13bacfe9561452b41417b2703e8 100644 --- a/sklearn/decomposition/__init__.py +++ b/sklearn/decomposition/__init__.py @@ -5,6 +5,7 @@ Matrix decomposition algorithms from .nmf import NMF, ProjectedGradientNMF from .pca import PCA, RandomizedPCA, ProbabilisticPCA from .kernel_pca import KernelPCA -from .sparse_pca import SparsePCA, MiniBatchSparsePCA, dict_learning, \ - dict_learning_online +from .sparse_pca import SparsePCA, MiniBatchSparsePCA from .fastica_ import FastICA, fastica +from .dict_learning import dict_learning, dict_learning_online, \ + DictionaryLearning, DictionaryLearningOnline diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..03365cde25078b4e587c0accabd0f456d0921a0e --- /dev/null +++ b/sklearn/decomposition/dict_learning.py @@ -0,0 +1,993 @@ +""" Dictionary learning +""" +# Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort +# License: BSD + +import time +import sys +import itertools + +from math import sqrt, floor, ceil + +import numpy as np +from scipy import linalg +from numpy.lib.stride_tricks import as_strided + +from ..base import BaseEstimator, TransformerMixin +from ..externals.joblib import Parallel, delayed, cpu_count +from ..utils import check_random_state +from ..utils import gen_even_slices +from ..utils.extmath import fast_svd +from ..linear_model import Lasso, orthogonal_mp_gram, lars_path + + +def sparse_encode(X, Y, gram=None, cov=None, algorithm='lasso_lars', + n_nonzero_coefs=None, alpha=None, + overwrite_gram=False, overwrite_cov=False, init=None): + """Generic sparse coding + + Each column of the result is the solution to a Lasso problem. + + Parameters + ---------- + X: array of shape (n_samples, n_components) + Dictionary against which to optimize the sparse code. + + Y: array of shape (n_samples, n_features) + Data matrix. + + gram: array, shape=(n_components, n_components) + Precomputed Gram matrix, X^T * X + + cov: array, shape=(n_components, n_features) + Precomputed covariance, X^T * Y + + algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} + lars: uses the least angle regression method (linear_model.lars_path) + lasso_lars: uses Lars to compute the Lasso solution + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). lasso_lars will be faster if + the estimated components are sparse. + omp: uses orthogonal matching pursuit to estimate the sparse solution + threshold: squashes to zero all coefficients less than alpha from + the projection X.T * Y + + n_nonzero_coefs: int, 0.1 * n_features by default + Number of nonzero coefficients to target in each column of the + solution. This is only used by `algorithm='lars'` and `algorithm='omp'` + and is overridden by `alpha` in the `omp` case. + + alpha: float, 1. by default + If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the + penalty applied to the L1 norm. + If `algorithm='threhold'`, `alpha` is the absolute value of the + threshold below which coefficients will be squashed to zero. + If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of + the reconstruction error targeted. In this case, it overrides + `n_nonzero_coefs`. + + init: array of shape (n_components, n_features) + Initialization value of the sparse codes. Only used if + `algorithm='lasso_cd'`. + + overwrite_gram: boolean, + Whether to overwrite the precomputed Gram matrix. + + overwrite_cov: boolean, + Whether to overwrite the precomputed covariance matrix. + + Returns + ------- + code: array of shape (n_components, n_features) + The sparse codes + """ + alpha = float(alpha) if alpha is not None else None + X, Y = map(np.asanyarray, (X, Y)) + if Y.ndim == 1: + Y = Y[:, np.newaxis] + n_features = Y.shape[1] + # This will always use Gram + if gram is None: + # I think it's never safe to overwrite Gram when n_features > 1 + # but I'd like to avoid the complicated logic. + # The parameter could be removed in this case. Discuss. + gram = np.dot(X.T, X) + if cov is None and algorithm != 'lasso_cd': + # overwrite_cov is safe + overwrite_cov = True + cov = np.dot(X.T, Y) + + if algorithm == 'lasso_lars': + if alpha is None: + alpha = 1. + try: + new_code = np.empty((X.shape[1], n_features)) + err_mgt = np.seterr(all='ignore') + for k in range(n_features): + # A huge amount of time is spent in this loop. It needs to be + # tight. + _, _, coef_path_ = lars_path(X, Y[:, k], Xy=cov[:, k], + Gram=gram, alpha_min=alpha, + method='lasso') + new_code[:, k] = coef_path_[:, -1] + finally: + np.seterr(**err_mgt) + + elif algorithm == 'lasso_cd': + if alpha is None: + alpha = 1. + new_code = np.empty((X.shape[1], n_features)) + clf = Lasso(alpha=alpha, fit_intercept=False, precompute=gram, + max_iter=1000) + for k in xrange(n_features): + # A huge amount of time is spent in this loop. It needs to be + # tight. + if init is not None: + clf.coef_ = init[:, k] # Init with previous value of Vk + clf.fit(X, Y[:, k]) + new_code[:, k] = clf.coef_ + + elif algorithm == 'lars': + if n_nonzero_coefs is None: + n_nonzero_coefs = n_features / 10 + try: + new_code = np.empty((X.shape[1], n_features)) + err_mgt = np.seterr(all='ignore') + for k in xrange(n_features): + # A huge amount of time is spent in this loop. It needs to be + # tight. + _, _, coef_path_ = lars_path(X, Y[:, k], Xy=cov[:, k], + Gram=gram, method='lar', + max_iter=n_nonzero_coefs) + new_code[:, k] = coef_path_[:, -1] + finally: + np.seterr(**err_mgt) + + elif algorithm == 'threshold': + if alpha is None: + alpha = 1. + new_code = np.sign(cov) * np.maximum(np.abs(cov) - alpha, 0) + + elif algorithm == 'omp': + if n_nonzero_coefs is None and alpha is None: + n_nonzero_coefs = n_features / 10 + norms_squared = np.sum((Y ** 2), axis=0) + new_code = orthogonal_mp_gram(gram, cov, n_nonzero_coefs, alpha, + norms_squared, overwrite_Xy=overwrite_cov + ) + else: + raise NotImplemented('Sparse coding method %s not implemented' % + algorithm) + return new_code + + +def sparse_encode_parallel(X, Y, gram=None, cov=None, algorithm='lasso_lars', + n_nonzero_coefs=None, alpha=None, overwrite_gram=False, + overwrite_cov=False, init=None, n_jobs=1): + """Parallel sparse coding using joblib + + Each column of the result is the solution to a Lasso problem. + + Parameters + ---------- + X: array of shape (n_samples, n_components) + Dictionary against which to optimize the sparse code. + + Y: array of shape (n_samples, n_features) + Data matrix. + + gram: array, shape=(n_components, n_components) + Precomputed Gram matrix, X^T * X + + cov: array, shape=(n_components, n_features) + Precomputed covariance, X^T * Y + + algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} + lars: uses the least angle regression method (linear_model.lars_path) + lasso_lars: uses Lars to compute the Lasso solution + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). lasso_lars will be faster if + the estimated components are sparse. + omp: uses orthogonal matching pursuit to estimate the sparse solution + threshold: squashes to zero all coefficients less than alpha from + the projection X.T * Y + + n_nonzero_coefs: int, 0.1 * n_features by default + Number of nonzero coefficients to target in each column of the + solution. This is only used by `algorithm='lars'` and `algorithm='omp'` + and is overridden by `alpha` in the `omp` case. + + alpha: float, 1. by default + If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the + penalty applied to the L1 norm. + If `algorithm='threhold'`, `alpha` is the absolute value of the + threshold below which coefficients will be squashed to zero. + If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of + the reconstruction error targeted. In this case, it overrides + `n_nonzero_coefs`. + + init: array of shape (n_components, n_features) + Initialization value of the sparse codes. Only used if + `algorithm='lasso_cd'`. + + overwrite_gram: boolean, + Whether to overwrite the precomputed Gram matrix. + + overwrite_cov: boolean, + Whether to overwrite the precomputed covariance matrix. + + n_jobs: int, + Number of parallel jobs to run. + + Returns + ------- + code: array of shape (n_components, n_features) + The sparse codes + """ + n_samples, n_features = Y.shape + n_components = X.shape[1] + if gram is None: + overwrite_gram = True + gram = np.dot(X.T, X) + if cov is None and algorithm != 'lasso_cd': + overwrite_cov = True + cov = np.dot(X.T, Y) + if n_jobs == 1 or algorithm == 'threshold': + return sparse_encode(X, Y, gram, cov, algorithm, n_nonzero_coefs, + alpha, overwrite_gram, overwrite_cov, init) + code = np.empty((n_components, n_features)) + slices = list(gen_even_slices(n_features, n_jobs)) + code_views = Parallel(n_jobs=n_jobs)( + delayed(sparse_encode)(X, Y[:, this_slice], gram, + cov[:, this_slice], algorithm, + n_nonzero_coefs, alpha, + overwrite_gram, overwrite_cov, + init=init[:, this_slice] if init is not + None else None) + for this_slice in slices) + for this_slice, this_view in zip(slices, code_views): + code[:, this_slice] = this_view + return code + + +def _update_dict(dictionary, Y, code, verbose=False, return_r2=False, + random_state=None): + """Update the dense dictionary factor in place. + + Parameters + ---------- + dictionary: array of shape (n_samples, n_components) + Value of the dictionary at the previous iteration. + + Y: array of shape (n_samples, n_features) + Data matrix. + + code: array of shape (n_components, n_features) + Sparse coding of the data against which to optimize the dictionary. + + verbose: + Degree of output the procedure will print. + + return_r2: bool + Whether to compute and return the residual sum of squares corresponding + to the computed solution. + + random_state: int or RandomState + Pseudo number generator state used for random sampling. + + Returns + ------- + dictionary: array of shape (n_samples, n_components) + Updated dictionary. + + """ + n_atoms = len(code) + n_samples = Y.shape[0] + random_state = check_random_state(random_state) + # Residuals, computed 'in-place' for efficiency + R = -np.dot(dictionary, code) + R += Y + R = np.asfortranarray(R) + ger, = linalg.get_blas_funcs(('ger',), (dictionary, code)) + for k in xrange(n_atoms): + # R <- 1.0 * U_k * V_k^T + R + R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True) + dictionary[:, k] = np.dot(R, code[k, :].T) + # Scale k'th atom + atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k]) + if atom_norm_square < 1e-20: + if verbose == 1: + sys.stdout.write("+") + sys.stdout.flush() + elif verbose: + print "Adding new random atom" + dictionary[:, k] = random_state.randn(n_samples) + # Setting corresponding coefs to 0 + code[k, :] = 0.0 + dictionary[:, k] /= sqrt(np.dot(dictionary[:, k], + dictionary[:, k])) + else: + dictionary[:, k] /= sqrt(atom_norm_square) + # R <- -1.0 * U_k * V_k^T + R + R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True) + if return_r2: + R **= 2 + # R is fortran-ordered. For numpy version < 1.6, sum does not + # follow the quick striding first, and is thus inefficient on + # fortran ordered data. We take a flat view of the data with no + # striding + R = as_strided(R, shape=(R.size, ), strides=(R.dtype.itemsize,)) + R = np.sum(R) + return dictionary, R + return dictionary + + +def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, + method='lasso_lars', n_jobs=1, dict_init=None, + code_init=None, callback=None, verbose=False, + random_state=None): + """Solves a dictionary learning matrix factorization problem. + + Finds the best dictionary and the corresponding sparse code for + approximating the data matrix X by solving:: + + (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1 + (U,V) + with || V_k ||_2 = 1 for all 0 <= k < n_atoms + + where V is the dictionary and U is the sparse code. + + Parameters + ---------- + X: array of shape (n_samples, n_features) + Data matrix. + + n_atoms: int, + Number of dictionary atoms to extract. + + alpha: int, + Sparsity controlling parameter. + + max_iter: int, + Maximum number of iterations to perform. + + tol: float, + Tolerance for the stopping condition. + + method: {'lasso_lars', 'lasso_cd'} + lasso_lars: uses the least angle regression method + (linear_model.lars_path) + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). Lars will be faster if + the estimated components are sparse. + + n_jobs: int, + Number of parallel jobs to run, or -1 to autodetect. + + dict_init: array of shape (n_atoms, n_features), + Initial value for the dictionary for warm restart scenarios. + + code_init: array of shape (n_samples, n_atoms), + Initial value for the sparse code for warm restart scenarios. + + callback: + Callable that gets invoked every five iterations. + + verbose: + Degree of output the procedure will print. + + random_state: int or RandomState + Pseudo number generator state used for random sampling. + + Returns + ------- + code: array of shape (n_samples, n_atoms) + The sparse code factor in the matrix factorization. + + dictionary: array of shape (n_atoms, n_features), + The dictionary factor in the matrix factorization. + + errors: array + Vector of errors at each iteration. + + """ + if method not in ('lasso_lars', 'lasso_cd'): + raise ValueError('Coding method not supported as a fit algorithm.') + t0 = time.time() + n_features = X.shape[1] + # Avoid integer division problems + alpha = float(alpha) + random_state = check_random_state(random_state) + + if n_jobs == -1: + n_jobs = cpu_count() + + # Init U and V with SVD of Y + if code_init is not None and code_init is not None: + code = np.array(code_init, order='F') + # Don't copy V, it will happen below + dictionary = dict_init + else: + code, S, dictionary = linalg.svd(X, full_matrices=False) + dictionary = S[:, np.newaxis] * dictionary + r = len(dictionary) + if n_atoms <= r: # True even if n_atoms=None + code = code[:, :n_atoms] + dictionary = dictionary[:n_atoms, :] + else: + code = np.c_[code, np.zeros((len(code), n_atoms - r))] + dictionary = np.r_[dictionary, + np.zeros((n_atoms - r, dictionary.shape[1]))] + + # Fortran-order dict, as we are going to access its row vectors + dictionary = np.array(dictionary, order='F') + + residuals = 0 + + errors = [] + current_cost = np.nan + + if verbose == 1: + print '[dict_learning]', + + for ii in xrange(max_iter): + dt = (time.time() - t0) + if verbose == 1: + sys.stdout.write(".") + sys.stdout.flush() + elif verbose: + print ("Iteration % 3i " + "(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)" % + (ii, dt, dt / 60, current_cost)) + + # Update code + code = sparse_encode_parallel(dictionary.T, X.T, algorithm=method, + alpha=alpha / n_features, + init=code.T, n_jobs=n_jobs) + code = code.T + # Update dictionary + dictionary, residuals = _update_dict(dictionary.T, X.T, code.T, + verbose=verbose, return_r2=True, + random_state=random_state) + dictionary = dictionary.T + + # Cost function + current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code)) + errors.append(current_cost) + + if ii > 0: + dE = errors[-2] - errors[-1] + # assert(dE >= -tol * errors[-1]) + if dE < tol * errors[-1]: + if verbose == 1: + # A line return + print "" + elif verbose: + print "--- Convergence reached after %d iterations" % ii + break + if ii % 5 == 0 and callback is not None: + callback(locals()) + + return code, dictionary, errors + + +def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True, + dict_init=None, callback=None, chunk_size=3, + verbose=False, shuffle=True, n_jobs=1, + method='lasso_lars', iter_offset=0, + random_state=None): + """Solves a dictionary learning matrix factorization problem online. + + Finds the best dictionary and the corresponding sparse code for + approximating the data matrix X by solving: + + (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1 + (U,V) + with || V_k ||_2 = 1 for all 0 <= k < n_atoms + + where V is the dictionary and U is the sparse code. This is + accomplished by repeatedly iterating over mini-batches by slicing + the input data. + + Parameters + ---------- + X: array of shape (n_samples, n_features) + data matrix + + n_atoms: int, + number of dictionary atoms to extract + + alpha: int, + sparsity controlling parameter + + n_iter: int, + number of iterations to perform + + return_code: boolean, + whether to also return the code U or just the dictionary V + + dict_init: array of shape (n_atoms, n_features), + initial value for the dictionary for warm restart scenarios + + callback: + callable that gets invoked every five iterations + + chunk_size: int, + the number of samples to take in each batch + + verbose: + degree of output the procedure will print + + shuffle: boolean, + whether to shuffle the data before splitting it in batches + + n_jobs: int, + number of parallel jobs to run, or -1 to autodetect. + + method: {'lasso_lars', 'lasso_cd'} + lasso_lars: uses the least angle regression method + (linear_model.lars_path) + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). Lars will be faster if + the estimated components are sparse. + + iter_offset: int, default 0 + number of previous iterations completed on the dictionary used for + initialization + + random_state: int or RandomState + Pseudo number generator state used for random sampling. + + Returns + ------- + dictionary: array of shape (n_atoms, n_features), + the solutions to the dictionary learning problem + + code: array of shape (n_samples, n_atoms), + the sparse code (only returned if `return_code=True`) + """ + if method not in ('lasso_lars', 'lasso_cd'): + raise ValueError('Coding method not supported as a fit algorithm.') + t0 = time.time() + n_samples, n_features = X.shape + # Avoid integer division problems + alpha = float(alpha) + random_state = check_random_state(random_state) + + if n_jobs == -1: + n_jobs = cpu_count() + + # Init V with SVD of X + if dict_init is not None: + dictionary = dict_init + else: + _, S, dictionary = fast_svd(X, n_atoms) + dictionary = S[:, np.newaxis] * dictionary + r = len(dictionary) + if n_atoms <= r: + dictionary = dictionary[:n_atoms, :] + else: + dictionary = np.r_[dictionary, + np.zeros((n_atoms - r, dictionary.shape[1]))] + dictionary = np.ascontiguousarray(dictionary.T) + + if verbose == 1: + print '[dict_learning]', + + n_batches = floor(float(len(X)) / chunk_size) + if shuffle: + X_train = X.copy() + random_state.shuffle(X_train) + else: + X_train = X + batches = np.array_split(X_train, n_batches) + batches = itertools.cycle(batches) + + # The covariance of the dictionary + A = np.zeros((n_atoms, n_atoms)) + # The data approximation + B = np.zeros((n_features, n_atoms)) + + for ii, this_X in itertools.izip(xrange(iter_offset, iter_offset + n_iter), + batches): + dt = (time.time() - t0) + if verbose == 1: + sys.stdout.write(".") + sys.stdout.flush() + elif verbose: + if verbose > 10 or ii % ceil(100. / verbose) == 0: + print ("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)" % + (ii, dt, dt / 60)) + + this_code = sparse_encode(dictionary, this_X.T, algorithm=method, + alpha=alpha) + + # Update the auxiliary variables + if ii < chunk_size - 1: + theta = float((ii + 1) * chunk_size) + else: + theta = float(chunk_size ** 2 + ii + 1 - chunk_size) + beta = (theta + 1 - chunk_size) / (theta + 1) + + A *= beta + A += np.dot(this_code, this_code.T) + B *= beta + B += np.dot(this_X.T, this_code.T) + + # Update dictionary + dictionary = _update_dict(dictionary, B, A, verbose=verbose, + random_state=random_state) + # XXX: Can the residuals be of any use? + + # Maybe we need a stopping criteria based on the amount of + # modification in the dictionary + if callback is not None: + callback(locals()) + + if return_code: + if verbose > 1: + print 'Learning code...', + elif verbose == 1: + print '|', + code = sparse_encode_parallel(dictionary, X.T, algorithm=method, + alpha=alpha, n_jobs=n_jobs) + if verbose > 1: + dt = (time.time() - t0) + print 'done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60) + return code.T, dictionary.T + + return dictionary.T + + +class BaseDictionaryLearning(BaseEstimator, TransformerMixin): + """Dictionary learning base class""" + + def __init__(self, n_atoms, transform_algorithm='omp', + transform_n_nonzero_coefs=None, transform_alpha=None, + split_sign=False, n_jobs=1): + self.n_atoms = n_atoms + self.transform_algorithm = transform_algorithm + self.transform_n_nonzero_coefs = transform_n_nonzero_coefs + self.transform_alpha = transform_alpha + self.split_sign = split_sign + self.n_jobs = n_jobs + + def transform(self, X, y=None): + """Encode the data as a sparse combination of the dictionary atoms. + + Coding method is determined by the object parameter + `transform_algorithm`. + + Parameters + ---------- + X: array of shape (n_samples, n_features) + Test data to be transformed, must have the same number of + features as the data used to train the model. + + Returns + ------- + X_new array, shape (n_samples, n_components) + Transformed data + """ + # XXX : kwargs is not documented + X = np.atleast_2d(X) + n_samples, n_features = X.shape + + code = sparse_encode_parallel( + self.components_.T, X.T, algorithm=self.transform_algorithm, + n_nonzero_coefs=self.transform_n_nonzero_coefs, + alpha=self.transform_alpha, n_jobs=self.n_jobs) + code = code.T + + if self.split_sign: + # feature vector is split into a positive and negative side + n_samples, n_features = code.shape + split_code = np.empty((n_samples, 2 * n_features)) + split_code[:, :n_features] = np.maximum(code, 0) + split_code[:, n_features:] = -np.minimum(code, 0) + code = split_code + + return code + + +class DictionaryLearning(BaseDictionaryLearning): + """ Dictionary learning + + Finds a dictionary (a set of atoms) that can best be used to represent data + using a sparse code. + + Solves the optimization problem: + (U^*,V^*) = argmin 0.5 || Y - U V ||_2^2 + alpha * || U ||_1 + (U,V) + with || V_k ||_2 = 1 for all 0 <= k < n_atoms + + Parameters + ---------- + n_atoms: int, + number of dictionary elements to extract + + alpha: int, + sparsity controlling parameter + + max_iter: int, + maximum number of iterations to perform + + tol: float, + tolerance for numerical error + + fit_algorithm: {'lasso_lars', 'lasso_cd'} + lasso_lars: uses the least angle regression method + (linear_model.lars_path) + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). Lars will be faster if + the estimated components are sparse. + + transform_algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} + Algorithm used to transform the data + lars: uses the least angle regression method (linear_model.lars_path) + lasso_lars: uses Lars to compute the Lasso solution + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). lasso_lars will be faster if + the estimated components are sparse. + omp: uses orthogonal matching pursuit to estimate the sparse solution + threshold: squashes to zero all coefficients less than alpha from + the projection X.T * Y + + transform_n_nonzero_coefs: int, 0.1 * n_features by default + Number of nonzero coefficients to target in each column of the + solution. This is only used by `algorithm='lars'` and `algorithm='omp'` + and is overridden by `alpha` in the `omp` case. + + transform_alpha: float, 1. by default + If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the + penalty applied to the L1 norm. + If `algorithm='threhold'`, `alpha` is the absolute value of the + threshold below which coefficients will be squashed to zero. + If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of + the reconstruction error targeted. In this case, it overrides + `n_nonzero_coefs`. + + n_jobs: int, + number of parallel jobs to run + + code_init: array of shape (n_samples, n_atoms), + initial value for the code, for warm restart + + dict_init: array of shape (n_atoms, n_features), + initial values for the dictionary, for warm restart + + verbose: + degree of verbosity of the printed output + + random_state: int or RandomState + Pseudo number generator state used for random sampling. + + Attributes + ---------- + components_: array, [n_atoms, n_features] + dictionary atoms extracted from the data + + error_: array + vector of errors at each iteration + + References + ---------- + J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning + for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf) + + + See also + -------- + :class:`sklearn.decomposition.SparsePCA` which solves the transposed + problem, finding sparse components to represent data. + + """ + def __init__(self, n_atoms, alpha=1, max_iter=1000, tol=1e-8, + fit_algorithm='lasso_lars', transform_algorithm='omp', + transform_n_nonzero_coefs=None, transform_alpha=None, + n_jobs=1, code_init=None, dict_init=None, verbose=False, + split_sign=False, random_state=None): + BaseDictionaryLearning.__init__(self, n_atoms, transform_algorithm, + transform_n_nonzero_coefs, transform_alpha, split_sign, + n_jobs) + self.alpha = alpha + self.max_iter = max_iter + self.tol = tol + self.fit_algorithm = fit_algorithm + self.code_init = code_init + self.dict_init = dict_init + self.verbose = verbose + self.random_state = random_state + + def fit(self, X, y=None): + """Fit the model from data in X. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples + and n_features is the number of features. + + Returns + ------- + self: object + Returns the object itself + """ + self.random_state = check_random_state(self.random_state) + X = np.asanyarray(X) + V, U, E = dict_learning(X, self.n_atoms, self.alpha, + tol=self.tol, max_iter=self.max_iter, + method=self.fit_algorithm, + n_jobs=self.n_jobs, + code_init=self.code_init, + dict_init=self.dict_init, + verbose=self.verbose, + random_state=self.random_state) + self.components_ = U + self.error_ = E + return self + + +class DictionaryLearningOnline(BaseDictionaryLearning): + """ Online dictionary learning + + Finds a dictionary (a set of atoms) that can best be used to represent data + using a sparse code. + + Solves the optimization problem: + (U^*,V^*) = argmin 0.5 || Y - U V ||_2^2 + alpha * || U ||_1 + (U,V) + with || V_k ||_2 = 1 for all 0 <= k < n_atoms + + Parameters + ---------- + n_atoms: int, + number of dictionary elements to extract + + alpha: int, + sparsity controlling parameter + + n_iter: int, + total number of iterations to perform + + fit_algorithm: {'lars', 'cd'} + lars: uses the least angle regression method (linear_model.lars_path) + cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). Lars will be faster if + the estimated components are sparse. + + transform_algorithm: {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} + Algorithm used to transform the data. + lars: uses the least angle regression method (linear_model.lars_path) + lasso_lars: uses Lars to compute the Lasso solution + lasso_cd: uses the coordinate descent method to compute the + Lasso solution (linear_model.Lasso). lasso_lars will be faster if + the estimated components are sparse. + omp: uses orthogonal matching pursuit to estimate the sparse solution + threshold: squashes to zero all coefficients less than alpha from + the projection X.T * Y + + transform_n_nonzero_coefs: int, 0.1 * n_features by default + Number of nonzero coefficients to target in each column of the + solution. This is only used by `algorithm='lars'` and `algorithm='omp'` + and is overridden by `alpha` in the `omp` case. + + transform_alpha: float, 1. by default + If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the + penalty applied to the L1 norm. + If `algorithm='threhold'`, `alpha` is the absolute value of the + threshold below which coefficients will be squashed to zero. + If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of + the reconstruction error targeted. In this case, it overrides + `n_nonzero_coefs`. + + n_jobs: int, + number of parallel jobs to run + + dict_init: array of shape (n_atoms, n_features), + initial value of the dictionary for warm restart scenarios + + verbose: + degree of verbosity of the printed output + + chunk_size: int, + number of samples in each mini-batch + + shuffle: bool, + whether to shuffle the samples before forming batches + + random_state: int or RandomState + Pseudo number generator state used for random sampling. + + Attributes + ---------- + components_: array, [n_atoms, n_features] + components extracted from the data + + References + ---------- + J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning + for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf) + + + See also + -------- + :class:`sklearn.decomposition.SparsePCA` which solves the transposed + problem, finding sparse components to represent data. + + """ + def __init__(self, n_atoms, alpha=1, n_iter=1000, + fit_algorithm='lasso_lars', n_jobs=1, chunk_size=3, + shuffle=True, dict_init=None, transform_algorithm='omp', + transform_n_nonzero_coefs=None, transform_alpha=None, + verbose=False, split_sign=False, random_state=None): + BaseDictionaryLearning.__init__(self, n_atoms, transform_algorithm, + transform_n_nonzero_coefs, transform_alpha, split_sign, + n_jobs) + self.alpha = alpha + self.n_iter = n_iter + self.fit_algorithm = fit_algorithm + self.dict_init = dict_init + self.verbose = verbose + self.shuffle = shuffle + self.chunk_size = chunk_size + self.split_sign = split_sign + self.random_state = random_state + + def fit(self, X, y=None): + """Fit the model from data in X. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples + and n_features is the number of features. + + Returns + ------- + self : object + Returns the instance itself. + """ + self.random_state = check_random_state(self.random_state) + X = np.asanyarray(X) + U = dict_learning_online(X, self.n_atoms, self.alpha, + n_iter=self.n_iter, return_code=False, + method=self.fit_algorithm, + n_jobs=self.n_jobs, + dict_init=self.dict_init, + chunk_size=self.chunk_size, + shuffle=self.shuffle, verbose=self.verbose, + random_state=self.random_state) + self.components_ = U + return self + + def partial_fit(self, X, y=None, iter_offset=0): + """Updates the model using the data in X as a mini-batch. + + Parameters + ---------- + X: array-like, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples + and n_features is the number of features. + + Returns + ------- + self : object + Returns the instance itself. + """ + self.random_state = check_random_state(self.random_state) + X = np.atleast_2d(X) + if hasattr(self, 'components_'): + dict_init = self.components_ + else: + dict_init = self.dict_init + U = dict_learning_online(X, self.n_atoms, self.alpha, + n_iter=self.n_iter, + method=self.fit_algorithm, + n_jobs=self.n_jobs, dict_init=dict_init, + chunk_size=len(X), shuffle=False, + verbose=self.verbose, return_code=False, + iter_offset=iter_offset, + random_state=self.random_state) + self.components_ = U + return self diff --git a/sklearn/decomposition/sparse_pca.py b/sklearn/decomposition/sparse_pca.py index f6b02dc1d4ac78f35dbd5b4e12c6d3e308d85073..7bcf2b729b34aa853c3c1cdc636e2ea34b966a13 100644 --- a/sklearn/decomposition/sparse_pca.py +++ b/sklearn/decomposition/sparse_pca.py @@ -2,529 +2,12 @@ # Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort # License: BSD -import time -import sys - -from math import sqrt, floor, ceil -import itertools - import numpy as np -from numpy.lib.stride_tricks import as_strided -from scipy import linalg from ..utils import check_random_state -from ..utils import gen_even_slices -from ..utils.extmath import fast_svd -from ..linear_model import Lasso, lars_path, ridge_regression -from ..externals.joblib import Parallel, delayed, cpu_count +from ..linear_model import ridge_regression from ..base import BaseEstimator, TransformerMixin - - -def _update_code(dictionary, Y, alpha, code=None, Gram=None, method='lars', - tol=1e-8): - """Update the sparse code factor in the sparse_pca loop. - - Each column of the result is the solution to a Lasso problem. - - Parameters - ---------- - dictionary: array of shape (n_samples, n_components) - Dictionary against which to optimize the sparse code. - - Y: array of shape (n_samples, n_features) - Data matrix. - - alpha: float - Regularization parameter for the Lasso problem. - - code: array of shape (n_components, n_features) - Value of the sparse codes at the previous iteration. - - Gram: array of shape (n_features, n_features) - Precomputed Gram matrix, (Y^T * Y). - - method: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the - Lasso solution (linear_model.Lasso). Lars will be faster if - the estimated components are sparse. - - tol: float - Numerical tolerance for coordinate descent Lasso convergence. - Only used if `method='cd'` - - Returns - ------- - new_code : array of shape (n_components, n_features) - The sparse codes precomputed using this iteration's dictionary - """ - n_features = Y.shape[1] - n_atoms = dictionary.shape[1] - new_code = np.empty((n_atoms, n_features)) - if Gram is None: - Gram = np.dot(dictionary.T, dictionary) - if method == 'lars': - XY = np.dot(dictionary.T, Y) - try: - err_mgt = np.seterr(all='ignore') - for k in range(n_features): - # A huge amount of time is spent in this loop. It needs to be - # tight. - _, _, coef_path_ = lars_path(dictionary, Y[:, k], Xy=XY[:, k], - Gram=Gram, alpha_min=alpha, - method='lasso') - new_code[:, k] = coef_path_[:, -1] - finally: - np.seterr(**err_mgt) - elif method == 'cd': - clf = Lasso(alpha=alpha, fit_intercept=False, precompute=Gram, - max_iter=1000, tol=tol) - for k in range(n_features): - # A huge amount of time is spent in this loop. It needs to be - # tight. - if code is not None: - clf.coef_ = code[:, k] # Init with previous value of Vk - clf.fit(dictionary, Y[:, k]) - new_code[:, k] = clf.coef_ - else: - raise NotImplemented("Lasso method %s is not implemented." % method) - return new_code - - -def _update_code_parallel(dictionary, Y, alpha, code=None, Gram=None, - method='lars', n_jobs=1, tol=1e-8): - """Update the sparse factor V in the sparse_pca loop in parallel. - - The computation is spread over all the available cores. - - Parameters - ---------- - dictionary: array of shape (n_samples, n_components) - Dictionary against which to optimize the sparse code. - - Y: array of shape (n_samples, n_features) - Data matrix. - - alpha: float - Regularization parameter for the Lasso problem. - - code: array of shape (n_components, n_features) - Previous iteration of the sparse code. - - Gram: array of shape (n_features, n_features) - Precomputed Gram matrix, (Y^T * Y). - - method: 'lars' | 'cd' - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the - lasso solution (linear_model.Lasso). Lars will be faster if - the components extracted are sparse. - - n_jobs: int - Number of parallel jobs to run. - - tol: float - Numerical tolerance for coordinate descent Lasso convergence. - Only used if `method='cd`. - - """ - n_samples, n_features = Y.shape - n_atoms = dictionary.shape[1] - if Gram is None: - Gram = np.dot(dictionary.T, dictionary) - if n_jobs == 1: - return _update_code(dictionary, Y, alpha, code=code, Gram=Gram, - method=method) - if code is None: - code = np.empty((n_atoms, n_features)) - slices = list(gen_even_slices(n_features, n_jobs)) - code_views = Parallel(n_jobs=n_jobs)( - delayed(_update_code)(dictionary, Y[:, this_slice], - code=code[:, this_slice], alpha=alpha, - Gram=Gram, method=method, tol=tol) - for this_slice in slices) - for this_slice, this_view in zip(slices, code_views): - code[:, this_slice] = this_view - return code - - -def _update_dict(dictionary, Y, code, verbose=False, return_r2=False, - random_state=None): - """Update the dense dictionary factor in place. - - Parameters - ---------- - dictionary: array of shape (n_samples, n_components) - Value of the dictionary at the previous iteration. - - Y: array of shape (n_samples, n_features) - Data matrix. - - code: array of shape (n_components, n_features) - Sparse coding of the data against which to optimize the dictionary. - - verbose: - Degree of output the procedure will print. - - return_r2: bool - Whether to compute and return the residual sum of squares corresponding - to the computed solution. - - random_state: int or RandomState - Pseudo number generator state used for random sampling. - - Returns - ------- - dictionary: array of shape (n_samples, n_components) - Updated dictionary. - - """ - n_atoms = len(code) - n_samples = Y.shape[0] - random_state = check_random_state(random_state) - # Residuals, computed 'in-place' for efficiency - R = -np.dot(dictionary, code) - R += Y - R = np.asfortranarray(R) - ger, = linalg.get_blas_funcs(('ger',), (dictionary, code)) - for k in xrange(n_atoms): - # R <- 1.0 * U_k * V_k^T + R - R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True) - dictionary[:, k] = np.dot(R, code[k, :].T) - # Scale k'th atom - atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k]) - if atom_norm_square < 1e-20: - if verbose == 1: - sys.stdout.write("+") - sys.stdout.flush() - elif verbose: - print "Adding new random atom" - dictionary[:, k] = random_state.randn(n_samples) - # Setting corresponding coefs to 0 - code[k, :] = 0.0 - dictionary[:, k] /= sqrt(np.dot(dictionary[:, k], - dictionary[:, k])) - else: - dictionary[:, k] /= sqrt(atom_norm_square) - # R <- -1.0 * U_k * V_k^T + R - R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True) - if return_r2: - R **= 2 - # R is fortran-ordered. For numpy version < 1.6, sum does not - # follow the quick striding first, and is thus inefficient on - # fortran ordered data. We take a flat view of the data with no - # striding - R = as_strided(R, shape=(R.size, ), strides=(R.dtype.itemsize,)) - R = np.sum(R) - return dictionary, R - return dictionary - - -def dict_learning(X, n_atoms, alpha, max_iter=100, tol=1e-8, method='lars', - n_jobs=1, dict_init=None, code_init=None, callback=None, - verbose=False, random_state=None): - """Solves a dictionary learning matrix factorization problem. - - Finds the best dictionary and the corresponding sparse code for - approximating the data matrix X by solving:: - - (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1 - (U,V) - with || V_k ||_2 = 1 for all 0 <= k < n_atoms - - where V is the dictionary and U is the sparse code. - - Parameters - ---------- - X: array of shape (n_samples, n_features) - Data matrix. - - n_atoms: int, - Number of dictionary atoms to extract. - - alpha: int, - Sparsity controlling parameter. - - max_iter: int, - Maximum number of iterations to perform. - - tol: float, - Tolerance for the stopping condition. - - method: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the - Lasso solution (linear_model.Lasso). Lars will be faster if - the estimated components are sparse. - - n_jobs: int, - Number of parallel jobs to run, or -1 to autodetect. - - dict_init: array of shape (n_atoms, n_features), - Initial value for the dictionary for warm restart scenarios. - - code_init: array of shape (n_samples, n_atoms), - Initial value for the sparse code for warm restart scenarios. - - callback: - Callable that gets invoked every five iterations. - - verbose: - Degree of output the procedure will print. - - random_state: int or RandomState - Pseudo number generator state used for random sampling. - - Returns - ------- - code: array of shape (n_samples, n_atoms) - The sparse code factor in the matrix factorization. - - dictionary: array of shape (n_atoms, n_features), - The dictionary factor in the matrix factorization. - - errors: array - Vector of errors at each iteration. - - """ - t0 = time.time() - n_features = X.shape[1] - # Avoid integer division problems - alpha = float(alpha) - random_state = check_random_state(random_state) - - if n_jobs == -1: - n_jobs = cpu_count() - - # Init U and V with SVD of Y - if code_init is not None and code_init is not None: - code = np.array(code_init, order='F') - # Don't copy V, it will happen below - dictionary = dict_init - else: - code, S, dictionary = linalg.svd(X, full_matrices=False) - dictionary = S[:, np.newaxis] * dictionary - r = len(dictionary) - if n_atoms <= r: # True even if n_atoms=None - code = code[:, :n_atoms] - dictionary = dictionary[:n_atoms, :] - else: - code = np.c_[code, np.zeros((len(code), n_atoms - r))] - dictionary = np.r_[dictionary, - np.zeros((n_atoms - r, dictionary.shape[1]))] - - # Fortran-order dict, as we are going to access its row vectors - dictionary = np.array(dictionary, order='F') - - residuals = 0 - - errors = [] - current_cost = np.nan - - if verbose == 1: - print '[dict_learning]', - - for ii in xrange(max_iter): - dt = (time.time() - t0) - if verbose == 1: - sys.stdout.write(".") - sys.stdout.flush() - elif verbose: - print ("Iteration % 3i " - "(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)" % - (ii, dt, dt / 60, current_cost)) - - # Update code - code = _update_code_parallel(dictionary.T, X.T, alpha / n_features, - code.T, method=method, n_jobs=n_jobs) - code = code.T - # Update dictionary - dictionary, residuals = _update_dict(dictionary.T, X.T, code.T, - verbose=verbose, return_r2=True, - random_state=random_state) - dictionary = dictionary.T - - # Cost function - current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code)) - errors.append(current_cost) - - if ii > 0: - dE = errors[-2] - errors[-1] - assert(dE >= -tol * errors[-1]) - if dE < tol * errors[-1]: - if verbose == 1: - # A line return - print "" - elif verbose: - print "--- Convergence reached after %d iterations" % ii - break - if ii % 5 == 0 and callback is not None: - callback(locals()) - - return code, dictionary, errors - - -def dict_learning_online(X, n_atoms, alpha, n_iter=100, return_code=True, - dict_init=None, callback=None, chunk_size=3, - verbose=False, shuffle=True, n_jobs=1, - method='lars', iter_offset=0, random_state=None): - """Solves a dictionary learning matrix factorization problem online. - - Finds the best dictionary and the corresponding sparse code for - approximating the data matrix X by solving: - - (U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1 - (U,V) - with || V_k ||_2 = 1 for all 0 <= k < n_atoms - - where V is the dictionary and U is the sparse code. This is - accomplished by repeatedly iterating over mini-batches by slicing - the input data. - - Parameters - ---------- - X: array of shape (n_samples, n_features) - data matrix - - n_atoms: int, - number of dictionary atoms to extract - - alpha: int, - sparsity controlling parameter - - n_iter: int, - number of iterations to perform - - return_code: boolean, - whether to also return the code U or just the dictionary V - - dict_init: array of shape (n_atoms, n_features), - initial value for the dictionary for warm restart scenarios - - callback: - callable that gets invoked every five iterations - - chunk_size: int, - the number of samples to take in each batch - - verbose: - degree of output the procedure will print - - shuffle: boolean, - whether to shuffle the data before splitting it in batches - - n_jobs: int, - number of parallel jobs to run, or -1 to autodetect. - - method: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the - Lasso solution (linear_model.Lasso). Lars will be faster if - the estimated components are sparse. - - iter_offset: int, default 0 - number of previous iterations completed on the dictionary used for - initialization - - random_state: int or RandomState - Pseudo number generator state used for random sampling. - - Returns - ------- - dictionary: array of shape (n_atoms, n_features), - the solutions to the dictionary learning problem - - code: array of shape (n_samples, n_atoms), - the sparse code (only returned if `return_code=True`) - """ - t0 = time.time() - n_samples, n_features = X.shape - # Avoid integer division problems - alpha = float(alpha) - random_state = check_random_state(random_state) - - if n_jobs == -1: - n_jobs = cpu_count() - - # Init V with SVD of X - if dict_init is not None: - dictionary = dict_init - else: - _, S, dictionary = fast_svd(X, n_atoms) - dictionary = S[:, np.newaxis] * dictionary - r = len(dictionary) - if n_atoms <= r: - dictionary = dictionary[:n_atoms, :] - else: - dictionary = np.r_[dictionary, - np.zeros((n_atoms - r, dictionary.shape[1]))] - dictionary = np.ascontiguousarray(dictionary.T) - - if verbose == 1: - print '[dict_learning]', - - n_batches = floor(float(len(X)) / chunk_size) - if shuffle: - X_train = X.copy() - random_state.shuffle(X_train) - else: - X_train = X - batches = np.array_split(X_train, n_batches) - batches = itertools.cycle(batches) - - # The covariance of the dictionary - A = np.zeros((n_atoms, n_atoms)) - # The data approximation - B = np.zeros((n_features, n_atoms)) - - for ii, this_X in itertools.izip(xrange(iter_offset, iter_offset + n_iter), - batches): - dt = (time.time() - t0) - if verbose == 1: - sys.stdout.write(".") - sys.stdout.flush() - elif verbose: - if verbose > 10 or ii % ceil(100. / verbose) == 0: - print ("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)" % - (ii, dt, dt / 60)) - - this_code = _update_code(dictionary, this_X.T, alpha, method=method) - - # Update the auxiliary variables - if ii < chunk_size - 1: - theta = float((ii + 1) * chunk_size) - else: - theta = float(chunk_size ** 2 + ii + 1 - chunk_size) - beta = (theta + 1 - chunk_size) / (theta + 1) - - A *= beta - A += np.dot(this_code, this_code.T) - B *= beta - B += np.dot(this_X.T, this_code.T) - - # Update dictionary - dictionary = _update_dict(dictionary, B, A, verbose=verbose, - random_state=random_state) - # XXX: Can the residuals be of any use? - - # Maybe we need a stopping criteria based on the amount of - # modification in the dictionary - if callback is not None: - callback(locals()) - - if return_code: - if verbose > 1: - print 'Learning code...', - elif verbose == 1: - print '|', - code = _update_code_parallel(dictionary, X.T, alpha, n_jobs=n_jobs, - method=method) - if verbose > 1: - dt = (time.time() - t0) - print 'done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60) - return code.T, dictionary.T - - return dictionary.T +from .dict_learning import dict_learning, dict_learning_online class SparsePCA(BaseEstimator, TransformerMixin): @@ -553,9 +36,10 @@ class SparsePCA(BaseEstimator, TransformerMixin): tol: float, Tolerance for the stopping condition. - method: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the + method: {'lasso_lars', 'lasso_cd'} + lasso_lars: uses the least angle regression method + (linear_model.lars_path) + lasso_cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -588,8 +72,8 @@ class SparsePCA(BaseEstimator, TransformerMixin): """ def __init__(self, n_components, alpha=1, ridge_alpha=0.01, max_iter=1000, - tol=1e-8, method='lars', n_jobs=1, U_init=None, V_init=None, - verbose=False, random_state=None): + tol=1e-8, method='lasso_lars', n_jobs=1, U_init=None, + V_init=None, verbose=False, random_state=None): self.n_components = n_components self.alpha = alpha self.ridge_alpha = ridge_alpha @@ -701,9 +185,10 @@ class MiniBatchSparsePCA(SparsePCA): n_jobs: int, number of parallel jobs to run, or -1 to autodetect. - method: {'lars', 'cd'} - lars: uses the least angle regression method (linear_model.lars_path) - cd: uses the coordinate descent method to compute the + method: {'lasso_lars', 'lasso_cd'} + lasso_lars: uses the least angle regression method + (linear_model.lars_path) + lasso_cd: uses the coordinate descent method to compute the Lasso solution (linear_model.Lasso). Lars will be faster if the estimated components are sparse. @@ -713,7 +198,7 @@ class MiniBatchSparsePCA(SparsePCA): """ def __init__(self, n_components, alpha=1, ridge_alpha=0.01, n_iter=100, callback=None, chunk_size=3, verbose=False, shuffle=True, - n_jobs=1, method='lars', random_state=None): + n_jobs=1, method='lasso_lars', random_state=None): self.n_components = n_components self.alpha = alpha self.ridge_alpha = ridge_alpha diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..fcde71495e44ed6a8f2c52b31f873e7591255b59 --- /dev/null +++ b/sklearn/decomposition/tests/test_dict_learning.py @@ -0,0 +1,122 @@ +import numpy as np +from numpy.testing import assert_array_almost_equal, assert_array_equal, \ + assert_equal + +from .. import DictionaryLearning, DictionaryLearningOnline, \ + dict_learning_online +from ..dict_learning import sparse_encode, sparse_encode_parallel + +rng = np.random.RandomState(0) +n_samples, n_features = 10, 8 +X = rng.randn(n_samples, n_features) + + +def test_dict_learning_shapes(): + n_atoms = 5 + dico = DictionaryLearning(n_atoms).fit(X) + assert dico.components_.shape == (n_atoms, n_features) + + +def test_dict_learning_overcomplete(): + n_atoms = 12 + X = rng.randn(n_samples, n_features) + dico = DictionaryLearning(n_atoms).fit(X) + assert dico.components_.shape == (n_atoms, n_features) + + +def test_dict_learning_reconstruction(): + n_atoms = 12 + dico = DictionaryLearning(n_atoms, transform_algorithm='omp', + transform_alpha=0.001, random_state=0) + code = dico.fit(X).transform(X) + assert_array_almost_equal(np.dot(code, dico.components_), X) + + dico.set_params(transform_algorithm='lasso_lars') + code = dico.transform(X) + assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2) + + dico.set_params(transform_algorithm='lars') + code = dico.transform(X) + assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2) + + +def test_dict_learning_nonzero_coefs(): + n_atoms = 4 + dico = DictionaryLearning(n_atoms, transform_algorithm='lars', + transform_n_nonzero_coefs=3, random_state=0) + code = dico.fit(X).transform(X[0]) + assert len(np.flatnonzero(code)) == 3 + + dico.set_params(transform_algorithm='omp') + code = dico.transform(X[0]) + assert len(np.flatnonzero(code)) == 3 + + +def test_dict_learning_split(): + n_atoms = 5 + dico = DictionaryLearning(n_atoms, transform_algorithm='threshold') + code = dico.fit(X).transform(X) + dico.split_sign = True + split_code = dico.transform(X) + + assert_array_equal(split_code[:, :n_atoms] - split_code[:, n_atoms:], code) + + +def test_dict_learning_online_shapes(): + rng = np.random.RandomState(0) + X = rng.randn(12, 10) + dictionaryT, codeT = dict_learning_online(X.T, n_atoms=8, alpha=1, + random_state=rng) + assert_equal(codeT.shape, (8, 12)) + assert_equal(dictionaryT.shape, (10, 8)) + assert_equal(np.dot(codeT.T, dictionaryT.T).shape, X.shape) + + +def test_dict_learning_online_estimator_shapes(): + n_atoms = 5 + dico = DictionaryLearningOnline(n_atoms, n_iter=20).fit(X) + assert dico.components_.shape == (n_atoms, n_features) + + +def test_dict_learning_online_overcomplete(): + n_atoms = 12 + dico = DictionaryLearningOnline(n_atoms, n_iter=20).fit(X) + assert dico.components_.shape == (n_atoms, n_features) + + +def test_dict_learning_online_initialization(): + n_atoms = 12 + V = rng.randn(n_atoms, n_features) + dico = DictionaryLearningOnline(n_atoms, n_iter=0, dict_init=V).fit(X) + assert_array_equal(dico.components_, V) + + +def test_dict_learning_online_partial_fit(): + n_atoms = 12 + V = rng.randn(n_atoms, n_features) # random init + rng1 = np.random.RandomState(0) + rng2 = np.random.RandomState(0) + dico1 = DictionaryLearningOnline(n_atoms, n_iter=10, chunk_size=1, + shuffle=False, dict_init=V, + transform_algorithm='threshold', + random_state=rng1).fit(X) + dico2 = DictionaryLearningOnline(n_atoms, n_iter=1, dict_init=V, + transform_algorithm='threshold', + random_state=rng2) + for ii, sample in enumerate(X): + dico2.partial_fit(sample, iter_offset=ii * dico2.n_iter) + + assert_array_equal(dico1.components_, dico2.components_) + + +def test_sparse_code(): + rng = np.random.RandomState(0) + dictionary = rng.randn(10, 3) + real_code = np.zeros((3, 5)) + real_code.ravel()[rng.randint(15, size=6)] = 1.0 + Y = np.dot(dictionary, real_code) + est_code_1 = sparse_encode(dictionary, Y, alpha=1.0) + est_code_2 = sparse_encode_parallel(dictionary, Y, alpha=1.0) + assert_equal(est_code_1.shape, real_code.shape) + assert_equal(est_code_1, est_code_2) + assert_equal(est_code_1.nonzero(), real_code.nonzero()) diff --git a/sklearn/decomposition/tests/test_sparse_pca.py b/sklearn/decomposition/tests/test_sparse_pca.py index f7da34a75ae9c682f5bc3f677fa1ed89d0550065..e8e3cda667d0117a8c2b9e74f5dbdfa613c984b0 100644 --- a/sklearn/decomposition/tests/test_sparse_pca.py +++ b/sklearn/decomposition/tests/test_sparse_pca.py @@ -6,8 +6,7 @@ import sys import numpy as np from numpy.testing import assert_array_almost_equal, assert_equal -from .. import SparsePCA, MiniBatchSparsePCA, dict_learning_online -from ..sparse_pca import _update_code, _update_code_parallel +from .. import SparsePCA, MiniBatchSparsePCA from ...utils import check_random_state @@ -53,7 +52,8 @@ def test_correct_shapes(): def test_fit_transform(): rng = np.random.RandomState(0) Y, _, _ = generate_toy_data(3, 10, (8, 8), random_state=rng) # wide array - spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng) + spca_lars = SparsePCA(n_components=3, method='lasso_lars', + random_state=rng) spca_lars.fit(Y) U1 = spca_lars.transform(Y) # Test multiple CPUs @@ -71,7 +71,7 @@ def test_fit_transform(): U2 = spca.transform(Y) assert_array_almost_equal(U1, U2) # Test that CD gives similar results - spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng) + spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng) spca_lasso.fit(Y) assert_array_almost_equal(spca_lasso.components_, spca_lars.components_) @@ -79,26 +79,14 @@ def test_fit_transform(): def test_fit_transform_tall(): rng = np.random.RandomState(0) Y, _, _ = generate_toy_data(3, 65, (8, 8), random_state=rng) # tall array - spca_lars = SparsePCA(n_components=3, method='lars', random_state=rng) + spca_lars = SparsePCA(n_components=3, method='lasso_lars', + random_state=rng) U1 = spca_lars.fit_transform(Y) - spca_lasso = SparsePCA(n_components=3, method='cd', random_state=rng) + spca_lasso = SparsePCA(n_components=3, method='lasso_cd', random_state=rng) U2 = spca_lasso.fit(Y).transform(Y) assert_array_almost_equal(U1, U2) -def test_sparse_code(): - rng = np.random.RandomState(0) - dictionary = rng.randn(10, 3) - real_code = np.zeros((3, 5)) - real_code.ravel()[rng.randint(15, size=6)] = 1.0 - Y = np.dot(dictionary, real_code) - est_code_1 = _update_code(dictionary, Y, alpha=1.0) - est_code_2 = _update_code_parallel(dictionary, Y, alpha=1.0) - assert_equal(est_code_1.shape, real_code.shape) - assert_equal(est_code_1, est_code_2) - assert_equal(est_code_1.nonzero(), real_code.nonzero()) - - def test_initialization(): rng = np.random.RandomState(0) U_init = rng.randn(5, 3) @@ -109,16 +97,6 @@ def test_initialization(): assert_equal(model.components_, V_init) -def test_dict_learning_online_shapes(): - rng = np.random.RandomState(0) - X = rng.randn(12, 10) - dictionaryT, codeT = dict_learning_online(X.T, n_atoms=8, alpha=1, - random_state=rng) - assert_equal(codeT.shape, (8, 12)) - assert_equal(dictionaryT.shape, (10, 8)) - assert_equal(np.dot(codeT.T, dictionaryT.T).shape, X.shape) - - def test_mini_batch_correct_shapes(): rng = np.random.RandomState(0) X = rng.randn(12, 10) @@ -153,6 +131,6 @@ def test_mini_batch_fit_transform(): random_state=rng).fit(Y).transform(Y) assert_array_almost_equal(U1, U2) # Test that CD gives similar results - spca_lasso = MiniBatchSparsePCA(n_components=3, method='cd', + spca_lasso = MiniBatchSparsePCA(n_components=3, method='lasso_cd', random_state=rng).fit(Y) assert_array_almost_equal(spca_lasso.components_, spca_lars.components_) diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index 8ec6ab0b8cb15abf6cfcc749bd0a982cae378fc4..c8ed4277c5f076a0c0f0ea61279fb1ee484e8b8f 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -360,7 +360,6 @@ class PatchExtractor(BaseEstimator): `n_patches` is either `n_samples * max_patches` or the total number of patches that can be extracted. - """ self.random_state = check_random_state(self.random_state) n_images, i_h, i_w = X.shape[:3] diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py index 3c6e7bda6d6031b86b81831024f71109773b26b2..81b4f38ed8582e6c71e22747dab6d0c1f62d9fa2 100644 --- a/sklearn/linear_model/omp.py +++ b/sklearn/linear_model/omp.py @@ -20,7 +20,7 @@ dependence in the dictionary. The requested precision might not have been met. """ -def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False): +def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, overwrite_X=False): """Orthogonal Matching Pursuit step using the Cholesky decomposition. Parameters: @@ -34,7 +34,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False): n_nonzero_coefs: int Targeted number of non-zero elements - eps: float + tol: float Targeted squared error, if not None overrides n_nonzero_coefs. overwrite_X: bool, @@ -66,7 +66,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False): n_active = 0 idx = [] - max_features = X.shape[1] if eps is not None else n_nonzero_coefs + max_features = X.shape[1] if tol is not None else n_nonzero_coefs L = np.empty((max_features, max_features), dtype=X.dtype) L[0, 0] = 1. @@ -94,7 +94,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False): overwrite_b=False) residual = y - np.dot(X[:, :n_active], gamma) - if eps is not None and nrm2(residual) ** 2 <= eps: + if tol is not None and nrm2(residual) ** 2 <= tol: break elif n_active == max_features: break @@ -102,7 +102,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, eps=None, overwrite_X=False): return gamma, idx -def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None, +def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, overwrite_gram=False, overwrite_Xy=False): """Orthogonal Matching Pursuit step on a precomputed Gram matrix. @@ -119,10 +119,10 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None, n_nonzero_coefs: int Targeted number of non-zero elements - eps_0: float - Squared norm of y, required if eps is not None. + tol_0: float + Squared norm of y, required if tol is not None. - eps: float + tol: float Targeted squared error, if not None overrides n_nonzero_coefs. overwrite_gram: bool, @@ -157,11 +157,11 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None, idx = [] alpha = Xy - eps_curr = eps_0 + tol_curr = tol_0 delta = 0 n_active = 0 - max_features = len(Gram) if eps is not None else n_nonzero_coefs + max_features = len(Gram) if tol is not None else n_nonzero_coefs L = np.empty((max_features, max_features), dtype=Gram.dtype) L[0, 0] = 1. @@ -190,11 +190,11 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None, beta = np.dot(Gram[:, :n_active], gamma) alpha = Xy - beta - if eps is not None: - eps_curr += delta + if tol is not None: + tol_curr += delta delta = np.inner(gamma, beta[:n_active]) - eps_curr -= delta - if eps_curr <= eps: + tol_curr -= delta + if tol_curr <= tol: break elif n_active == max_features: break @@ -202,7 +202,7 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, eps_0=None, eps=None, return gamma, idx -def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False, +def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute_gram=False, overwrite_X=False): """Orthogonal Matching Pursuit (OMP) @@ -213,8 +213,8 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False, `n_nonzero_coefs`: argmin ||y - X\gamma||^2 subject to ||\gamma||_0 <= n_{nonzero coefs} - When parametrized by error using the parameter `eps`: - argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= \epsilon + When parametrized by error using the parameter `tol`: + argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= tol Parameters ---------- @@ -228,7 +228,7 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False, Desired number of non-zero entries in the solution. If None (by default) this value is set to 10% of n_features. - eps: float + tol: float Maximum norm of the residual. If not None, overrides n_nonzero_coefs. precompute_gram: {True, False, 'auto'}, @@ -274,13 +274,13 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False, X = np.asfortranarray(X) if y.shape[1] > 1: # subsequent targets will be affected overwrite_X = False - if n_nonzero_coefs == None and eps == None: + if n_nonzero_coefs == None and tol == None: n_nonzero_coefs = int(0.1 * X.shape[1]) - if eps is not None and eps < 0: + if tol is not None and tol < 0: raise ValueError("Epsilon cannot be negative") - if eps is None and n_nonzero_coefs <= 0: + if tol is None and n_nonzero_coefs <= 0: raise ValueError("The number of atoms must be positive") - if eps is None and n_nonzero_coefs > X.shape[1]: + if tol is None and n_nonzero_coefs > X.shape[1]: raise ValueError("The number of atoms cannot be more than the number \ of features") if precompute_gram == 'auto': @@ -289,23 +289,23 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, eps=None, precompute_gram=False, G = np.dot(X.T, X) G = np.asfortranarray(G) Xy = np.dot(X.T, y) - if eps is not None: + if tol is not None: norms_squared = np.sum((y ** 2), axis=0) else: norms_squared = None - return orthogonal_mp_gram(G, Xy, n_nonzero_coefs, eps, norms_squared, + return orthogonal_mp_gram(G, Xy, n_nonzero_coefs, tol, norms_squared, overwrite_gram=overwrite_X, overwrite_Xy=True) coef = np.zeros((X.shape[1], y.shape[1])) for k in xrange(y.shape[1]): - x, idx = _cholesky_omp(X, y[:, k], n_nonzero_coefs, eps, + x, idx = _cholesky_omp(X, y[:, k], n_nonzero_coefs, tol, overwrite_X=overwrite_X) coef[idx, k] = x return np.squeeze(coef) -def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None, +def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None, norms_squared=None, overwrite_gram=False, overwrite_Xy=False): """Gram Orthogonal Matching Pursuit (OMP) @@ -325,11 +325,11 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None, Desired number of non-zero entries in the solution. If None (by default) this value is set to 10% of n_features. - eps: float + tol: float Maximum norm of the residual. If not None, overrides n_nonzero_coefs. norms_squared: array-like, shape = (n_targets,) - Squared L2 norms of the lines of y. Required if eps is not None. + Squared L2 norms of the lines of y. Required if tol is not None. overwrite_gram: bool, Whether the gram matrix can be overwritten by the algorithm. This @@ -366,25 +366,25 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, eps=None, Gram, Xy = map(np.asanyarray, (Gram, Xy)) if Xy.ndim == 1: Xy = Xy[:, np.newaxis] - if eps is not None: + if tol is not None: norms_squared = [norms_squared] - if n_nonzero_coefs == None and eps is None: + if n_nonzero_coefs == None and tol is None: n_nonzero_coefs = int(0.1 * len(Gram)) - if eps is not None and norms_squared == None: + if tol is not None and norms_squared == None: raise ValueError('Gram OMP needs the precomputed norms in order \ to evaluate the error sum of squares.') - if eps is not None and eps < 0: + if tol is not None and tol < 0: raise ValueError("Epsilon cennot be negative") - if eps is None and n_nonzero_coefs <= 0: + if tol is None and n_nonzero_coefs <= 0: raise ValueError("The number of atoms must be positive") - if eps is None and n_nonzero_coefs > len(Gram): + if tol is None and n_nonzero_coefs > len(Gram): raise ValueError("The number of atoms cannot be more than the number \ of features") coef = np.zeros((len(Gram), Xy.shape[1])) for k in range(Xy.shape[1]): x, idx = _gram_omp(Gram, Xy[:, k], n_nonzero_coefs, - norms_squared[k] if eps is not None else None, eps, + norms_squared[k] if tol is not None else None, tol, overwrite_gram=overwrite_gram, overwrite_Xy=overwrite_Xy) coef[idx, k] = x @@ -400,7 +400,7 @@ class OrthogonalMatchingPursuit(LinearModel): Desired number of non-zero entries in the solution. If None (by default) this value is set to 10% of n_features. - eps: float, optional + tol: float, optional Maximum norm of the residual. If not None, overrides n_nonzero_coefs. fit_intercept: boolean, optional @@ -462,17 +462,16 @@ class OrthogonalMatchingPursuit(LinearModel): """ def __init__(self, overwrite_X=False, overwrite_gram=False, - overwrite_Xy=False, n_nonzero_coefs=None, eps=None, + overwrite_Xy=False, n_nonzero_coefs=None, tol=None, fit_intercept=True, normalize=True, precompute_gram=False): self.n_nonzero_coefs = n_nonzero_coefs - self.eps = eps + self.tol = tol self.fit_intercept = fit_intercept self.normalize = normalize self.precompute_gram = precompute_gram self.overwrite_gram = overwrite_gram self.overwrite_Xy = overwrite_Xy self.overwrite_X = overwrite_X - self.eps = eps def fit(self, X, y, Gram=None, Xy=None): """Fit the model using X, y as training data. @@ -507,7 +506,7 @@ class OrthogonalMatchingPursuit(LinearModel): self.fit_intercept, self.normalize) - if self.n_nonzero_coefs == None and self.eps is None: + if self.n_nonzero_coefs == None and self.tol is None: self.n_nonzero_coefs = int(0.1 * n_features) if Gram is not None: @@ -538,15 +537,15 @@ class OrthogonalMatchingPursuit(LinearModel): Gram /= X_std Gram /= X_std[:, np.newaxis] - norms_sq = np.sum(y ** 2, axis=0) if self.eps is not None else None + norms_sq = np.sum(y ** 2, axis=0) if self.tol is not None else None self.coef_ = orthogonal_mp_gram(Gram, Xy, self.n_nonzero_coefs, - self.eps, norms_sq, + self.tol, norms_sq, overwrite_gram, True).T else: precompute_gram = self.precompute_gram if precompute_gram == 'auto': precompute_gram = X.shape[0] > X.shape[1] - self.coef_ = orthogonal_mp(X, y, self.n_nonzero_coefs, self.eps, + self.coef_ = orthogonal_mp(X, y, self.n_nonzero_coefs, self.tol, precompute_gram=self.precompute_gram, overwrite_X=self.overwrite_X).T diff --git a/sklearn/linear_model/tests/test_omp.py b/sklearn/linear_model/tests/test_omp.py index 44d73ef7e48d74c516c95cc27c4f3e3deee68dda..71f22b78a334c4ee363ba6d46d8c536396f2663b 100644 --- a/sklearn/linear_model/tests/test_omp.py +++ b/sklearn/linear_model/tests/test_omp.py @@ -39,12 +39,12 @@ def test_n_nonzero_coefs(): precompute_gram=True)) <= 5 -def test_eps(): - eps = 0.5 - gamma = orthogonal_mp(X, y[:, 0], eps=eps) - gamma_gram = orthogonal_mp(X, y[:, 0], eps=eps, precompute_gram=True) - assert np.sum((y[:, 0] - np.dot(X, gamma)) ** 2) <= eps - assert np.sum((y[:, 0] - np.dot(X, gamma_gram)) ** 2) <= eps +def test_tol(): + tol = 0.5 + gamma = orthogonal_mp(X, y[:, 0], tol=tol) + gamma_gram = orthogonal_mp(X, y[:, 0], tol=tol, precompute_gram=True) + assert np.sum((y[:, 0] - np.dot(X, gamma)) ** 2) <= tol + assert np.sum((y[:, 0] - np.dot(X, gamma_gram)) ** 2) <= tol def test_with_without_gram(): @@ -53,32 +53,32 @@ def test_with_without_gram(): orthogonal_mp(X, y, n_nonzero_coefs=5, precompute_gram=True)) -def test_with_without_gram_eps(): +def test_with_without_gram_tol(): assert_array_almost_equal( - orthogonal_mp(X, y, eps=1.), - orthogonal_mp(X, y, eps=1., precompute_gram=True)) + orthogonal_mp(X, y, tol=1.), + orthogonal_mp(X, y, tol=1., precompute_gram=True)) def test_unreachable_accuracy(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') assert_array_almost_equal( - orthogonal_mp(X, y, eps=0), + orthogonal_mp(X, y, tol=0), orthogonal_mp(X, y, n_nonzero_coefs=n_features)) assert_array_almost_equal( - orthogonal_mp(X, y, eps=0, precompute_gram=True), + orthogonal_mp(X, y, tol=0, precompute_gram=True), orthogonal_mp(X, y, precompute_gram=True, n_nonzero_coefs=n_features)) assert len(w) > 0 # warnings should be raised def test_bad_input(): - assert_raises(ValueError, orthogonal_mp, X, y, eps=-1) + assert_raises(ValueError, orthogonal_mp, X, y, tol=-1) assert_raises(ValueError, orthogonal_mp, X, y, n_nonzero_coefs=-1) assert_raises(ValueError, orthogonal_mp, X, y, n_nonzero_coefs=n_features + 1) - assert_raises(ValueError, orthogonal_mp_gram, G, Xy, eps=-1) + assert_raises(ValueError, orthogonal_mp_gram, G, Xy, tol=-1) assert_raises(ValueError, orthogonal_mp_gram, G, Xy, n_nonzero_coefs=-1) assert_raises(ValueError, orthogonal_mp_gram, G, Xy, n_nonzero_coefs=n_features + 1)