diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 1ed36d9b7b88024c29d07126101796a27b1af086..bf1d70b0819c898bc85db11d9c5127abb55b1509 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -476,11 +476,11 @@ def _check_is_permutation(indices, n_samples): Returns ------- is_partition : bool - True iff sorted(locs) is range(n) + True iff sorted(indices) is np.arange(n) """ if len(indices) != n_samples: return False - hit = np.zeros(n_samples, bool) + hit = np.zeros(n_samples, dtype=bool) hit[indices] = True if not np.all(hit): return False diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 67937711ec2d8d4f353ac468d76ca374ae7a2302..4f45a079fbd7f82dc0317eab57165143d107e7d9 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -731,13 +731,18 @@ def test_validation_curve(): def test_check_is_permutation(): + rng = np.random.RandomState(0) p = np.arange(100) + rng.shuffle(p) assert_true(_check_is_permutation(p, 100)) assert_false(_check_is_permutation(np.delete(p, 23), 100)) p[0] = 23 assert_false(_check_is_permutation(p, 100)) + # Check if the additional duplicate indices are caught + assert_false(_check_is_permutation(np.hstack((p, 0)), 100)) + def test_cross_val_predict_sparse_prediction(): # check that cross_val_predict gives same result for sparse and dense input