diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index 96baa81e8db27ece25db41c73d449b21cc9812ef..61b27774599d55c3b075ce7ea401e30bef0daaac 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -521,6 +521,50 @@ See also
 stratified splits, *i.e* which creates splits by preserving the same
 percentage for each target class as in the complete set.
 
+Cross validation of time series data
+====================================
+
+Time series data is characterised by the correlation between observations 
+that are near in time (*autocorrelation*). However, classical 
+cross-validation techniques such as :class:`KFold` and 
+:class:`ShuffleSplit` assume the samples are independent and 
+identically distributed, and would result in unreasonable correlation 
+between training and testing instances (yielding poor estimates of 
+generalisation error) on time series data. Therefore, it is very important 
+to evaluate our model for time series data on the "future" observations 
+least like those that are used to train the model. To achieve this, one 
+solution is provided by :class:`TimeSeriesCV`.
+
+
+TimeSeriesCV
+-----------------------
+
+:class:`TimeSeriesCV` is a variation of *k-fold* which 
+returns first :math:`k` folds as train set and the :math:`(k+1)` th 
+fold as test set. Note that unlike standard cross-validation methods, 
+successive training sets are supersets of those that come before them.
+Also, it adds all surplus data to the first training partition, which
+is always used to train the model.
+
+This class can be used to cross-validate time series data samples 
+that are observed at fixed time intervals.
+
+Example of 3-split time series cross-validation on a dataset with 6 samples::
+
+  >>> from sklearn.model_selection import TimeSeriesCV
+
+  >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
+  >>> y = np.array([1, 2, 3, 4, 5, 6])
+  >>> tscv = TimeSeriesCV(n_splits=3)
+  >>> print(tscv)  # doctest: +NORMALIZE_WHITESPACE
+  TimeSeriesCV(n_splits=3)
+  >>> for train, test in tscv.split(X):
+  ...     print("%s %s" % (train, test))
+  [0 1 2] [3]
+  [0 1 2 3] [4]
+  [0 1 2 3 4] [5]
+
+
 A note on shuffling
 ===================
 
diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py
index caf0fe70ff4935a3c9e2f81f1f2beb776cc3d3d6..7942ef3cc25e5c247544e41bb8091dee85998ebf 100644
--- a/sklearn/model_selection/__init__.py
+++ b/sklearn/model_selection/__init__.py
@@ -2,6 +2,7 @@ from ._split import BaseCrossValidator
 from ._split import KFold
 from ._split import LabelKFold
 from ._split import StratifiedKFold
+from ._split import TimeSeriesCV
 from ._split import LeaveOneLabelOut
 from ._split import LeaveOneOut
 from ._split import LeavePLabelOut
