diff --git a/doc/whats_new.rst b/doc/whats_new.rst index da2926c536017da414f6b92b28b48da143188785..f6ea4127c9bba273b664177e39fc471fb382283b 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -139,6 +139,11 @@ Bug fixes ``transform`` or ``predict_proba`` are called on the non-fitted estimator. by `Sebastian Raschka`_. + - Fixed bug in :class:`model_selection.StratifiedShuffleSplit` + where train and test sample could overlap in some edge cases, + see `#6121 <https://github.com/scikit-learn/scikit-learn/issues/6121>`_ for + more details. By `Loic Esteve`_. + API changes summary ------------------- diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 15843bc7269fd1e9c688fe29ab4d0c90e6381699..7c284b486eac1862d727d82dbb85374921da57bc 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1010,14 +1010,19 @@ class StratifiedShuffleSplit(BaseShuffleSplit): # Because of rounding issues (as n_train and n_test are not # dividers of the number of elements per class), we may end # up here with less samples in train and test than asked for. - if len(train) < self.n_train or len(test) < self.n_test: + if len(train) + len(test) < self.n_train + self.n_test: # We complete by affecting randomly the missing indexes missing_idx = np.where(bincount(train + test, minlength=len(self.y)) == 0, )[0] missing_idx = rng.permutation(missing_idx) - train.extend(missing_idx[:(self.n_train - len(train))]) - test.extend(missing_idx[-(self.n_test - len(test)):]) + n_missing_train = self.n_train - len(train) + n_missing_test = self.n_test - len(test) + + if n_missing_train > 0: + train.extend(missing_idx[:n_missing_train]) + if n_missing_test > 0: + test.extend(missing_idx[-n_missing_test:]) train = rng.permutation(train) test = rng.permutation(test) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 6f97017625ca31c7cd85866fc05114c61d38a312..4bec6003a244f870275b5a1d212ce1d26d8756e7 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1110,13 +1110,18 @@ class StratifiedShuffleSplit(BaseShuffleSplit): # Because of rounding issues (as n_train and n_test are not # dividers of the number of elements per class), we may end # up here with less samples in train and test than asked for. - if len(train) < n_train or len(test) < n_test: + if len(train) + len(test) < n_train + n_test: # We complete by affecting randomly the missing indexes missing_indices = np.where(bincount(train + test, minlength=len(y)) == 0)[0] missing_indices = rng.permutation(missing_indices) - train.extend(missing_indices[:(n_train - len(train))]) - test.extend(missing_indices[-(n_test - len(test)):]) + n_missing_train = n_train - len(train) + n_missing_test = n_test - len(test) + + if n_missing_train > 0: + train.extend(missing_indices[:n_missing_train]) + if n_missing_test > 0: + test.extend(missing_indices[-n_missing_test:]) train = rng.permutation(train) test = rng.permutation(test) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 907ea5815583052b4c830df3c3ea2425525cbc1c..69749f8e4c0aaf1fc2e07908af2f1591eddfdedc 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -610,6 +610,20 @@ def test_stratified_shuffle_split_even(): assert_counts_are_ok(test_counts, ex_test_p) +def test_stratified_shuffle_split_overlap_train_test_bug(): + # See https://github.com/scikit-learn/scikit-learn/issues/6121 for + # the original bug report + y = [0, 1, 2, 3] * 3 + [4, 5] * 5 + X = np.ones_like(y) + + splits = StratifiedShuffleSplit(n_iter=1, + test_size=0.5, random_state=0) + + train, test = next(iter(splits.split(X=X, y=y))) + + assert_array_equal(np.intersect1d(train, test), []) + + def test_predefinedsplit_with_kfold_split(): # Check that PredefinedSplit can reproduce a split generated by Kfold. folds = -1 * np.ones(10) diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 7886ea212793adda31ac7dbbaea74ed59f660854..027fe1ee458ed8f591ccd8726b204e4d590ede15 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -546,6 +546,18 @@ def test_stratified_shuffle_split_even(): assert_counts_are_ok(test_counts, ex_test_p) +def test_stratified_shuffle_split_overlap_train_test_bug(): + # See https://github.com/scikit-learn/scikit-learn/issues/6121 for + # the original bug report + labels = [0, 1, 2, 3] * 3 + [4, 5] * 5 + + splits = cval.StratifiedShuffleSplit(labels, n_iter=1, + test_size=0.5, random_state=0) + train, test = next(iter(splits)) + + assert_array_equal(np.intersect1d(train, test), []) + + def test_predefinedsplit_with_kfold_split(): # Check that PredefinedSplit can reproduce a split generated by Kfold. folds = -1 * np.ones(10)