diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 0eb51be93f5bb266b4780ea14e0b5f41b8a2420d..151bbafd622206288a9be69846ffdb860483c7de 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -997,6 +997,33 @@ class _RepeatedSplits(with_metaclass(ABCMeta)): for train_index, test_index in cv.split(X, y, groups): yield train_index, test_index + def get_n_splits(self, X=None, y=None, groups=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. + + y : object + Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + rng = check_random_state(self.random_state) + cv = self.cv(random_state=rng, shuffle=True, + **self.cvargs) + return cv.get_n_splits(X, y, groups) * self.n_repeats + class RepeatedKFold(_RepeatedSplits): """Repeated K-Fold cross validator. diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index c997ac9d73e5d55ef88c1244995c52f811dbfe0f..fcd0160ca74eeb3aeca0419721105bcab72b2665 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -844,6 +844,22 @@ def test_repeated_kfold_determinstic_split(): assert_raises(StopIteration, next, splits) +def test_get_n_splits_for_repeated_kfold(): + n_splits = 3 + n_repeats = 4 + rkf = RepeatedKFold(n_splits, n_repeats) + expected_n_splits = n_splits * n_repeats + assert_equal(expected_n_splits, rkf.get_n_splits()) + + +def test_get_n_splits_for_repeated_stratified_kfold(): + n_splits = 3 + n_repeats = 4 + rskf = RepeatedStratifiedKFold(n_splits, n_repeats) + expected_n_splits = n_splits * n_repeats + assert_equal(expected_n_splits, rskf.get_n_splits()) + + def test_repeated_stratified_kfold_determinstic_split(): X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] y = [1, 1, 1, 0, 0]