diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 3aee8f258b9d159ad381ad7b2db540e9ed58a71b..3101488fd66614a75bc6d25623cd742a7785670a 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -170,6 +170,8 @@ Splitter Classes
    model_selection.LeavePGroupsOut
    model_selection.LeaveOneOut
    model_selection.LeavePOut
+   model_selection.RepeatedKFold
+   model_selection.RepeatedStratifiedKFold
    model_selection.ShuffleSplit
    model_selection.GroupShuffleSplit
    model_selection.StratifiedShuffleSplit
diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index 3f7cf95e5941747190ce2e3c39c7af729caf6244..4b9a36e979d4dffb32339c181dbf22d45501dddd 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -263,6 +263,33 @@ Thus, one can create the training/test sets using numpy indexing::
   >>> X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]
 
 
+Repeated K-Fold
+---------------
+
+:class:`RepeatedKFold` repeats K-Fold n times. It can be used when one
+requires to run :class:`KFold` n times, producing different splits in
+each repetition.
+
+Example of 2-fold K-Fold repeated 2 times::
+
+  >>> import numpy as np
+  >>> from sklearn.model_selection import RepeatedKFold
+  >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
+  >>> random_state = 12883823
+  >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
+  >>> for train, test in rkf.split(X):
+  ...     print("%s %s" % (train, test))
+  ...
+  [2 3] [0 1]
+  [0 1] [2 3]
+  [0 2] [1 3]
+  [1 3] [0 2]
+
+
+Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold n times
+with different randomization in each repetition.
+
+
 Leave One Out (LOO)
 -------------------
 
@@ -409,6 +436,10 @@ two slightly unbalanced classes::
   [0 1 3 4 5 8 9] [2 6 7]
   [0 1 2 4 5 6 7] [3 8 9]
 
+:class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times
+with different randomization in each repetition.
+
+
 Stratified Shuffle Split
 ------------------------
 
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 7d2fa8a5628874e91c59767e183a9e9a7f282daa..a5a7b369bf89af83c9dbc22d6e7908edac0aa658 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -41,6 +41,10 @@ New features
      Kullback-Leibler divergence and the Itakura-Saito divergence.
      By `Tom Dupre la Tour`_.
 
+   - Added the :class:`sklearn.model_selection.RepeatedKFold` and
+     :class:`sklearn.model_selection.RepeatedStratifiedKFold`.
+     :issue:`8120` by `Neeraj Gangwar`_.
+
    - Added :func:`metrics.mean_squared_log_error`, which computes
      the mean square error of the logarithmic transformation of targets,
      particularly useful for targets with an exponential trend.
@@ -5004,3 +5008,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
 .. _Vincent Pham: https://github.com/vincentpham1991
 
 .. _Denis Engemann: http://denis-engemann.de
+
+.. _Neeraj Gangwar: http://neerajgangwar.in
diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py
index f5ab0d7526ccf503456130c0462b36dfc960f05c..73c842e706df8fc3d3ffd781475d8107af195013 100644
--- a/sklearn/model_selection/__init__.py
+++ b/sklearn/model_selection/__init__.py
@@ -7,6 +7,8 @@ from ._split import LeaveOneGroupOut
 from ._split import LeaveOneOut
 from ._split import LeavePGroupsOut
 from ._split import LeavePOut
+from ._split import RepeatedKFold
+from ._split import RepeatedStratifiedKFold
 from ._split import ShuffleSplit
 from ._split import GroupShuffleSplit
 from ._split import StratifiedShuffleSplit
