diff --git a/scikits/learn/datasets/samples_generator.py b/scikits/learn/datasets/samples_generator.py
index 11628fac836c5318ea019d38a2a64ca7c353bc05..472b4955c78017130c744c8eaf63ac85a7d27092 100644
--- a/scikits/learn/datasets/samples_generator.py
+++ b/scikits/learn/datasets/samples_generator.py
@@ -1,6 +1,115 @@
 import numpy as np
 import numpy.random as nr
 
+
+"""
+Samples generator
+
+"""
+
+# Author: B. Thirion, G. Varoquaux, A. Gramfort, V. Michel
+# License: BSD 3 clause
+
+
+
+
+######################################################################
+# Generate Dataset for test
+######################################################################
+
+def test_dataset_classif(n_samples=100, n_features=100, param=[1,1],
+                             k=0, seed=None):
+    """
+    Generate an snp matrix
+
+    Parameters
+    ----------
+    n_samples : 100, int,
+        the number of subjects
+    n_features : 100, int,
+        the number of features
+    param : [1,1], list,
+        parameter of a dirichlet density
+        that is used to generate multinomial densities
+        from which the n_featuress will be samples
+    k : 0, int,
+        number of informative features
+    seed : None, int or np.random.RandomState
+        if seed is an instance of np.random.RandomState,
+        it is used to initialize the random generator
+
+    Returns
+    -------
+    x : array of shape(n_samples, n_features),
+        the design matrix
+    y : array of shape (n_samples),
+        the subject labels
+
+    """
+    assert k<=n_features, ValueError('cannot have %d informative features and'
+                                   ' %d features' % (k, n_features))
+    if isinstance(seed, np.random.RandomState):
+        random = seed
+    elif seed is not None:
+        random = np.random.RandomState(seed)
+    else:
+        random = np.random
+
+    x = random.randn(n_samples, n_features)
+    y = np.zeros(n_samples)
+    param = np.ravel(np.array(param)).astype(np.float)
+    for n in range(n_samples):
+        y[n] = np.nonzero(random.multinomial(1, param/param.sum()))[0]
+    x[:,:k] += 3*y[:,np.newaxis]
+    return x, y
+
+def test_dataset_reg(n_samples=100, n_features=100, k=0, seed=None):
+    """
+    Generate an snp matrix
+
+    Parameters
+    ----------
+    n_samples : 100, int,
+        the number of subjects
+    n_features : 100, int,
+        the number of features
+    k : 0, int,
+        number of informative features
+    seed : None, int or np.random.RandomState
+        if seed is an instance of np.random.RandomState,
+        it is used to initialize the random generator
+
+    Returns
+    -------
+    x : array of shape(n_samples, n_features),
+        the design matrix
+    y : array of shape (n_samples),
+        the subject data
+
+    """
+    assert k<n_features, ValueError('cannot have %d informative fetaures and'
+                                   ' %d features' % (k, n_features))
+    if isinstance(seed, np.random.RandomState):
+        random = seed
+    elif seed is not None:
+        random = np.random.RandomState(seed)
+    else:
+        random = np.random
+
+    x = random.randn(n_samples, n_features)
+    y = random.randn(n_samples)
+    x[:,:k] += y[:, np.newaxis]
+    return x, y
+
+
+
+
+
+######################################################################
+# Generate Dataset for regression
+######################################################################
+
+
 def sparse_uncorrelated(nb_samples=100, nb_features=10):
     """
     Function creating simulated data with sparse uncorrelated design.
diff --git a/scikits/learn/feature_selection/tests/test_feature_select.py b/scikits/learn/feature_selection/tests/test_feature_select.py
index 0f5dab91e0e6ff3bf4d0483546b87eca8da75628..b67510a9b3c645044e83116768fad86702980b81 100644
--- a/scikits/learn/feature_selection/tests/test_feature_select.py
+++ b/scikits/learn/feature_selection/tests/test_feature_select.py
@@ -2,54 +2,22 @@
 Todo: cross-check the F-value with stats model
 """
 
-from scikits.learn.feature_selection import univ_selection  as fs
+from scikits.learn.feature_selection import univariate_selection  as us
 import numpy as np
 from numpy.testing import assert_array_equal, \
                           assert_array_almost_equal, \
                           assert_raises
+import scikits.learn.datasets.samples_generator as sg
 
-def make_dataset(n_samples=50, n_features=20, k=5, seed=None, classif=True,
-                 param=[1,1]):
-    """
-    Create a generic dataset for various tests
-    """
-    if classif:
-        # classification
-        x, y = fs.generate_dataset_classif(n_samples, n_features, k=k,
-                                           seed=seed)
-    else:
-        # regression
-        x, y = fs.generate_dataset_reg(n_samples, n_features, k=k, seed=seed)
-        
-    return x, y
-
-# 
-# this test is commented because it depends on scikits.statsmodels
-# 
-# def test_compare_with_statsmodels():
-#     """
-#     Test whether the F test yields the same results as scikits.statmodels
-#     """
-#     x, y = make_dataset(classif=False)
-#     F, pv = fs.f_regression(x, y)
-
-#     import scikits.statsmodels as sm
-#     nsubj = y.shape[0]
-#     nfeature = x.shape[1]
-#     q = np.zeros(nfeature)
-#     contrast = np.array([1,0])
-#     for i in range(nfeature):
-#         q[i] = sm.OLS(x[:,i], np.vstack((y,np.ones(nsubj))).T).fit().f_test(contrast).pvalue 
-#     assert_array_almost_equal(pv,q,1.e-6)
 
-    
 def test_F_test_classif():
     """
     Test whether the F test yields meaningful results
     on a simple simulated classification problem
     """
-    x, y = make_dataset()
-    F, pv = fs.f_classif(x, y)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    F, pv = us.f_classif(X, Y)
     assert(F>0).all()
     assert(pv>0).all()
     assert(pv<1).all()
@@ -61,8 +29,9 @@ def test_F_test_reg():
     Test whether the F test yields meaningful results
     on a simple simulated regression problem
     """
-    x, y = make_dataset(classif=False)
-    F, pv = fs.f_regression(x, y)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    F, pv = us.f_regression(X, Y)
     assert(F>0).all()
     assert(pv>0).all()
     assert(pv<1).all()
@@ -74,8 +43,9 @@ def test_F_test_multi_class():
     Test whether the F test yields meaningful results
     on a simple simulated classification problem
     """
-    x, y = make_dataset(param=[1,1,1])
-    F, pv = fs.f_classif(x, y)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None,param=[1,1,1])
+    F, pv = us.f_classif(X, Y)
     assert(F>0).all()
     assert(pv>0).all()
     assert(pv<1).all()
@@ -88,27 +58,11 @@ def test_univ_fs_percentile_classif():
     gets the correct items in a simple classification problem
     with the percentile heuristic
     """
-    x, y = make_dataset()
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_args=(25,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
-    gtruth = np.zeros(20)
-    gtruth[:5]=1
-    assert_array_equal(result, gtruth)
-
-def test_univ_fs_percentile_classif2():
-    """
-    Test whether the relative univariate feature selection
-    gets the correct items in a simple classification problem
-    with the percentile heuristic
-    """
-    x, y = make_dataset()
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_func=fs.select_percentile,
-                                      select_args=(25,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectPercentile(us.f_classif)
+    X_r = univariate_filter.fit(X, Y).transform(X, percentile=25)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -119,12 +73,11 @@ def test_univ_fs_kbest_classif():
     gets the correct items in a simple classification problem
     with the k best heuristic
     """
