diff --git a/benchmarks/bench_glmnet.py b/benchmarks/bench_glmnet.py index 8f3bc35b03d8b46d74689c616587dc6f8d0be136..0576300d3c6a8982eda8dbb9382d198fed63ecf1 100644 --- a/benchmarks/bench_glmnet.py +++ b/benchmarks/bench_glmnet.py @@ -19,8 +19,8 @@ In both cases, only 10% of the features are informative. import numpy as np import gc from time import time +from scikits.learn.datasets.samples_generator import make_regression_dataset -# alpha = 1.0 alpha = 0.1 # alpha = 0.01 @@ -29,34 +29,6 @@ def rmse(a, b): return np.sqrt(np.mean((a - b) ** 2)) -def make_data(n_samples=100, n_tests=100, n_features=100, k=10, - noise=0.1, seed=0): - - # deterministic test - np.random.seed(seed) - - # generate random input set - X = np.random.randn(n_samples, n_features) - X_test = np.random.randn(n_tests, n_features) - - # generate a ground truth model with only the first 10 features being non - # zeros (the other features are not correlated to Y and should be ignored by - # the L1 regularizer) - coef_ = np.random.randn(n_features) - coef_[k:] = 0.0 - - # generate the ground truth Y from the reference model and X - Y = np.dot(X, coef_) - if noise > 0.0: - Y += np.random.normal(scale=noise, size=Y.shape) - - Y_test = np.dot(X_test, coef_) - if noise > 0.0: - Y_test += np.random.normal(scale=noise, size=Y_test.shape) - - return X, Y, X_test, Y_test, coef_ - - def bench(factory, X, Y, X_test, Y_test, ref_coef): gc.collect() @@ -66,7 +38,7 @@ def bench(factory, X, Y, X_test, Y_test, ref_coef): delta = (time() - tstart) # stop time - print "duration: %fms" % (delta * 1000) + print "duration: %0.3fs" % delta print "rmse: %f" % rmse(Y_test, clf.predict(X_test)) print "mean coef abs diff: %f" % abs(ref_coef - clf.coef_.ravel()).mean() return delta @@ -74,7 +46,7 @@ def bench(factory, X, Y, X_test, Y_test, ref_coef): if __name__ == '__main__': from glmnet.elastic_net import Lasso as GlmnetLasso - from scikits.learn.glm import Lasso as ScikitLasso + from scikits.learn.linear_model import Lasso as ScikitLasso # Delayed import of pylab import pylab as pl @@ -83,20 +55,20 @@ if __name__ == '__main__': n = 20 step = 500 n_features = 1000 - k = n_features / 10 - n_tests = 1000 + n_informative = n_features / 10 + n_test_samples = 1000 for i in range(1, n + 1): print '==================' print 'Iteration %s of %s' % (i, n) print '==================' - X, Y, X_test, Y_test, coef_ = make_data( - n_samples=(i * step), n_tests=n_tests, n_features=n_features, - noise=0.1, k=k) + X, Y, X_test, Y_test, coef = make_regression_dataset( + n_train_samples=(i * step), n_test_samples=n_test_samples, + n_features=n_features, noise=0.1, n_informative=n_informative) print "benching scikit: " - scikit_results.append(bench(ScikitLasso, X, Y, X_test, Y_test, coef_)) + scikit_results.append(bench(ScikitLasso, X, Y, X_test, Y_test, coef)) print "benching glmnet: " - glmnet_results.append(bench(GlmnetLasso, X, Y, X_test, Y_test, coef_)) + glmnet_results.append(bench(GlmnetLasso, X, Y, X_test, Y_test, coef)) pl.clf() xx = range(0, n*step, step) @@ -122,10 +94,10 @@ if __name__ == '__main__': print 'Iteration %02d of %02d' % (i, n) print '==================' n_features = i * step - k = n_features / 10 - X, Y, X_test, Y_test, coef_ = make_data( - n_samples=n_samples, n_tests=n_tests, n_features=n_features, - noise=0.1, k=k) + n_informative = n_features / 10 + X, Y, X_test, Y_test, coef_ = make_regression_dataset( + n_train_samples=n_samples, n_test_samples=n_test_samples, + n_features=n_features, noise=0.1, n_informative=n_informative) print "benching scikit: " scikit_results.append(bench(ScikitLasso, X, Y, X_test, Y_test, coef_)) diff --git a/benchmarks/bench_lasso.py b/benchmarks/bench_lasso.py index 69d63fad404dc20b329c0a124d222acbe7279c9a..802bd78c2aa54c1421fbe6071fd8267403fa1617 100644 --- a/benchmarks/bench_lasso.py +++ b/benchmarks/bench_lasso.py @@ -15,7 +15,7 @@ import gc from time import time import numpy as np -from bench_glmnet import make_data +from scikits.learn.datasets.samples_generator import make_regression_dataset def compute_bench(alpha, n_samples, n_features, precompute): @@ -23,7 +23,7 @@ def compute_bench(alpha, n_samples, n_features, precompute): lasso_results = [] larslasso_results = [] - n_tests = 1000 + n_test_samples = 1000 it = 0 for ns in n_samples: @@ -33,10 +33,10 @@ def compute_bench(alpha, n_samples, n_features, precompute): print 'Iteration %s of %s' % (it, max(len(n_samples), len(n_features))) print '==================' - k = nf // 10 - X, Y, X_test, Y_test, coef_ = make_data( - n_samples=ns, n_tests=n_tests, n_features=nf, - noise=0.1, k=k) + n_informative = nf // 10 + X, Y, X_test, Y_test, coef = make_regression_dataset( + n_train_samples=ns, n_test_samples=n_test_samples, + n_features=nf, noise=0.1, n_informative = n_informative) X /= np.sqrt(np.sum(X**2, axis=0)) # Normalize data diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py index 1cb9be8a2852aee6909efa7b2d5bca0e04af4f3d..61ed5f524e74add3109bca875b758096f7485aa6 100644 --- a/scikits/learn/datasets/samples_generator.py +++ b/scikits/learn/datasets/samples_generator.py @@ -182,7 +182,7 @@ def friedman(n_samples=100, n_features=10, noise_std=1): def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10, - tail_strength=0.5, seed=None): + tail_strength=0.5, seed=0): """Mostly low rank random matrix with bell-shaped singular values profile Most of the variance can be explained by a bell-shaped curve of width @@ -217,7 +217,10 @@ def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10, tail_strength: float between 0.0 and 1.0 relative importance of the fat noisy tail of the singular values - profile. + profile (default is 0.5). + + seed: int or RandomState or None + how to seed the random number generator (default is 0) """ if isinstance(seed, np.random.RandomState): @@ -244,3 +247,101 @@ def low_rank_fat_tail(n_samples=100, n_features=100, effective_rank=10, return np.dot(np.dot(u, s), v) +def make_regression_dataset(n_train_samples=100, n_test_samples=100, + n_features=100, n_informative=10, + effective_rank=None, tail_strength=0.5, + bias=0., noise=0.05, seed=0): + """Generate a regression train + test set + + The input set can be well conditioned (by default) or have a low rank-fat + tail singular profile. See the low_rank_fat_tail docstring for more + details. + + The output is generated by applying a (potentially biased) random linear + regression model with n_informative nonzero regressors to the previously + generated input and some gaussian centered noise with some adjustable + scale. + + Parameters + ---------- + n_train_samples : int + number of samples for the training set (default is 100) + + n_test_samples : int + number of samples for the testing set (default is 100) + + n_features : int + number of features (default is 100) + + n_informative: int or float between 0.0 and 1.0 + Number of informative features (nonzero regressors in the ground truth + linear model used to generate the output). + + effective_rank : int or None + if not None (default is 50): + approximate number of singular vectors required to explain most of + the data by linear combinations on the input sets. Using this kind + of singular spectrum in the input allow the datagenerator to + reproduce the kind of correlation often observed in practice. + if None: + the input sets are well conditioned centered gaussian with unit + variance + + tail_strength: float between 0.0 and 1.0 + relative importance of the fat noisy tail of the singular values + profile if effective_rank is not None + + bias: float + bias for the ground truth model (default is 0.0) + + noise: + variance of the gaussian noise applied to the output (default is 0.05) + + seed: int or RandomState or None + how to seed the random number generator (default is 0) + + """ + # allow for reproducible samples generation by explicit random number + # generator seeding + if isinstance(seed, np.random.RandomState): + random = seed + elif seed is not None: + random = np.random.RandomState(seed) + else: + random = np.random + + if effective_rank is None: + # randomly generate a well conditioned input set + X_train = random.randn(n_train_samples, n_features) + X_test = random.randn(n_test_samples, n_features) + else: + # randomly generate a low rank, fat tail input set + X_train = low_rank_fat_tail( + n_samples=n_train_samples, n_features=n_features, + effective_rank=effective_rank, tail_strength=tail_strength, + seed=random) + + X_test = low_rank_fat_tail( + n_samples=n_test_samples, n_features=n_features, + effective_rank=effective_rank, tail_strength=tail_strength, + seed=random) + + # generate a ground truth model with only n_informative features being non + # zeros (the other features are not correlated to Y and should be ignored + # by a sparsifying regularizers such as L1 or elastic net) + ground_truth = np.zeros(n_features) + ground_truth[:n_informative] = random.randn(n_informative) + random.shuffle(ground_truth) + + # generate the ground truth Y from the reference model and X + Y_train = np.dot(X_train, ground_truth) + bias + Y_test = np.dot(X_test, ground_truth) + bias + + if noise > 0.0: + # apply some gaussian noise to the output + Y_train += random.normal(scale=noise, size=Y_train.shape) + Y_test += random.normal(scale=noise, size=Y_test.shape) + + return X_train, Y_train, X_test, Y_test, ground_truth + +