@@ -36,6 +38,8 @@ __all__ = ('BaseCrossValidator',
            'LeaveOneOut',
            'LeavePGroupsOut',
            'LeavePOut',
+           'RepeatedKFold',
+           'RepeatedStratifiedKFold',
            'ParameterGrid',
            'ParameterSampler',
            'PredefinedSplit',
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index b2ed060e31717e205ba47c51804daaceeb6d9c34..992c4f6d81e6a4ef4566339770d70f8e4f8e622d 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -41,6 +41,8 @@ __all__ = ['BaseCrossValidator',
            'LeaveOneOut',
            'LeavePGroupsOut',
            'LeavePOut',
+           'RepeatedStratifiedKFold',
+           'RepeatedKFold',
            'ShuffleSplit',
            'GroupShuffleSplit',
            'StratifiedKFold',
@@ -397,6 +399,8 @@ class KFold(_BaseKFold):
         classification tasks).
 
     GroupKFold: K-fold iterator variant with non-overlapping groups.
+
+    RepeatedKFold: Repeats K-Fold n times.
     """
 
     def __init__(self, n_splits=3, shuffle=False,
@@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold):
     All the folds have size ``trunc(n_samples / n_splits)``, the last one has
     the complementary.
 
+    See also
+    --------
+    RepeatedStratifiedKFold: Repeats Stratified K-Fold n times.
     """
 
     def __init__(self, n_splits=3, shuffle=False, random_state=None):
@@ -913,6 +920,170 @@ class LeavePGroupsOut(BaseCrossValidator):
         return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
 
 
+class _RepeatedSplits(with_metaclass(ABCMeta)):
+    """Repeated splits for an arbitrary randomized CV splitter.
+
+    Repeats splits for cross-validators n times with different randomization
+    in each repetition.
+
+    Parameters
+    ----------
+    cv : callable
+        Cross-validator class.
+
+    n_repeats : int, default=10
+        Number of times cross-validator needs to be repeated.
+
+    random_state : None, int or RandomState, default=None
+        Random state to be used to generate random state for each
+        repetition.
+
+    **cvargs : additional params
+        Constructor parameters for cv. Must not contain random_state
+        and shuffle.
+    """
+    def __init__(self, cv, n_repeats=10, random_state=None, **cvargs):
+        if not isinstance(n_repeats, (np.integer, numbers.Integral)):
+            raise ValueError("Number of repetitions must be of Integral type.")
+
+        if n_repeats <= 1:
+            raise ValueError("Number of repetitions must be greater than 1.")
+
+        if any(key in cvargs for key in ('random_state', 'shuffle')):
+            raise ValueError(
+                "cvargs must not contain random_state or shuffle.")
+
+        self.cv = cv
+        self.n_repeats = n_repeats
+        self.random_state = random_state
+        self.cvargs = cvargs
+
+    def split(self, X, y=None, groups=None):
+        """Generates indices to split data into training and test set.
+
+        Parameters
+        ----------
+        X : array-like, shape (n_samples, n_features)
+            Training data, where n_samples is the number of samples
+            and n_features is the number of features.
+
+        y : array-like, of length n_samples
+            The target variable for supervised learning problems.
+
+        groups : array-like, with shape (n_samples,), optional
+            Group labels for the samples used while splitting the dataset into
+            train/test set.
+
+        Returns
+        -------
+        train : ndarray
+            The training set indices for that split.
+
+        test : ndarray
+            The testing set indices for that split.
+        """
+        n_repeats = self.n_repeats
+        rng = check_random_state(self.random_state)
+
+        for idx in range(n_repeats):
+            cv = self.cv(random_state=rng, shuffle=True,
+                         **self.cvargs)
+            for train_index, test_index in cv.split(X, y, groups):
+                yield train_index, test_index
+
+
+class RepeatedKFold(_RepeatedSplits):
+    """Repeated K-Fold cross validator.
+
+    Repeats K-Fold n times with different randomization in each repetition.
+
+    Read more in the :ref:`User Guide <cross_validation>`.
+
+    Parameters
+    ----------
+    n_splits : int, default=5
+        Number of folds. Must be at least 2.
+
+    n_repeats : int, default=10
+        Number of times cross-validator needs to be repeated.
+
+    random_state : None, int or RandomState, default=None
+        Random state to be used to generate random state for each
+        repetition.
+
+    Examples
+    --------
+    >>> from sklearn.model_selection import RepeatedKFold
+    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
+    >>> y = np.array([0, 0, 1, 1])
+    >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124)
+    >>> for train_index, test_index in rkf.split(X):
+    ...     print("TRAIN:", train_index, "TEST:", test_index)
+    ...     X_train, X_test = X[train_index], X[test_index]
+    ...     y_train, y_test = y[train_index], y[test_index]
+    ...
+    TRAIN: [0 1] TEST: [2 3]
+    TRAIN: [2 3] TEST: [0 1]
+    TRAIN: [1 2] TEST: [0 3]
+    TRAIN: [0 3] TEST: [1 2]
+
+
+    See also
+    --------
+    RepeatedStratifiedKFold: Repeates Stratified K-Fold n times.
+    """
+    def __init__(self, n_splits=5, n_repeats=10, random_state=None):
+        super(RepeatedKFold, self).__init__(
+            KFold, n_repeats, random_state, n_splits=n_splits)
+
+
+class RepeatedStratifiedKFold(_RepeatedSplits):
+    """Repeated Stratified K-Fold cross validator.
+
+    Repeats Stratified K-Fold n times with different randomization in each
+    repetition.
+
+    Read more in the :ref:`User Guide <cross_validation>`.
+
+    Parameters
+    ----------
+    n_splits : int, default=5
+        Number of folds. Must be at least 2.
+
+    n_repeats : int, default=10
+        Number of times cross-validator needs to be repeated.
+
+    random_state : None, int or RandomState, default=None
+        Random state to be used to generate random state for each
+        repetition.
+
+    Examples
+    --------
+    >>> from sklearn.model_selection import RepeatedStratifiedKFold
+    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
+    >>> y = np.array([0, 0, 1, 1])
+    >>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2,
+    ...     random_state=36851234)
+    >>> for train_index, test_index in rskf.split(X, y):
+    ...     print("TRAIN:", train_index, "TEST:", test_index)
+    ...     X_train, X_test = X[train_index], X[test_index]
+    ...     y_train, y_test = y[train_index], y[test_index]
+    ...
+    TRAIN: [1 2] TEST: [0 3]
+    TRAIN: [0 3] TEST: [1 2]
+    TRAIN: [1 3] TEST: [0 2]
+    TRAIN: [0 2] TEST: [1 3]
+
+
+    See also
+    --------
+    RepeatedKFold: Repeats K-Fold n times.
+    """
+    def __init__(self, n_splits=5, n_repeats=10, random_state=None):
+        super(RepeatedStratifiedKFold, self).__init__(
+            StratifiedKFold, n_repeats, random_state, n_splits=n_splits)
+
+
 class BaseShuffleSplit(with_metaclass(ABCMeta)):
     """Base class for ShuffleSplit and StratifiedShuffleSplit"""
 
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index 601e9b259c537fb89d9fafc9614735dd8ee40b04..c997ac9d73e5d55ef88c1244995c52f811dbfe0f 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -42,6 +42,8 @@ from sklearn.model_selection import PredefinedSplit
 from sklearn.model_selection import check_cv
 from sklearn.model_selection import train_test_split
 from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import RepeatedKFold
+from sklearn.model_selection import RepeatedStratifiedKFold
 
 from sklearn.linear_model import Ridge
 
@@ -804,6 +806,76 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
                          LeavePGroupsOut(n_groups=3).split(X, y, groups))
 
 
