From 046141ba5ea18c08b39df8ffaa68cef1286b1b67 Mon Sep 17 00:00:00 2001
From: RAKOTOARISON Herilalaina <rkt.herilalaina@gmail.com>
Date: Thu, 8 Jun 2017 14:12:57 +0200
Subject: [PATCH] [MRG+1] Make classification (dimensions > 30) (#9045)

* Change _generate_hypercube into rng.randint

* Improve unit test

* Test if each row is unique
---
 sklearn/datasets/samples_generator.py            |  2 +-
 sklearn/datasets/tests/test_samples_generator.py | 11 +++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py
index 3ba9dfd487..82ae355a7f 100644
--- a/sklearn/datasets/samples_generator.py
+++ b/sklearn/datasets/samples_generator.py
@@ -25,7 +25,7 @@ def _generate_hypercube(samples, dimensions, rng):
     """Returns distinct binary samples of length dimensions
     """
     if dimensions > 30:
-        return np.hstack([_generate_hypercube(samples, dimensions - 30, rng),
+        return np.hstack([rng.randint(2, size=(samples, dimensions - 30)),
                           _generate_hypercube(samples, 30, rng)])
     out = sample_without_replacement(2 ** dimensions, samples,
                                      random_state=rng).astype(dtype='>u4',
diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py
index cd4d4148c0..7e0bcff90d 100644
--- a/sklearn/datasets/tests/test_samples_generator.py
+++ b/sklearn/datasets/tests/test_samples_generator.py
@@ -50,6 +50,17 @@ def test_make_classification():
     assert_equal(sum(y == 1), 25, "Unexpected number of samples in class #1")
     assert_equal(sum(y == 2), 65, "Unexpected number of samples in class #2")
 
+    # Test for n_features > 30
+    X, y = make_classification(n_samples=2000, n_features=31, n_informative=31,
+                               n_redundant=0, n_repeated=0, hypercube=True,
+                               scale=0.5, random_state=0)
+
+    assert_equal(X.shape, (2000, 31), "X shape mismatch")
+    assert_equal(y.shape, (2000,), "y shape mismatch")
+    assert_equal(np.unique(X.view([('', X.dtype)]*X.shape[1])).view(X.dtype)
+                 .reshape(-1, X.shape[1]).shape[0], 2000,
+                 "Unexpected number of unique rows")
+
 
 def test_make_classification_informative_features():
     """Test the construction of informative features in make_classification
-- 
GitLab