diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 96baa81e8db27ece25db41c73d449b21cc9812ef..61b27774599d55c3b075ce7ea401e30bef0daaac 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -521,6 +521,50 @@ See also stratified splits, *i.e* which creates splits by preserving the same percentage for each target class as in the complete set. +Cross validation of time series data +==================================== + +Time series data is characterised by the correlation between observations +that are near in time (*autocorrelation*). However, classical +cross-validation techniques such as :class:`KFold` and +:class:`ShuffleSplit` assume the samples are independent and +identically distributed, and would result in unreasonable correlation +between training and testing instances (yielding poor estimates of +generalisation error) on time series data. Therefore, it is very important +to evaluate our model for time series data on the "future" observations +least like those that are used to train the model. To achieve this, one +solution is provided by :class:`TimeSeriesCV`. + + +TimeSeriesCV +----------------------- + +:class:`TimeSeriesCV` is a variation of *k-fold* which +returns first :math:`k` folds as train set and the :math:`(k+1)` th +fold as test set. Note that unlike standard cross-validation methods, +successive training sets are supersets of those that come before them. +Also, it adds all surplus data to the first training partition, which +is always used to train the model. + +This class can be used to cross-validate time series data samples +that are observed at fixed time intervals. + +Example of 3-split time series cross-validation on a dataset with 6 samples:: + + >>> from sklearn.model_selection import TimeSeriesCV + + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([1, 2, 3, 4, 5, 6]) + >>> tscv = TimeSeriesCV(n_splits=3) + >>> print(tscv) # doctest: +NORMALIZE_WHITESPACE + TimeSeriesCV(n_splits=3) + >>> for train, test in tscv.split(X): + ... print("%s %s" % (train, test)) + [0 1 2] [3] + [0 1 2 3] [4] + [0 1 2 3 4] [5] + + A note on shuffling =================== diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index caf0fe70ff4935a3c9e2f81f1f2beb776cc3d3d6..7942ef3cc25e5c247544e41bb8091dee85998ebf 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -2,6 +2,7 @@ from ._split import BaseCrossValidator from ._split import KFold from ._split import LabelKFold from ._split import StratifiedKFold +from ._split import TimeSeriesCV from ._split import LeaveOneLabelOut from ._split import LeaveOneOut from ._split import LeavePLabelOut @@ -27,6 +28,7 @@ from ._search import fit_grid_point __all__ = ('BaseCrossValidator', 'GridSearchCV', + 'TimeSeriesCV', 'KFold', 'LabelKFold', 'LabelShuffleSplit', diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 9491b788edb7113d5c976167e37de1c15d076e66..236e57af1435d36cd72ba41a7e1fdcdd4999cf14 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -635,6 +635,98 @@ class StratifiedKFold(_BaseKFold): return super(StratifiedKFold, self).split(X, y, labels) +class TimeSeriesCV(_BaseKFold): + """Time Series cross-validator + + Provides train/test indices to split time series data samples + that are observed at fixed time intervals, in train/test sets. + In each split, test indices must be higher than before, and thus shuffling + in cross validator is inappropriate. + + This cross-validation object is a variation of :class:`KFold`. + In the kth split, it returns first k folds as train set and the + (k+1)th fold as test set. + + Note that unlike standard cross-validation methods, successive + training sets are supersets of those that come before them. + + Read more in the :ref:`User Guide <cross_validation>`. + + Parameters + ---------- + n_splits : int, default=3 + Number of splits. Must be at least 1. + + Examples + -------- + >>> from sklearn.model_selection import TimeSeriesCV + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([1, 2, 3, 4]) + >>> tscv = TimeSeriesCV(n_splits=3) + >>> print(tscv) # doctest: +NORMALIZE_WHITESPACE + TimeSeriesCV(n_splits=3) + >>> for train_index, test_index in tscv.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [0] TEST: [1] + TRAIN: [0 1] TEST: [2] + TRAIN: [0 1 2] TEST: [3] + + Notes + ----- + The training set has size ``i * n_samples // (n_splits + 1) + + n_samples % (n_splits + 1)`` in the ``i``th split, + with a test set of size ``n_samples//(n_splits + 1)``, + where ``n_samples`` is the number of samples. + """ + def __init__(self, n_splits=3): + super(TimeSeriesCV, self).__init__(n_splits, + shuffle=False, + random_state=None) + + def split(self, X, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : array-like, shape (n_samples,) + The target variable for supervised learning problems. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + X, y, labels = indexable(X, y, labels) + n_samples = _num_samples(X) + n_splits = self.n_splits + n_folds = n_splits + 1 + if n_folds > n_samples: + raise ValueError( + ("Cannot have number of folds ={0} greater" + " than the number of samples: {1}.").format(n_folds, + n_samples)) + indices = np.arange(n_samples) + test_size = (n_samples // n_folds) + test_starts = range(test_size + n_samples % n_folds, + n_samples, test_size) + for test_start in test_starts: + yield (indices[:test_start], + indices[test_start:test_start + test_size]) + + class LeaveOneLabelOut(BaseCrossValidator): """Leave One Label Out cross-validator diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index d9b4277aae38fd0d5e24725eb720d4434a3eca5b..ef43c15a17505ae1916b3b7f29953c9ac6e44839 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -30,6 +30,7 @@ from sklearn.model_selection import cross_val_score from sklearn.model_selection import KFold from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import LabelKFold +from sklearn.model_selection import TimeSeriesCV from sklearn.model_selection import LeaveOneOut from sklearn.model_selection import LeaveOneLabelOut from sklearn.model_selection import LeavePOut @@ -997,6 +998,44 @@ def test_label_kfold(): next, LabelKFold(n_splits=3).split(X, y, labels)) +def test_time_series_cv(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]] + + # Should fail if there are more folds than samples + assert_raises_regexp(ValueError, "Cannot have number of folds.*greater", + next, + TimeSeriesCV(n_splits=7).split(X)) + + tscv = TimeSeriesCV(2) + + # Manually check that Time Series CV preserves the data + # ordering on toy datasets + splits = tscv.split(X[:-1]) + train, test = next(splits) + assert_array_equal(train, [0, 1]) + assert_array_equal(test, [2, 3]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 2, 3]) + assert_array_equal(test, [4, 5]) + + splits = TimeSeriesCV(2).split(X) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 2]) + assert_array_equal(test, [3, 4]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 2, 3, 4]) + assert_array_equal(test, [5, 6]) + + # Check get_n_splits returns the correct number of splits + splits = TimeSeriesCV(2).split(X) + n_splits_actual = len(list(splits)) + assert_equal(n_splits_actual, tscv.get_n_splits()) + assert_equal(n_splits_actual, 2) + + def test_nested_cv(): # Test if nested cross validation works with different combinations of cv rng = np.random.RandomState(0)