+def test_repeated_cv_value_errors():
+    # n_repeats is not integer or <= 1
+    for cv in (RepeatedKFold, RepeatedStratifiedKFold):
+        assert_raises(ValueError, cv, n_repeats=1)
+        assert_raises(ValueError, cv, n_repeats=1.5)
+
+
+def test_repeated_kfold_determinstic_split():
+    X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+    random_state = 258173307
+    rkf = RepeatedKFold(
+        n_splits=2,
+        n_repeats=2,
+        random_state=random_state)
+
+    # split should produce same and deterministic splits on
+    # each call
+    for _ in range(3):
+        splits = rkf.split(X)
+        train, test = next(splits)
+        assert_array_equal(train, [2, 4])
+        assert_array_equal(test, [0, 1, 3])
+
+        train, test = next(splits)
+        assert_array_equal(train, [0, 1, 3])
+        assert_array_equal(test, [2, 4])
+
+        train, test = next(splits)
+        assert_array_equal(train, [0, 1])
+        assert_array_equal(test, [2, 3, 4])
+
+        train, test = next(splits)
+        assert_array_equal(train, [2, 3, 4])
+        assert_array_equal(test, [0, 1])
+
+        assert_raises(StopIteration, next, splits)
+
+
+def test_repeated_stratified_kfold_determinstic_split():
+    X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+    y = [1, 1, 1, 0, 0]
+    random_state = 1944695409
+    rskf = RepeatedStratifiedKFold(
+        n_splits=2,
+        n_repeats=2,
+        random_state=random_state)
+
+    # split should produce same and deterministic splits on
+    # each call
+    for _ in range(3):
+        splits = rskf.split(X, y)
+        train, test = next(splits)
+        assert_array_equal(train, [1, 4])
+        assert_array_equal(test, [0, 2, 3])
+
+        train, test = next(splits)
+        assert_array_equal(train, [0, 2, 3])
+        assert_array_equal(test, [1, 4])
+
+        train, test = next(splits)
+        assert_array_equal(train, [2, 3])
+        assert_array_equal(test, [0, 1, 4])
+
+        train, test = next(splits)
+        assert_array_equal(train, [0, 1, 4])
+        assert_array_equal(test, [2, 3])
+
+        assert_raises(StopIteration, next, splits)
+
+
 def test_train_test_split_errors():
     assert_raises(ValueError, train_test_split)
     assert_raises(ValueError, train_test_split, range(3), train_size=1.1)