diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index 4b9a36e979d4dffb32339c181dbf22d45501dddd..187eb4020178d1e8c94772e2b2fd93b31fdcd6f5 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -464,7 +464,7 @@ In this case we would like to know if a model trained on a particular set of
 groups generalizes well to the unseen groups. To measure this, we need to
 ensure that all the samples in the validation fold come from groups that are
 not represented at all in the paired training fold.
- 
+
 The following cross-validation splitters can be used to do that.
 The grouping identifier for the samples is specified via the ``groups``
 parameter.
@@ -601,29 +601,29 @@ samples that are part of the validation set, and to -1 for all other samples.
 Cross validation of time series data
 ====================================
 
-Time series data is characterised by the correlation between observations 
-that are near in time (*autocorrelation*). However, classical 
-cross-validation techniques such as :class:`KFold` and 
-:class:`ShuffleSplit` assume the samples are independent and 
-identically distributed, and would result in unreasonable correlation 
-between training and testing instances (yielding poor estimates of 
-generalisation error) on time series data. Therefore, it is very important 
-to evaluate our model for time series data on the "future" observations 
-least like those that are used to train the model. To achieve this, one 
+Time series data is characterised by the correlation between observations
+that are near in time (*autocorrelation*). However, classical
+cross-validation techniques such as :class:`KFold` and
+:class:`ShuffleSplit` assume the samples are independent and
+identically distributed, and would result in unreasonable correlation
+between training and testing instances (yielding poor estimates of
+generalisation error) on time series data. Therefore, it is very important
+to evaluate our model for time series data on the "future" observations
+least like those that are used to train the model. To achieve this, one
 solution is provided by :class:`TimeSeriesSplit`.
 
 
 Time Series Split
 -----------------
 
-:class:`TimeSeriesSplit` is a variation of *k-fold* which 
-returns first :math:`k` folds as train set and the :math:`(k+1)` th 
-fold as test set. Note that unlike standard cross-validation methods, 
+:class:`TimeSeriesSplit` is a variation of *k-fold* which
+returns first :math:`k` folds as train set and the :math:`(k+1)` th
+fold as test set. Note that unlike standard cross-validation methods,
 successive training sets are supersets of those that come before them.
 Also, it adds all surplus data to the first training partition, which
 is always used to train the model.
 
-This class can be used to cross-validate time series data samples 
+This class can be used to cross-validate time series data samples
 that are observed at fixed time intervals.
 
 Example of 3-split time series cross-validation on a dataset with 6 samples::
@@ -634,7 +634,7 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::
   >>> y = np.array([1, 2, 3, 4, 5, 6])
   >>> tscv = TimeSeriesSplit(n_splits=3)
   >>> print(tscv)  # doctest: +NORMALIZE_WHITESPACE
-  TimeSeriesSplit(n_splits=3)
+  TimeSeriesSplit(max_train_size=None, n_splits=3)
   >>> for train, test in tscv.split(X):
   ...     print("%s %s" % (train, test))
   [0 1 2] [3]
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index d51487cd1cb4e6f3954ec9f0faf4b7790107f7e8..f58d1857aad0fa73cf935a5a78fe6653262ac9a9 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -673,6 +673,9 @@ class TimeSeriesSplit(_BaseKFold):
     n_splits : int, default=3
         Number of splits. Must be at least 1.
 
+    max_train_size : int, optional
+        Maximum size for a single training set.
+
     Examples
     --------
     >>> from sklearn.model_selection import TimeSeriesSplit
@@ -680,7 +683,7 @@ class TimeSeriesSplit(_BaseKFold):
     >>> y = np.array([1, 2, 3, 4])
     >>> tscv = TimeSeriesSplit(n_splits=3)
     >>> print(tscv)  # doctest: +NORMALIZE_WHITESPACE
-    TimeSeriesSplit(n_splits=3)
+    TimeSeriesSplit(max_train_size=None, n_splits=3)
     >>> for train_index, test_index in tscv.split(X):
     ...    print("TRAIN:", train_index, "TEST:", test_index)
     ...    X_train, X_test = X[train_index], X[test_index]
@@ -696,10 +699,11 @@ class TimeSeriesSplit(_BaseKFold):
     with a test set of size ``n_samples//(n_splits + 1)``,
     where ``n_samples`` is the number of samples.
     """
-    def __init__(self, n_splits=3):
+    def __init__(self, n_splits=3, max_train_size=None):
         super(TimeSeriesSplit, self).__init__(n_splits,
                                               shuffle=False,
                                               random_state=None)
+        self.max_train_size = max_train_size
 
     def split(self, X, y=None, groups=None):
         """Generate indices to split data into training and test set.
@@ -738,8 +742,12 @@ class TimeSeriesSplit(_BaseKFold):
         test_starts = range(test_size + n_samples % n_folds,
                             n_samples, test_size)
         for test_start in test_starts:
-            yield (indices[:test_start],
-                   indices[test_start:test_start + test_size])
+            if self.max_train_size and self.max_train_size < test_start:
+                yield (indices[test_start - self.max_train_size:test_start],
+                       indices[test_start:test_start + test_size])
+            else:
+                yield (indices[:test_start],
+                       indices[test_start:test_start + test_size])
 
 
 class LeaveOneGroupOut(BaseCrossValidator):
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index 546609f413c1549939c02583cf3e3ad01c17097f..98a6d106721b39506ca0f81e1caf1f21db242265 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -234,7 +234,7 @@ def test_kfold_valueerrors():
     X1 = np.array([[1, 2], [3, 4], [5, 6]])
     X2 = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
     # Check that errors are raised if there is not enough samples
-    assert_raises(ValueError, next, KFold(4).split(X1))
+    (ValueError, next, KFold(4).split(X1))
 
     # Check that a warning is raised if the least populated class has too few
     # members.
@@ -1289,6 +1289,29 @@ def test_time_series_cv():
     assert_equal(n_splits_actual, 2)
 
 
+def _check_time_series_max_train_size(splits, check_splits, max_train_size):
+    for (train, test), (check_train, check_test) in zip(splits, check_splits):
+        assert_array_equal(test, check_test)
+        assert_true(len(check_train) <= max_train_size)
+        suffix_start = max(len(train) - max_train_size, 0)
+        assert_array_equal(check_train, train[suffix_start:])
+
+
+def test_time_series_max_train_size():
+    X = np.zeros((6, 1))
+    splits = TimeSeriesSplit(n_splits=3).split(X)
+    check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
+    _check_time_series_max_train_size(splits, check_splits, max_train_size=3)
+
+    # Test for the case where the size of a fold is greater than max_train_size
+    check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
+    _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
+
+    # Test for the case where the size of each fold is less than max_train_size
+    check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
+    _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
+
+
 def test_nested_cv():
     # Test if nested cross validation works with different combinations of cv
     rng = np.random.RandomState(0)