diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py index 11628fac836c5318ea019d38a2a64ca7c353bc05..472b4955c78017130c744c8eaf63ac85a7d27092 100644 --- a/scikits/learn/datasets/samples_generator.py +++ b/scikits/learn/datasets/samples_generator.py @@ -1,6 +1,115 @@ import numpy as np import numpy.random as nr + +""" +Samples generator + +""" + +# Author: B. Thirion, G. Varoquaux, A. Gramfort, V. Michel +# License: BSD 3 clause + + + + +###################################################################### +# Generate Dataset for test +###################################################################### + +def test_dataset_classif(n_samples=100, n_features=100, param=[1,1], + k=0, seed=None): + """ + Generate an snp matrix + + Parameters + ---------- + n_samples : 100, int, + the number of subjects + n_features : 100, int, + the number of features + param : [1,1], list, + parameter of a dirichlet density + that is used to generate multinomial densities + from which the n_featuress will be samples + k : 0, int, + number of informative features + seed : None, int or np.random.RandomState + if seed is an instance of np.random.RandomState, + it is used to initialize the random generator + + Returns + ------- + x : array of shape(n_samples, n_features), + the design matrix + y : array of shape (n_samples), + the subject labels + + """ + assert k<=n_features, ValueError('cannot have %d informative features and' + ' %d features' % (k, n_features)) + if isinstance(seed, np.random.RandomState): + random = seed + elif seed is not None: + random = np.random.RandomState(seed) + else: + random = np.random + + x = random.randn(n_samples, n_features) + y = np.zeros(n_samples) + param = np.ravel(np.array(param)).astype(np.float) + for n in range(n_samples): + y[n] = np.nonzero(random.multinomial(1, param/param.sum()))[0] + x[:,:k] += 3*y[:,np.newaxis] + return x, y + +def test_dataset_reg(n_samples=100, n_features=100, k=0, seed=None): + """ + Generate an snp matrix + + Parameters + ---------- + n_samples : 100, int, + the number of subjects + n_features : 100, int, + the number of features + k : 0, int, + number of informative features + seed : None, int or np.random.RandomState + if seed is an instance of np.random.RandomState, + it is used to initialize the random generator + + Returns + ------- + x : array of shape(n_samples, n_features), + the design matrix + y : array of shape (n_samples), + the subject data + + """ + assert k<n_features, ValueError('cannot have %d informative fetaures and' + ' %d features' % (k, n_features)) + if isinstance(seed, np.random.RandomState): + random = seed + elif seed is not None: + random = np.random.RandomState(seed) + else: + random = np.random + + x = random.randn(n_samples, n_features) + y = random.randn(n_samples) + x[:,:k] += y[:, np.newaxis] + return x, y + + + + + +###################################################################### +# Generate Dataset for regression +###################################################################### + + def sparse_uncorrelated(nb_samples=100, nb_features=10): """ Function creating simulated data with sparse uncorrelated design. diff --git a/scikits/learn/feature_selection/tests/test_feature_select.py b/scikits/learn/feature_selection/tests/test_feature_select.py index 0f5dab91e0e6ff3bf4d0483546b87eca8da75628..b67510a9b3c645044e83116768fad86702980b81 100644 --- a/scikits/learn/feature_selection/tests/test_feature_select.py +++ b/scikits/learn/feature_selection/tests/test_feature_select.py @@ -2,54 +2,22 @@ Todo: cross-check the F-value with stats model """ -from scikits.learn.feature_selection import univ_selection as fs +from scikits.learn.feature_selection import univariate_selection as us import numpy as np from numpy.testing import assert_array_equal, \ assert_array_almost_equal, \ assert_raises +import scikits.learn.datasets.samples_generator as sg -def make_dataset(n_samples=50, n_features=20, k=5, seed=None, classif=True, - param=[1,1]): - """ - Create a generic dataset for various tests - """ - if classif: - # classification - x, y = fs.generate_dataset_classif(n_samples, n_features, k=k, - seed=seed) - else: - # regression - x, y = fs.generate_dataset_reg(n_samples, n_features, k=k, seed=seed) - - return x, y - -# -# this test is commented because it depends on scikits.statsmodels -# -# def test_compare_with_statsmodels(): -# """ -# Test whether the F test yields the same results as scikits.statmodels -# """ -# x, y = make_dataset(classif=False) -# F, pv = fs.f_regression(x, y) - -# import scikits.statsmodels as sm -# nsubj = y.shape[0] -# nfeature = x.shape[1] -# q = np.zeros(nfeature) -# contrast = np.array([1,0]) -# for i in range(nfeature): -# q[i] = sm.OLS(x[:,i], np.vstack((y,np.ones(nsubj))).T).fit().f_test(contrast).pvalue -# assert_array_almost_equal(pv,q,1.e-6) - def test_F_test_classif(): """ Test whether the F test yields meaningful results on a simple simulated classification problem """ - x, y = make_dataset() - F, pv = fs.f_classif(x, y) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + F, pv = us.f_classif(X, Y) assert(F>0).all() assert(pv>0).all() assert(pv<1).all() @@ -61,8 +29,9 @@ def test_F_test_reg(): Test whether the F test yields meaningful results on a simple simulated regression problem """ - x, y = make_dataset(classif=False) - F, pv = fs.f_regression(x, y) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + F, pv = us.f_regression(X, Y) assert(F>0).all() assert(pv>0).all() assert(pv<1).all() @@ -74,8 +43,9 @@ def test_F_test_multi_class(): Test whether the F test yields meaningful results on a simple simulated classification problem """ - x, y = make_dataset(param=[1,1,1]) - F, pv = fs.f_classif(x, y) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None,param=[1,1,1]) + F, pv = us.f_classif(X, Y) assert(F>0).all() assert(pv>0).all() assert(pv<1).all() @@ -88,27 +58,11 @@ def test_univ_fs_percentile_classif(): gets the correct items in a simple classification problem with the percentile heuristic """ - x, y = make_dataset() - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_args=(25,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) - gtruth = np.zeros(20) - gtruth[:5]=1 - assert_array_equal(result, gtruth) - -def test_univ_fs_percentile_classif2(): - """ - Test whether the relative univariate feature selection - gets the correct items in a simple classification problem - with the percentile heuristic - """ - x, y = make_dataset() - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_func=fs.select_percentile, - select_args=(25,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectPercentile(us.f_classif) + X_r = univariate_filter.fit(X, Y).transform(X, percentile=25) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -119,12 +73,11 @@ def test_univ_fs_kbest_classif(): gets the correct items in a simple classification problem with the k best heuristic """ - x, y = make_dataset() - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_func=fs.select_k_best, - select_args=(5,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectKBest(us.f_classif) + X_r = univariate_filter.fit(X, Y).transform(X, k=5) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -135,12 +88,11 @@ def test_univ_fs_fpr_classif(): gets the correct items in a simple classification problem with the fpr heuristic """ - x, y = make_dataset() - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_func=fs.select_fpr, - select_args=(0.0001,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectFpr(us.f_classif) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.0001) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -151,12 +103,11 @@ def test_univ_fs_fdr_classif(): gets the correct items in a simple classification problem with the fpr heuristic """ - x, y = make_dataset(seed=3) - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_func=fs.select_fdr, - select_args=(0.01,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=3) + univariate_filter = us.SelectFdr(us.f_classif) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -167,44 +118,32 @@ def test_univ_fs_fwe_classif(): gets the correct items in a simple classification problem with the fpr heuristic """ - x, y = make_dataset() - univ_selection = fs.UnivSelection(score_func=fs.f_classif, - select_func=fs.select_fwe, - select_args=(0.01,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectFwe(us.f_classif) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert(np.sum(np.abs(result-gtruth))<2) -def test_univ_fs_percentile_regression(): - """ - Test whether the relative univariate feature selection - gets the correct items in a simple regression problem - with the percentile heuristic - """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_args=(25,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) - gtruth = np.zeros(20) - gtruth[:5]=1 - assert_array_equal(result, gtruth) -def test_univ_fs_percentile_regression2(): + + + + +def test_univ_fs_percentile_regression(): """ Test whether the relative univariate feature selection gets the correct items in a simple regression problem with the percentile heuristic """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_percentile, - select_args=(25,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectPercentile(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, percentile=25) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -214,12 +153,11 @@ def test_univ_fs_full_percentile_regression(): Test whether the relative univariate feature selection selects all features when '100%' is asked. """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_percentile, - select_args=(100,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectPercentile(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, percentile=100) + result = univariate_filter.support_.astype(int) gtruth = np.ones(20) assert_array_equal(result, gtruth) @@ -229,12 +167,11 @@ def test_univ_fs_kbest_regression(): gets the correct items in a simple regression problem with the k best heuristic """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_k_best, - select_args=(5,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectKBest(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, k=5) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -245,12 +182,11 @@ def test_univ_fs_fpr_regression(): gets the correct items in a simple regression problem with the fpr heuristic """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_fpr, - select_args=(0.01,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectFpr(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert(result[:5]==1).all() @@ -262,12 +198,11 @@ def test_univ_fs_fdr_regression(): gets the correct items in a simple regression problem with the fpr heuristic """ - x, y = make_dataset(seed=2, classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_fdr, - select_args=(0.01,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=2) + univariate_filter = us.SelectFdr(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 assert_array_equal(result, gtruth) @@ -278,14 +213,12 @@ def test_univ_fs_fwe_regression(): gets the correct items in a simple regression problem with the fpr heuristic """ - x, y = make_dataset(classif=False) - univ_selection = fs.UnivSelection(score_func=fs.f_regression, - select_func=fs.select_fwe, - select_args=(0.01,)) - univ_selection.fit(x, y) - result = univ_selection.support_.astype(int) + X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5, + seed=None) + univariate_filter = us.SelectFwe(us.f_regression) + X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01) + result = univariate_filter.support_.astype(int) gtruth = np.zeros(20) gtruth[:5]=1 - assert(result[:5]==1).all() assert(np.sum(result[5:]==1)<2) diff --git a/scikits/learn/feature_selection/univ_selection.py b/scikits/learn/feature_selection/univ_selection.py deleted file mode 100644 index 93dc2b324a0fb37c051759ab419b682d46683d50..0000000000000000000000000000000000000000 --- a/scikits/learn/feature_selection/univ_selection.py +++ /dev/null @@ -1,349 +0,0 @@ -""" -Univariate features selection. - -""" - -# Author: B. Thirion, G. Varoquaux, A. Gramfort -# License: BSD 3 clause - -import numpy as np -from scipy import stats - - -###################################################################### -# Generate Dataset -###################################################################### - -def generate_dataset_classif(n_samples=100, n_features=100, param=[1,1], - k=0, seed=None): - """ - Generate an snp matrix - - Parameters - ---------- - n_samples : 100, int, - the number of subjects - n_features : 100, int, - the number of featyres - param : [1,1], list, - parameter of a dirichlet density - that is used to generate multinomial densities - from which the n_featuress will be samples - k : 0, int, - number of informative features - seed : None, int or np.random.RandomState - if seed is an instance of np.random.RandomState, - it is used to initialize the random generator - - Returns - ------- - x : array of shape(n_samples, n_features), - the design matrix - y : array of shape (n_samples), - the subject labels - - """ - assert k<=n_features, ValueError('cannot have %d informative features and' - ' %d features' % (k, n_features)) - if isinstance(seed, np.random.RandomState): - random = seed - elif seed is not None: - random = np.random.RandomState(seed) - else: - random = np.random - - x = random.randn(n_samples, n_features) - y = np.zeros(n_samples) - param = np.ravel(np.array(param)).astype(np.float) - for n in range(n_samples): - y[n] = np.nonzero(random.multinomial(1, param/param.sum()))[0] - x[:,:k] += 3*y[:,np.newaxis] - return x, y - -def generate_dataset_reg(n_samples=100, n_features=100, k=0, seed=None): - """ - Generate an snp matrix - - Parameters - ---------- - n_samples : 100, int, - the number of subjects - n_features : 100, int, - the number of features - k : 0, int, - number of informative features - seed : None, int or np.random.RandomState - if seed is an instance of np.random.RandomState, - it is used to initialize the random generator - - Returns - ------- - x : array of shape(n_samples, n_features), - the design matrix - y : array of shape (n_samples), - the subject data - - """ - assert k<n_features, ValueError('cannot have %d informative fetaures and' - ' %d features' % (k, n_features)) - if isinstance(seed, np.random.RandomState): - random = seed - elif seed is not None: - random = np.random.RandomState(seed) - else: - random = np.random - - x = random.randn(n_samples, n_features) - y = random.randn(n_samples) - x[:,:k] += y[:, np.newaxis] - return x, y - - -###################################################################### -# Scoring functions -###################################################################### - -def f_classif(x, y): - """ - Compute the Anova F-value for the provided sample - - Parameters - ---------- - x : array of shape (n_samples, n_features) - the set of regressors sthat will tested sequentially - y : array of shape(n_samples) - the data matrix - - Returns - ------- - F : array of shape (m), - the set of F values - pval : array of shape(m), - the set of p-values - """ - x = np.asanyarray(x) - args = [x[y==k] for k in np.unique(y)] - return stats.f_oneway(*args) - - -def f_regression(x, y, center=True): - """ - Quick linear model for testing the effect of a single regressor, - sequentially for many regressors - This is done in 3 steps: - 1. the regressor of interest and the data are orthogonalized - wrt constant regressors - 2. the cross correlation between data and regressors is computed - 3. it is converted to an F score then to a p-value - - Parameters - ---------- - x : array of shape (n_samples, n_features) - the set of regressors sthat will tested sequentially - y : array of shape(n_samples) - the data matrix - - center : True, bool, - If true, x and y are centered - - Returns - ------- - F : array of shape (m), - the set of F values - pval : array of shape(m) - the set of p-values - """ - - # orthogonalize everything wrt to confounds - y = y.copy() - x = x.copy() - if center: - y -= np.mean(y) - x -= np.mean(x, 0) - - # compute the correlation - x /= np.sqrt(np.sum(x**2,0)) - y /= np.sqrt(np.sum(y**2)) - corr = np.dot(y, x) - - # convert to p-value - dof = y.size-2 - F = corr**2/(1-corr**2)*dof - pv = stats.f.sf(F, 1, dof) - return F, pv - - -###################################################################### -# Selection function -###################################################################### - -def select_percentile(p_values, percentile): - """ Select the best percentile of the p_values - """ - assert percentile<=100, ValueError('percentile should be \ - between 0 and 100 (%f given)' \ - %(percentile)) - alpha = stats.scoreatpercentile(p_values, percentile) - return (p_values <= alpha) - -def select_k_best(p_values, k): - """Select the k lowest p-values - """ - assert k<=len(p_values), ValueError('cannot select %d features' - ' among %d ' % (k, len(p_values))) - #alpha = stats.scoreatpercentile(p_values, 100.*k/len(p_values)) - alpha = np.sort(p_values)[k-1] - return (p_values <= alpha) - - -def select_fpr(p_values, alpha): - """Select the pvalues below alpha - """ - return (p_values < alpha) - -def select_fdr(p_values, alpha): - """ - Select the p-values corresponding to an estimated false discovery rate - of alpha - This uses the Benjamini-Hochberg procedure - """ - sv = np.sort(p_values) - threshold = sv[sv < alpha*np.arange(len(p_values))].max() - return (p_values < threshold) - -def select_fwe(p_values, alpha): - """ - Select the p-values corresponding to a corrected p-value of alpha - """ - return (p_values<alpha/len(p_values)) - - - -###################################################################### -# Univariate Selection -###################################################################### - -class UnivSelection(object): - - def __init__(self, estimator=None, - score_func=f_regression, - select_func=None, select_args=(10,)): - """ An object to do univariate selection before using a - classifier. - - Parameters - ----------- - estimator: None or an estimator instance - If an estimator is given, it is used to predict on the - features selected. - score_func: A callable - The function used to score features. Should be:: - - _, p_values = score_func(x, y) - - The first output argument is ignored. - select_func: A callable - The function used to select features. Should be:: - - support = select_func(p_values, *select_args) - If None is passed, the 10% lowest p_values are - selected. - select_args: A list or tuple - The arguments passed to select_func - """ - if not hasattr(select_args, '__iter__'): - select_args = list(select_args) - assert callable(score_func), ValueError( - "The score function should be a callable, '%s' (type %s) " - "was passed." % (score_func, type(score_func)) - ) - if select_func is None: - select_func = select_percentile - assert callable(select_func), ValueError( - "The score function should be a callable, '%s' (type %s) " - "was passed." % (select_func, type(select_func)) - ) - self.estimator = estimator - self.score_func = score_func - self.select_func = select_func - self.select_args = select_args - - - #-------------------------------------------------------------------------- - # Estimator interface - #-------------------------------------------------------------------------- - - def fit(self, x, y): - _, p_values_ = self.score_func(x, y) - self.support_ = self.select_func(p_values_,*self.select_args) - self.p_values_ = p_values_ - if self.estimator is not None: - self.estimator.fit(x[:,self.support_], y) - return self - - def predict(self, x=None): - # FIXME : support estimate is done again in predict too in - # case select_args have changed - self.support_ = self.select_func(self.p_values_, *self.select_args) - support_ = self.support_ - if x is None or self.estimator is None: - return support_ - else: - return self.estimator.predict(x[:,support_]) - - def predict_proba(self, X): - self.support_ = self.select_func(self.p_values_, *self.select_args) - support_ = self.support_ - return self.estimator.predict_proba(X[:,support_]) - -class UnivSelect(object): - - def __init__(self, score_func=f_regression, - select_func=None): - """ An object to do univariate selection before using a - classifier. - - Implements fit and reduce methods - - The reduce method returns the support of the selected - feature set. - - Parameters - ----------- - score_func: A callable - The function used to score features. Should be:: - - _, p_values = score_func(x, y) - - The first output argument is ignored. - select_func: A callable - The function used to select features. Should be:: - - support = select_func(p_values, *select_args) - If None is passed, the 10% lowest p_values are - selected. - """ - if select_func is None: - select_func = select_percentile - assert callable(select_func), ValueError( - "The score function should be a callable, '%s' (type %s) " - "was passed." % (select_func, type(select_func)) - ) - self.score_func = score_func - self.select_func = select_func - - #-------------------------------------------------------------------------- - # Interface - #-------------------------------------------------------------------------- - - def fit(self, x, y): - self.x = x - self.y = y - _, p_values_ = self.score_func(x, y) - self.p_values_ = p_values_ - return self - - def reduce(self, n_features): - support = self.select_func(self.p_values_, n_features) - return support - diff --git a/scikits/learn/feature_selection/univariate_selection.py b/scikits/learn/feature_selection/univariate_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b7e4d84da4ec601a92cc72bbd5cb06890b78ea --- /dev/null +++ b/scikits/learn/feature_selection/univariate_selection.py @@ -0,0 +1,203 @@ +""" +Univariate features selection. +""" + +# Authors: V. Michel, B. Thirion, G. Varoquaux, A. Gramfort, E. Duchesnay +# License: BSD 3 clause + +import numpy as np +from scipy import stats + +###################################################################### +# Scoring functions +###################################################################### + +def f_classif(X, y): + """ + Compute the Anova F-value for the provided sample + + Parameters + ---------- + X : array of shape (n_samples, n_features) + the set of regressors sthat will tested sequentially + y : array of shape(n_samples) + the data matrix + + Returns + ------- + F : array of shape (m), + the set of F values + pval : array of shape(m), + the set of p-values + """ + X = np.asanyarray(X) + args = [X[y==k] for k in np.unique(y)] + return stats.f_oneway(*args) + + +def f_regression(X, y, center=True): + """ + Quick linear model for testing the effect of a single regressor, + sequentially for many regressors + This is done in 3 steps: + 1. the regressor of interest and the data are orthogonalized + wrt constant regressors + 2. the cross correlation between data and regressors is computed + 3. it is converted to an F score then to a p-value + + Parameters + ---------- + X : array of shape (n_samples, n_features) + the set of regressors sthat will tested sequentially + y : array of shape(n_samples) + the data matrix + + center : True, bool, + If true, X and y are centered + + Returns + ------- + F : array of shape (m), + the set of F values + pval : array of shape(m) + the set of p-values + """ + + # orthogonalize everything wrt to confounds + y = y.copy() + X = X.copy() + if center: + y -= np.mean(y) + X -= np.mean(X, 0) + + # compute the correlation + X /= np.sqrt(np.sum(X**2,0)) + y /= np.sqrt(np.sum(y**2)) + corr = np.dot(y, X) + + # convert to p-value + dof = y.size-2 + F = corr**2/(1-corr**2)*dof + pv = stats.f.sf(F, 1, dof) + return F, pv + + +###################################################################### +# General class for filter univariate selection +###################################################################### + + +class UnivariateFilter(object): + + def __init__(self, score_func): + """ + Initialize the univariate feature selection. + Func : function taking two arrays X and y, and returning an array. + score_func returning both scores and pvalues + """ + assert callable(score_func), ValueError( + "The score function should be a callable, '%s' (type %s) " + "was passed." % (score_func, type(score_func)) + ) + self.score_func = score_func + + def fit(self,X,y): + """ + Evaluate the function + """ + _scores = self.score_func(X, y) + self._scores = _scores[0] + self._pvalues = _scores[1] + #self._rank = np.argsort(self._pvalues) + return self + + def get_selected_features(self): + """ + Returns the indices of the selected features + """ + return self.support_ + + #def transform(self): + #""" + #Returns the indices of the selected features + #""" + #raise("Error : Not implemented") + + + +###################################################################### +# Specific filters +###################################################################### + +class SelectPercentile(UnivariateFilter): + """ + Filter : Select the best percentile of the p_values + """ + def transform(self,X,percentile): + """ + Transform the data. + """ + assert percentile<=100, ValueError('percentile should be \ + between 0 and 100 (%f given)' %(percentile)) + alpha = stats.scoreatpercentile(self._pvalues, percentile) + self.support_ = (self._pvalues <= alpha) + return X[:,self.support_] + +class SelectKBest(UnivariateFilter): + """ + Filter : Select the k lowest p-values + """ + def transform(self,X,k): + assert k<=len(self._pvalues), ValueError('cannot select %d features' + ' among %d ' % (k, len(self._pvalues))) + alpha = np.sort(self._pvalues)[k-1] + self.support_ = (self._pvalues <= alpha) + return X[:,self.support_] + + +class SelectFpr(UnivariateFilter): + """ + Filter : Select the pvalues below alpha + """ + def transform(self,X,alpha): + self.support_ = (self._pvalues < alpha) + return X[:,self.support_] + + +class SelectFdr(UnivariateFilter): + """ + Filter : Select the p-values corresponding to an estimated false + discovery rate of alpha. This uses the Benjamini-Hochberg procedure + """ + def transform(self,X,alpha): + sv = np.sort(self._pvalues) + threshold = sv[sv < alpha*np.arange(len(self._pvalues))].max() + self.support_ = (self._pvalues < threshold) + return X[:,self.support_] + + +class SelectFwe(UnivariateFilter): + """ + Filter : Select the p-values corresponding to a corrected p-value of alpha + """ + def transform(self,X,alpha): + self.support_ = (self._pvalues < alpha/len(self._pvalues)) + return X[:,self.support_] + + + +if __name__ == "__main__": + import scikits.learn.datasets.samples_generator as sg + from scikits.learn.svm import SVR, SVC + + X,y = sg.sparse_uncorrelated(50,100) + univariate_filter = SelectKBest(f_regression) + X_r = univariate_filter.fit(X, y).transform(X, k=5) + clf = SVR(kernel='linear', C=1.) + y_ = clf.fit(X_r, y).predict(X_r) + print univariate_filter.support_.astype(int) + + ### now change k + X_r = univariate_filter.transform(X, k=2) + y_ = clf.fit(X_r, y).predict(X) + print univariate_filter.support_.astype(int)