From b57837188dd55d37e13ddcc8f9f548e9dbd3f0c1 Mon Sep 17 00:00:00 2001
From: AishwaryaRK <aishwarya.kaneri@gmail.com>
Date: Sat, 29 Apr 2017 17:03:32 +0530
Subject: [PATCH] [MRG] Fixes #8736 add get_n_splits for RepeatedKFold and
 RepeatedStratifiedKFold (#8802)

---
 sklearn/model_selection/_split.py           | 27 +++++++++++++++++++++
 sklearn/model_selection/tests/test_split.py | 16 ++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 0eb51be93f..151bbafd62 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -997,6 +997,33 @@ class _RepeatedSplits(with_metaclass(ABCMeta)):
             for train_index, test_index in cv.split(X, y, groups):
                 yield train_index, test_index
 
+    def get_n_splits(self, X=None, y=None, groups=None):
+        """Returns the number of splitting iterations in the cross-validator
+
+        Parameters
+        ----------
+        X : object
+            Always ignored, exists for compatibility.
+            ``np.zeros(n_samples)`` may be used as a placeholder.
+
+        y : object
+            Always ignored, exists for compatibility.
+            ``np.zeros(n_samples)`` may be used as a placeholder.
+
+        groups : array-like, with shape (n_samples,), optional
+            Group labels for the samples used while splitting the dataset into
+            train/test set.
+
+        Returns
+        -------
+        n_splits : int
+            Returns the number of splitting iterations in the cross-validator.
+        """
+        rng = check_random_state(self.random_state)
+        cv = self.cv(random_state=rng, shuffle=True,
+                     **self.cvargs)
+        return cv.get_n_splits(X, y, groups) * self.n_repeats
+
 
 class RepeatedKFold(_RepeatedSplits):
     """Repeated K-Fold cross validator.
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index c997ac9d73..fcd0160ca7 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -844,6 +844,22 @@ def test_repeated_kfold_determinstic_split():
         assert_raises(StopIteration, next, splits)
 
 
+def test_get_n_splits_for_repeated_kfold():
+    n_splits = 3
+    n_repeats = 4
+    rkf = RepeatedKFold(n_splits, n_repeats)
+    expected_n_splits = n_splits * n_repeats
+    assert_equal(expected_n_splits, rkf.get_n_splits())
+
+
+def test_get_n_splits_for_repeated_stratified_kfold():
+    n_splits = 3
+    n_repeats = 4
+    rskf = RepeatedStratifiedKFold(n_splits, n_repeats)
+    expected_n_splits = n_splits * n_repeats
+    assert_equal(expected_n_splits, rskf.get_n_splits())
+
+
 def test_repeated_stratified_kfold_determinstic_split():
     X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
     y = [1, 1, 1, 0, 0]
-- 
GitLab