@@ -27,6 +28,7 @@ from ._search import fit_grid_point
 
 __all__ = ('BaseCrossValidator',
            'GridSearchCV',
+           'TimeSeriesCV',
            'KFold',
            'LabelKFold',
            'LabelShuffleSplit',
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 9491b788edb7113d5c976167e37de1c15d076e66..236e57af1435d36cd72ba41a7e1fdcdd4999cf14 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -635,6 +635,98 @@ class StratifiedKFold(_BaseKFold):
         return super(StratifiedKFold, self).split(X, y, labels)
 
 
+class TimeSeriesCV(_BaseKFold):
+    """Time Series cross-validator
+
+    Provides train/test indices to split time series data samples
+    that are observed at fixed time intervals, in train/test sets.
+    In each split, test indices must be higher than before, and thus shuffling
+    in cross validator is inappropriate.
+
+    This cross-validation object is a variation of :class:`KFold`.
+    In the kth split, it returns first k folds as train set and the
+    (k+1)th fold as test set.
+
+    Note that unlike standard cross-validation methods, successive
+    training sets are supersets of those that come before them.
+
+    Read more in the :ref:`User Guide <cross_validation>`.
+
+    Parameters
+    ----------
+    n_splits : int, default=3
+        Number of splits. Must be at least 1.
+
+    Examples
+    --------
+    >>> from sklearn.model_selection import TimeSeriesCV
+    >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
+    >>> y = np.array([1, 2, 3, 4])
+    >>> tscv = TimeSeriesCV(n_splits=3)
+    >>> print(tscv)  # doctest: +NORMALIZE_WHITESPACE
+    TimeSeriesCV(n_splits=3)
+    >>> for train_index, test_index in tscv.split(X):
+    ...    print("TRAIN:", train_index, "TEST:", test_index)
+    ...    X_train, X_test = X[train_index], X[test_index]
+    ...    y_train, y_test = y[train_index], y[test_index]
+    TRAIN: [0] TEST: [1]
+    TRAIN: [0 1] TEST: [2]
+    TRAIN: [0 1 2] TEST: [3]
+
+    Notes
+    -----
+    The training set has size ``i * n_samples // (n_splits + 1)
+    + n_samples % (n_splits + 1)`` in the ``i``th split,
+    with a test set of size ``n_samples//(n_splits + 1)``,
+    where ``n_samples`` is the number of samples.
+    """
+    def __init__(self, n_splits=3):
+        super(TimeSeriesCV, self).__init__(n_splits,
+                                           shuffle=False,
+                                           random_state=None)
+
+    def split(self, X, y=None, labels=None):
+        """Generate indices to split data into training and test set.
+
+        Parameters
+        ----------
+        X : array-like, shape (n_samples, n_features)
+            Training data, where n_samples is the number of samples
+            and n_features is the number of features.
+
+        y : array-like, shape (n_samples,)
+            The target variable for supervised learning problems.
+
+        labels : array-like, with shape (n_samples,), optional
+            Group labels for the samples used while splitting the dataset into
+            train/test set.
+
+        Returns
+        -------
+        train : ndarray
+            The training set indices for that split.
+
+        test : ndarray
+            The testing set indices for that split.
+        """
+        X, y, labels = indexable(X, y, labels)
+        n_samples = _num_samples(X)
+        n_splits = self.n_splits
+        n_folds = n_splits + 1
+        if n_folds > n_samples:
+            raise ValueError(
+                ("Cannot have number of folds ={0} greater"
+                 " than the number of samples: {1}.").format(n_folds,
+                                                             n_samples))
+        indices = np.arange(n_samples)
+        test_size = (n_samples // n_folds)
+        test_starts = range(test_size + n_samples % n_folds,
+                            n_samples, test_size)
+        for test_start in test_starts:
+            yield (indices[:test_start],
+                   indices[test_start:test_start + test_size])
+
+
 class LeaveOneLabelOut(BaseCrossValidator):
     """Leave One Label Out cross-validator
 
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index d9b4277aae38fd0d5e24725eb720d4434a3eca5b..ef43c15a17505ae1916b3b7f29953c9ac6e44839 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -30,6 +30,7 @@ from sklearn.model_selection import cross_val_score
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
 from sklearn.model_selection import LabelKFold
+from sklearn.model_selection import TimeSeriesCV
 from sklearn.model_selection import LeaveOneOut
 from sklearn.model_selection import LeaveOneLabelOut
 from sklearn.model_selection import LeavePOut
@@ -997,6 +998,44 @@ def test_label_kfold():
                          next, LabelKFold(n_splits=3).split(X, y, labels))
 
 
+def test_time_series_cv():
+    X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
+
+    # Should fail if there are more folds than samples
+    assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
+                         next,
+                         TimeSeriesCV(n_splits=7).split(X))
+
+    tscv = TimeSeriesCV(2)
+
+    # Manually check that Time Series CV preserves the data
+    # ordering on toy datasets
+    splits = tscv.split(X[:-1])
+    train, test = next(splits)
+    assert_array_equal(train, [0, 1])
+    assert_array_equal(test, [2, 3])
+
+    train, test = next(splits)
+    assert_array_equal(train, [0, 1, 2, 3])
+    assert_array_equal(test, [4, 5])
+
+    splits = TimeSeriesCV(2).split(X)
+
+    train, test = next(splits)
+    assert_array_equal(train, [0, 1, 2])
+    assert_array_equal(test, [3, 4])
+
+    train, test = next(splits)
+    assert_array_equal(train, [0, 1, 2, 3, 4])
+    assert_array_equal(test, [5, 6])
+
+    # Check get_n_splits returns the correct number of splits
+    splits = TimeSeriesCV(2).split(X)
+    n_splits_actual = len(list(splits))
+    assert_equal(n_splits_actual, tscv.get_n_splits())
+    assert_equal(n_splits_actual, 2)
+
+
 def test_nested_cv():
     # Test if nested cross validation works with different combinations of cv
     rng = np.random.RandomState(0)