From 046141ba5ea18c08b39df8ffaa68cef1286b1b67 Mon Sep 17 00:00:00 2001 From: RAKOTOARISON Herilalaina <rkt.herilalaina@gmail.com> Date: Thu, 8 Jun 2017 14:12:57 +0200 Subject: [PATCH] [MRG+1] Make classification (dimensions > 30) (#9045) * Change _generate_hypercube into rng.randint * Improve unit test * Test if each row is unique --- sklearn/datasets/samples_generator.py | 2 +- sklearn/datasets/tests/test_samples_generator.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py index 3ba9dfd487..82ae355a7f 100644 --- a/sklearn/datasets/samples_generator.py +++ b/sklearn/datasets/samples_generator.py @@ -25,7 +25,7 @@ def _generate_hypercube(samples, dimensions, rng): """Returns distinct binary samples of length dimensions """ if dimensions > 30: - return np.hstack([_generate_hypercube(samples, dimensions - 30, rng), + return np.hstack([rng.randint(2, size=(samples, dimensions - 30)), _generate_hypercube(samples, 30, rng)]) out = sample_without_replacement(2 ** dimensions, samples, random_state=rng).astype(dtype='>u4', diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index cd4d4148c0..7e0bcff90d 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -50,6 +50,17 @@ def test_make_classification(): assert_equal(sum(y == 1), 25, "Unexpected number of samples in class #1") assert_equal(sum(y == 2), 65, "Unexpected number of samples in class #2") + # Test for n_features > 30 + X, y = make_classification(n_samples=2000, n_features=31, n_informative=31, + n_redundant=0, n_repeated=0, hypercube=True, + scale=0.5, random_state=0) + + assert_equal(X.shape, (2000, 31), "X shape mismatch") + assert_equal(y.shape, (2000,), "y shape mismatch") + assert_equal(np.unique(X.view([('', X.dtype)]*X.shape[1])).view(X.dtype) + .reshape(-1, X.shape[1]).shape[0], 2000, + "Unexpected number of unique rows") + def test_make_classification_informative_features(): """Test the construction of informative features in make_classification -- GitLab