-    x, y = make_dataset()
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_func=fs.select_k_best,
-                                      select_args=(5,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectKBest(us.f_classif)
+    X_r = univariate_filter.fit(X, Y).transform(X, k=5)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -135,12 +88,11 @@ def test_univ_fs_fpr_classif():
     gets the correct items in a simple classification problem
     with the fpr heuristic
     """
-    x, y = make_dataset()
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_func=fs.select_fpr,
-                                      select_args=(0.0001,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectFpr(us.f_classif)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.0001)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -151,12 +103,11 @@ def test_univ_fs_fdr_classif():
     gets the correct items in a simple classification problem
     with the fpr heuristic
     """
-    x, y = make_dataset(seed=3)
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_func=fs.select_fdr,
-                                      select_args=(0.01,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=3)
+    univariate_filter = us.SelectFdr(us.f_classif)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -167,44 +118,32 @@ def test_univ_fs_fwe_classif():
     gets the correct items in a simple classification problem
     with the fpr heuristic
     """
-    x, y = make_dataset()
-    univ_selection = fs.UnivSelection(score_func=fs.f_classif,
-                                      select_func=fs.select_fwe,
-                                      select_args=(0.01,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_classif(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectFwe(us.f_classif)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert(np.sum(np.abs(result-gtruth))<2)
 
 
-def test_univ_fs_percentile_regression():
-    """
-    Test whether the relative univariate feature selection
-    gets the correct items in a simple regression problem
-    with the percentile heuristic
-    """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_args=(25,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
-    gtruth = np.zeros(20)
-    gtruth[:5]=1
-    assert_array_equal(result, gtruth)
 
-def test_univ_fs_percentile_regression2():
+
+
+
+
+def test_univ_fs_percentile_regression():
     """
     Test whether the relative univariate feature selection
     gets the correct items in a simple regression problem
     with the percentile heuristic
     """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_percentile,
-                                      select_args=(25,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectPercentile(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, percentile=25)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -214,12 +153,11 @@ def test_univ_fs_full_percentile_regression():
     Test whether the relative univariate feature selection
     selects all features when '100%' is asked.
     """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_percentile,
-                                      select_args=(100,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectPercentile(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, percentile=100)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.ones(20)
     assert_array_equal(result, gtruth)
 
@@ -229,12 +167,11 @@ def test_univ_fs_kbest_regression():
     gets the correct items in a simple regression problem
     with the k best heuristic
     """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_k_best,
-                                      select_args=(5,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectKBest(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, k=5)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -245,12 +182,11 @@ def test_univ_fs_fpr_regression():
     gets the correct items in a simple regression problem
     with the fpr heuristic
     """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_fpr,
-                                      select_args=(0.01,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectFpr(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert(result[:5]==1).all()
@@ -262,12 +198,11 @@ def test_univ_fs_fdr_regression():
     gets the correct items in a simple regression problem
     with the fpr heuristic
     """
-    x, y = make_dataset(seed=2, classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_fdr,
-                                      select_args=(0.01,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=2)
+    univariate_filter = us.SelectFdr(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
     assert_array_equal(result, gtruth)
@@ -278,14 +213,12 @@ def test_univ_fs_fwe_regression():
     gets the correct items in a simple regression problem
     with the fpr heuristic
     """
-    x, y = make_dataset(classif=False)
-    univ_selection = fs.UnivSelection(score_func=fs.f_regression,
-                                      select_func=fs.select_fwe,
-                                      select_args=(0.01,))
-    univ_selection.fit(x, y)
-    result = univ_selection.support_.astype(int)
+    X, Y = sg.test_dataset_reg(n_samples=50, n_features=20, k=5,
+                                           seed=None)
+    univariate_filter = us.SelectFwe(us.f_regression)
+    X_r = univariate_filter.fit(X, Y).transform(X, alpha=0.01)
+    result = univariate_filter.support_.astype(int)
     gtruth = np.zeros(20)
     gtruth[:5]=1
-
     assert(result[:5]==1).all()
     assert(np.sum(result[5:]==1)<2)
diff --git a/scikits/learn/feature_selection/univ_selection.py b/scikits/learn/feature_selection/univ_selection.py
deleted file mode 100644
index 93dc2b324a0fb37c051759ab419b682d46683d50..0000000000000000000000000000000000000000
--- a/scikits/learn/feature_selection/univ_selection.py
+++ /dev/null
@@ -1,349 +0,0 @@
-"""
-Univariate features selection.
-
-"""
-
-# Author: B. Thirion, G. Varoquaux, A. Gramfort
-# License: BSD 3 clause
-
-import numpy as np
-from scipy import stats
-
-
-######################################################################
-# Generate Dataset
-######################################################################
-
-def generate_dataset_classif(n_samples=100, n_features=100, param=[1,1],
-                             k=0, seed=None):
-    """
-    Generate an snp matrix
-
-    Parameters
-    ----------
-    n_samples : 100, int,
-        the number of subjects
-    n_features : 100, int,
-        the number of featyres
-    param : [1,1], list,
-        parameter of a dirichlet density
-        that is used to generate multinomial densities
-        from which the n_featuress will be samples
-    k : 0, int,
-        number of informative features
-    seed : None, int or np.random.RandomState
-        if seed is an instance of np.random.RandomState,
-        it is used to initialize the random generator
-
-    Returns
-    -------
-    x : array of shape(n_samples, n_features),
-        the design matrix
-    y : array of shape (n_samples),
-        the subject labels
-
-    """
-    assert k<=n_features, ValueError('cannot have %d informative features and'
-                                   ' %d features' % (k, n_features))
-    if isinstance(seed, np.random.RandomState):
-        random = seed
-    elif seed is not None:
-        random = np.random.RandomState(seed)
-    else:
-        random = np.random
-
-    x = random.randn(n_samples, n_features)
-    y = np.zeros(n_samples)
-    param = np.ravel(np.array(param)).astype(np.float)
-    for n in range(n_samples):
-        y[n] = np.nonzero(random.multinomial(1, param/param.sum()))[0]
-    x[:,:k] += 3*y[:,np.newaxis]
-    return x, y
-
-def generate_dataset_reg(n_samples=100, n_features=100, k=0, seed=None):
-    """
-    Generate an snp matrix
-
-    Parameters
-    ----------
-    n_samples : 100, int,
-        the number of subjects
-    n_features : 100, int,
-        the number of features
-    k : 0, int,
-        number of informative features
-    seed : None, int or np.random.RandomState
-        if seed is an instance of np.random.RandomState,
-        it is used to initialize the random generator
-
-    Returns
-    -------
-    x : array of shape(n_samples, n_features),
-        the design matrix
-    y : array of shape (n_samples),
-        the subject data
-
-    """
-    assert k<n_features, ValueError('cannot have %d informative fetaures and'
-                                   ' %d features' % (k, n_features))
-    if isinstance(seed, np.random.RandomState):
-        random = seed
-    elif seed is not None:
-        random = np.random.RandomState(seed)
-    else:
-        random = np.random
-
-    x = random.randn(n_samples, n_features)
-    y = random.randn(n_samples)
-    x[:,:k] += y[:, np.newaxis]
-    return x, y
-
-
-######################################################################
-# Scoring functions
-######################################################################
-
-def f_classif(x, y):
-    """
-    Compute the Anova F-value for the provided sample
-
-    Parameters
-    ----------
-    x : array of shape (n_samples, n_features)
-        the set of regressors sthat will tested sequentially
-    y : array of shape(n_samples)
-        the data matrix
-
-    Returns
-    -------
-    F : array of shape (m),
-        the set of F values
-    pval : array of shape(m),
-        the set of p-values
-    """
-    x = np.asanyarray(x)
-    args = [x[y==k] for k in np.unique(y)]
-    return stats.f_oneway(*args)
-
-
-def f_regression(x, y, center=True):
-    """
-    Quick linear model for testing the effect of a single regressor,
-    sequentially for many regressors
-    This is done in 3 steps:
-    1. the regressor of interest and the data are orthogonalized
-    wrt constant regressors
-    2. the cross correlation between data and regressors is computed
-    3. it is converted to an F score then to a p-value
-
-    Parameters
-    ----------
-    x : array of shape (n_samples, n_features)
-        the set of regressors sthat will tested sequentially
-    y : array of shape(n_samples)
-        the data matrix
-
-    center : True, bool,
-        If true, x and y are centered
-
-    Returns
-    -------
-    F : array of shape (m),
-        the set of F values
-    pval : array of shape(m)
-        the set of p-values
-    """
-
-    # orthogonalize everything wrt to confounds
-    y = y.copy()
-    x = x.copy()
-    if center:
-        y -= np.mean(y)
-        x -= np.mean(x, 0)
-
-    # compute the correlation
-    x /= np.sqrt(np.sum(x**2,0))
-    y /= np.sqrt(np.sum(y**2))
-    corr = np.dot(y, x)
-
-    # convert to p-value
-    dof = y.size-2
-    F = corr**2/(1-corr**2)*dof
-    pv = stats.f.sf(F, 1, dof)
-    return F, pv
-
-
-######################################################################
-# Selection function
-######################################################################
-
-def select_percentile(p_values, percentile):
-    """ Select the best percentile of the p_values
-    """
-    assert percentile<=100, ValueError('percentile should be \
-                                             between 0 and 100 (%f given)' \
-                                             %(percentile))
-    alpha = stats.scoreatpercentile(p_values, percentile)
-    return (p_values <= alpha)
-
-def select_k_best(p_values, k):
-    """Select the k lowest p-values
-    """
-    assert k<=len(p_values), ValueError('cannot select %d features'
-                                       ' among %d ' % (k, len(p_values)))
-    #alpha = stats.scoreatpercentile(p_values, 100.*k/len(p_values))
-    alpha = np.sort(p_values)[k-1]
-    return (p_values <= alpha)
-
-
-def select_fpr(p_values, alpha):
-    """Select the pvalues below alpha
-    """
-    return (p_values < alpha)
-
-def select_fdr(p_values, alpha):
-    """
-    Select the p-values corresponding to an estimated false discovery rate
-    of alpha
-    This uses the Benjamini-Hochberg procedure
-    """
-    sv = np.sort(p_values)
-    threshold = sv[sv < alpha*np.arange(len(p_values))].max()
-    return (p_values < threshold)
-
-def select_fwe(p_values, alpha):
-    """
-    Select the p-values corresponding to a corrected p-value of alpha
-    """
-    return (p_values<alpha/len(p_values))
-
-
-
-######################################################################
-# Univariate Selection
-######################################################################
-
-class UnivSelection(object):
-
-    def __init__(self, estimator=None,
-                       score_func=f_regression,
-                       select_func=None, select_args=(10,)):
-        """ An object to do univariate selection before using a
-            classifier.
-
-            Parameters
-            -----------
-            estimator: None or an estimator instance
-                If an estimator is given, it is used to predict on the
-                features selected.
-            score_func: A callable
-                The function used to score features. Should be::
-
-                    _, p_values = score_func(x, y)
-
-                The first output argument is ignored.
-            select_func: A callable
-                The function used to select features. Should be::
-
-                    support = select_func(p_values, *select_args)
-                If None is passed, the 10% lowest p_values are
-                selected.
-            select_args: A list or tuple
-                The arguments passed to select_func
-        """
-        if not hasattr(select_args, '__iter__'):
-            select_args = list(select_args)
-        assert callable(score_func), ValueError(
-                "The score function should be a callable, '%s' (type %s) "
-                "was passed." % (score_func, type(score_func))
-            )
-        if select_func is None:
-            select_func = select_percentile
-        assert callable(select_func), ValueError(
-                "The score function should be a callable, '%s' (type %s) "
-                "was passed." % (select_func, type(select_func))
-            )
-        self.estimator = estimator
-        self.score_func = score_func
-        self.select_func = select_func
-        self.select_args = select_args
-
-
-    #--------------------------------------------------------------------------
-    # Estimator interface
-    #--------------------------------------------------------------------------
-
-    def fit(self, x, y):
-        _, p_values_   = self.score_func(x, y)
-        self.support_  = self.select_func(p_values_,*self.select_args)
-        self.p_values_ = p_values_
-        if self.estimator is not None:
-            self.estimator.fit(x[:,self.support_], y)
-        return self
-
-    def predict(self, x=None):
-        # FIXME : support estimate is done again in predict too in
-        # case select_args have changed
-        self.support_  = self.select_func(self.p_values_, *self.select_args)
-        support_ = self.support_
-        if x is None or self.estimator is None:
-            return support_
-        else:
-            return self.estimator.predict(x[:,support_])
-
-    def predict_proba(self, X):
-        self.support_  = self.select_func(self.p_values_, *self.select_args)
-        support_ = self.support_
-        return self.estimator.predict_proba(X[:,support_])
-
-class UnivSelect(object):
-
-    def __init__(self, score_func=f_regression,
-                       select_func=None):
-        """ An object to do univariate selection before using a
-            classifier.
-
-            Implements fit and reduce methods
-
-            The reduce method returns the support of the selected
-            feature set.
-
-            Parameters
-            -----------
-            score_func: A callable
-                The function used to score features. Should be::
-
-                    _, p_values = score_func(x, y)
-
-                The first output argument is ignored.
-            select_func: A callable
-                The function used to select features. Should be::
-
-                    support = select_func(p_values, *select_args)
-                If None is passed, the 10% lowest p_values are
-                selected.
-        """
-        if select_func is None:
-            select_func = select_percentile
-        assert callable(select_func), ValueError(
-                "The score function should be a callable, '%s' (type %s) "
-                "was passed." % (select_func, type(select_func))
-            )
-        self.score_func = score_func
-        self.select_func = select_func
-
-    #--------------------------------------------------------------------------
-    # Interface
-    #--------------------------------------------------------------------------
-
-    def fit(self, x, y):
-        self.x = x
-        self.y = y
-        _, p_values_   = self.score_func(x, y)
-        self.p_values_ = p_values_
-        return self
-
-    def reduce(self, n_features):
-        support  = self.select_func(self.p_values_, n_features)
-        return support
-
diff --git a/scikits/learn/feature_selection/univariate_selection.py b/scikits/learn/feature_selection/univariate_selection.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2b7e4d84da4ec601a92cc72bbd5cb06890b78ea
--- /dev/null
+++ b/scikits/learn/feature_selection/univariate_selection.py
@@ -0,0 +1,203 @@
+"""
+Univariate features selection.
+"""
+
+# Authors: V. Michel, B. Thirion, G. Varoquaux, A. Gramfort, E. Duchesnay
+# License: BSD 3 clause
+
+import numpy as np
+from scipy import stats
+
+######################################################################
+# Scoring functions
+######################################################################
+
+def f_classif(X, y):
+    """
+    Compute the Anova F-value for the provided sample
+
+    Parameters
+    ----------
+    X : array of shape (n_samples, n_features)
+        the set of regressors sthat will tested sequentially
+    y : array of shape(n_samples)
+        the data matrix
+
+    Returns
+    -------
+    F : array of shape (m),
+        the set of F values
+    pval : array of shape(m),
+        the set of p-values
+    """
+    X = np.asanyarray(X)
+    args = [X[y==k] for k in np.unique(y)]
+    return stats.f_oneway(*args)
+
+
+def f_regression(X, y, center=True):
+    """
+    Quick linear model for testing the effect of a single regressor,
+    sequentially for many regressors
+    This is done in 3 steps:
+    1. the regressor of interest and the data are orthogonalized
+    wrt constant regressors
+    2. the cross correlation between data and regressors is computed
+    3. it is converted to an F score then to a p-value
+
+    Parameters
+    ----------
+    X : array of shape (n_samples, n_features)
+        the set of regressors sthat will tested sequentially
+    y : array of shape(n_samples)
+        the data matrix
+
+    center : True, bool,
+        If true, X and y are centered
+
+    Returns
+    -------
+    F : array of shape (m),
+        the set of F values
+    pval : array of shape(m)
+        the set of p-values
+    """
+
+    # orthogonalize everything wrt to confounds
+    y = y.copy()
+    X = X.copy()
+    if center:
+        y -= np.mean(y)
+        X -= np.mean(X, 0)
+
+    # compute the correlation
+    X /= np.sqrt(np.sum(X**2,0))
+    y /= np.sqrt(np.sum(y**2))
+    corr = np.dot(y, X)
+
+    # convert to p-value
+    dof = y.size-2
+    F = corr**2/(1-corr**2)*dof
+    pv = stats.f.sf(F, 1, dof)
+    return F, pv
+
+
+######################################################################
+# General class for filter univariate selection
+######################################################################
+
+
+class UnivariateFilter(object):
+
+    def __init__(self, score_func):
+        """
+        Initialize the univariate feature selection.
+        Func : function taking two arrays X and y, and returning an array.
+        score_func returning both scores and pvalues
+        """
+        assert callable(score_func), ValueError(
+                "The score function should be a callable, '%s' (type %s) "
+                "was passed." % (score_func, type(score_func))
+            )
+        self.score_func = score_func
+
+    def fit(self,X,y):
+        """
+        Evaluate the function
+        """
+        _scores = self.score_func(X, y)
+        self._scores = _scores[0]
+        self._pvalues = _scores[1]
+        #self._rank = np.argsort(self._pvalues)
+        return self
+
+    def get_selected_features(self):
+        """
+        Returns the indices of the selected features
+        """
+        return self.support_
+
+   #def transform(self):
+       #"""
+        #Returns the indices of the selected features
+       #"""
+       #raise("Error  : Not implemented")
+
+
+
+######################################################################
+# Specific filters
+######################################################################
+
+class SelectPercentile(UnivariateFilter):
+    """
+    Filter : Select the best percentile of the p_values
+    """
+    def transform(self,X,percentile):
+        """
+        Transform the data.
+        """
+        assert percentile<=100, ValueError('percentile should be \
+                            between 0 and 100 (%f given)' %(percentile))
+        alpha = stats.scoreatpercentile(self._pvalues, percentile)
+        self.support_ = (self._pvalues <= alpha)
+        return X[:,self.support_]
+
+class SelectKBest(UnivariateFilter):
+    """
+    Filter : Select the k lowest p-values
+    """
+    def transform(self,X,k):
+        assert k<=len(self._pvalues), ValueError('cannot select %d features'
+                                    ' among %d ' % (k, len(self._pvalues)))
+        alpha = np.sort(self._pvalues)[k-1]
+        self.support_ = (self._pvalues <= alpha)
+        return X[:,self.support_]
+
+
+class SelectFpr(UnivariateFilter):
+    """
+    Filter : Select the pvalues below alpha
+    """
+    def transform(self,X,alpha):
+        self.support_ = (self._pvalues < alpha)
+        return X[:,self.support_]
+
+
+class SelectFdr(UnivariateFilter):
+    """
+    Filter : Select the p-values corresponding to an estimated false
+    discovery rate of alpha. This uses the Benjamini-Hochberg procedure
+    """
+    def transform(self,X,alpha):
+        sv = np.sort(self._pvalues)
+        threshold = sv[sv < alpha*np.arange(len(self._pvalues))].max()
+        self.support_ = (self._pvalues < threshold)
+        return X[:,self.support_]
+
+
+class SelectFwe(UnivariateFilter):
+    """
+    Filter : Select the p-values corresponding to a corrected p-value of alpha
+    """
+    def transform(self,X,alpha):
+        self.support_ = (self._pvalues < alpha/len(self._pvalues))
+        return X[:,self.support_]
+
+
+
+if __name__ == "__main__":
+    import scikits.learn.datasets.samples_generator as sg
+    from scikits.learn.svm import SVR, SVC
+
+    X,y = sg.sparse_uncorrelated(50,100)
+    univariate_filter = SelectKBest(f_regression)
+    X_r = univariate_filter.fit(X, y).transform(X, k=5)
+    clf = SVR(kernel='linear', C=1.)
+    y_ = clf.fit(X_r, y).predict(X_r)
+    print univariate_filter.support_.astype(int)
+
+    ### now change k
+    X_r = univariate_filter.transform(X, k=2)
+    y_ = clf.fit(X_r, y).predict(X)
+    print univariate_filter.support_.astype(int)