From b57837188dd55d37e13ddcc8f9f548e9dbd3f0c1 Mon Sep 17 00:00:00 2001 From: AishwaryaRK <aishwarya.kaneri@gmail.com> Date: Sat, 29 Apr 2017 17:03:32 +0530 Subject: [PATCH] [MRG] Fixes #8736 add get_n_splits for RepeatedKFold and RepeatedStratifiedKFold (#8802) --- sklearn/model_selection/_split.py | 27 +++++++++++++++++++++ sklearn/model_selection/tests/test_split.py | 16 ++++++++++++ 2 files changed, 43 insertions(+) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 0eb51be93f..151bbafd62 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -997,6 +997,33 @@ class _RepeatedSplits(with_metaclass(ABCMeta)): for train_index, test_index in cv.split(X, y, groups): yield train_index, test_index + def get_n_splits(self, X=None, y=None, groups=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. + + y : object + Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + rng = check_random_state(self.random_state) + cv = self.cv(random_state=rng, shuffle=True, + **self.cvargs) + return cv.get_n_splits(X, y, groups) * self.n_repeats + class RepeatedKFold(_RepeatedSplits): """Repeated K-Fold cross validator. diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index c997ac9d73..fcd0160ca7 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -844,6 +844,22 @@ def test_repeated_kfold_determinstic_split(): assert_raises(StopIteration, next, splits) +def test_get_n_splits_for_repeated_kfold(): + n_splits = 3 + n_repeats = 4 + rkf = RepeatedKFold(n_splits, n_repeats) + expected_n_splits = n_splits * n_repeats + assert_equal(expected_n_splits, rkf.get_n_splits()) + + +def test_get_n_splits_for_repeated_stratified_kfold(): + n_splits = 3 + n_repeats = 4 + rskf = RepeatedStratifiedKFold(n_splits, n_repeats) + expected_n_splits = n_splits * n_repeats + assert_equal(expected_n_splits, rskf.get_n_splits()) + + def test_repeated_stratified_kfold_determinstic_split(): X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] y = [1, 1, 1, 0, 0] -- GitLab