From b18b1611cd9ab129047fc5f0080b1eb836404d3a Mon Sep 17 00:00:00 2001
From: Raghav RV <rvraghav93@gmail.com>
Date: Mon, 12 Sep 2016 17:44:16 +0200
Subject: [PATCH] [MRG+1] TST Stronger test for _check_is_permutation (#7395)

* TST Stronger test for _check_is_permutation

* TST Ensure additional duplicate indices are caught
---
 sklearn/model_selection/_validation.py           | 4 ++--
 sklearn/model_selection/tests/test_validation.py | 5 +++++
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index 1ed36d9b7b..bf1d70b081 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -476,11 +476,11 @@ def _check_is_permutation(indices, n_samples):
     Returns
     -------
     is_partition : bool
-        True iff sorted(locs) is range(n)
+        True iff sorted(indices) is np.arange(n)
     """
     if len(indices) != n_samples:
         return False
-    hit = np.zeros(n_samples, bool)
+    hit = np.zeros(n_samples, dtype=bool)
     hit[indices] = True
     if not np.all(hit):
         return False
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index 67937711ec..4f45a079fb 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -731,13 +731,18 @@ def test_validation_curve():
 
 
 def test_check_is_permutation():
+    rng = np.random.RandomState(0)
     p = np.arange(100)
+    rng.shuffle(p)
     assert_true(_check_is_permutation(p, 100))
     assert_false(_check_is_permutation(np.delete(p, 23), 100))
 
     p[0] = 23
     assert_false(_check_is_permutation(p, 100))
 
+    # Check if the additional duplicate indices are caught
+    assert_false(_check_is_permutation(np.hstack((p, 0)), 100))
+
 
 def test_cross_val_predict_sparse_prediction():
     # check that cross_val_predict gives same result for sparse and dense input
-- 
GitLab