Skip to content
Snippets Groups Projects
Commit b661a9c8 authored by Joel Nothman's avatar Joel Nothman Committed by Andreas Müller
Browse files

TST Improve SelectFromModel tests (#9733)

Should fix one of the issues in #9393
parent 0e1d261c
Branches
Tags 0.19.1
No related merge requests found
......@@ -40,7 +40,6 @@ def test_input_estimator_unchanged():
assert_true(transformer.estimator is est)
@skip_if_32bit
def test_feature_importances():
X, y = datasets.make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
......@@ -59,17 +58,33 @@ def test_feature_importances():
feature_mask = np.abs(importances) > func(importances)
assert_array_almost_equal(X_new, X[:, feature_mask])
def test_sample_weight():
# Ensure sample weights are passed to underlying estimator
X, y = datasets.make_classification(
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0)
# Check with sample weights
sample_weight = np.ones(y.shape)
sample_weight[y == 1] *= 100
est = RandomForestClassifier(n_estimators=50, random_state=0)
est = LogisticRegression(random_state=0, fit_intercept=False)
transformer = SelectFromModel(estimator=est)
transformer.fit(X, y, sample_weight=None)
mask = transformer._get_support_mask()
transformer.fit(X, y, sample_weight=sample_weight)
importances = transformer.estimator_.feature_importances_
weighted_mask = transformer._get_support_mask()
assert not np.all(weighted_mask == mask)
transformer.fit(X, y, sample_weight=3 * sample_weight)
importances_bis = transformer.estimator_.feature_importances_
assert_almost_equal(importances, importances_bis)
reweighted_mask = transformer._get_support_mask()
assert np.all(weighted_mask == reweighted_mask)
def test_coef_default_threshold():
X, y = datasets.make_classification(
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0)
# For the Lasso and related models, the threshold defaults to 1e-5
transformer = SelectFromModel(estimator=Lasso(alpha=0.1))
......@@ -80,7 +95,7 @@ def test_feature_importances():
@skip_if_32bit
def test_feature_importances_2d_coef():
def test_2d_coef():
X, y = datasets.make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0, n_classes=4)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment