diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index f58d1857aad0fa73cf935a5a78fe6653262ac9a9..49b7874facf2aa7397aa4747127b91a5f2a5af18 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1450,7 +1450,6 @@ class StratifiedShuffleSplit(BaseShuffleSplit): If None, the random number generator is the RandomState instance used by `np.random`. - Examples -------- >>> from sklearn.model_selection import StratifiedShuffleSplit @@ -1860,6 +1859,10 @@ def train_test_split(*arrays, **options): If None, the random number generator is the RandomState instance used by `np.random`. + shuffle : boolean, optional (default=True) + Whether or not to shuffle the data before splitting. If shuffle=False + then stratify must be None. + stratify : array-like or None (default is None) If not None, data is split in a stratified fashion, using this as the class labels. @@ -1903,6 +1906,9 @@ def train_test_split(*arrays, **options): >>> y_test [1, 4] + >>> train_test_split(y, shuffle=False) + [[0, 1, 2], [3, 4]] + """ n_arrays = len(arrays) if n_arrays == 0: @@ -1911,6 +1917,7 @@ def train_test_split(*arrays, **options): train_size = options.pop('train_size', None) random_state = options.pop('random_state', None) stratify = options.pop('stratify', None) + shuffle = options.pop('shuffle', True) if options: raise TypeError("Invalid parameters passed: %s" % str(options)) @@ -1920,22 +1927,38 @@ def train_test_split(*arrays, **options): arrays = indexable(*arrays) - if stratify is not None: - CVClass = StratifiedShuffleSplit + if shuffle is False: + if stratify is not None: + raise NotImplementedError( + "Stratified train/test split is not implemented for " + "shuffle=False") + + n_samples = _num_samples(arrays[0]) + n_train, n_test = _validate_shuffle_split(n_samples, test_size, + train_size) + + train = np.arange(n_train) + test = np.arange(n_train, n_train + n_test) + else: - CVClass = ShuffleSplit + if stratify is not None: + CVClass = StratifiedShuffleSplit + else: + CVClass = ShuffleSplit - cv = CVClass(test_size=test_size, - train_size=train_size, - random_state=random_state) + cv = CVClass(test_size=test_size, + train_size=train_size, + random_state=random_state) + + train, test = next(cv.split(X=arrays[0], y=stratify)) - train, test = next(cv.split(X=arrays[0], y=stratify)) return list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in arrays)) train_test_split.__test__ = False # to avoid a pb with nosetests + def _build_repr(self): # XXX This is copied from BaseEstimator's get_params cls = self.__class__ diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 98a6d106721b39506ca0f81e1caf1f21db242265..d6efff7b2b0fc834fc51fe9e9e7921dc997cca8e 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -931,6 +931,8 @@ def test_train_test_split_errors(): assert_raises(TypeError, train_test_split, range(3), some_argument=1.1) assert_raises(ValueError, train_test_split, range(3), range(42)) + assert_raises(NotImplementedError, train_test_split, range(10), + shuffle=False, stratify=True) def test_train_test_split(): @@ -973,6 +975,13 @@ def test_train_test_split(): # check the 1:1 ratio of ones and twos in the data is preserved assert_equal(np.sum(train == 1), np.sum(train == 2)) + # test unshuffled split + y = np.arange(10) + for test_size in [2, 0.2]: + train, test = train_test_split(y, shuffle=False, test_size=test_size) + assert_array_equal(test, [8, 9]) + assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7]) + @ignore_warnings def train_test_split_pandas():