From b18b1611cd9ab129047fc5f0080b1eb836404d3a Mon Sep 17 00:00:00 2001 From: Raghav RV <rvraghav93@gmail.com> Date: Mon, 12 Sep 2016 17:44:16 +0200 Subject: [PATCH] [MRG+1] TST Stronger test for _check_is_permutation (#7395) * TST Stronger test for _check_is_permutation * TST Ensure additional duplicate indices are caught --- sklearn/model_selection/_validation.py | 4 ++-- sklearn/model_selection/tests/test_validation.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 1ed36d9b7b..bf1d70b081 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -476,11 +476,11 @@ def _check_is_permutation(indices, n_samples): Returns ------- is_partition : bool - True iff sorted(locs) is range(n) + True iff sorted(indices) is np.arange(n) """ if len(indices) != n_samples: return False - hit = np.zeros(n_samples, bool) + hit = np.zeros(n_samples, dtype=bool) hit[indices] = True if not np.all(hit): return False diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 67937711ec..4f45a079fb 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -731,13 +731,18 @@ def test_validation_curve(): def test_check_is_permutation(): + rng = np.random.RandomState(0) p = np.arange(100) + rng.shuffle(p) assert_true(_check_is_permutation(p, 100)) assert_false(_check_is_permutation(np.delete(p, 23), 100)) p[0] = 23 assert_false(_check_is_permutation(p, 100)) + # Check if the additional duplicate indices are caught + assert_false(_check_is_permutation(np.hstack((p, 0)), 100)) + def test_cross_val_predict_sparse_prediction(): # check that cross_val_predict gives same result for sparse and dense input -- GitLab