diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index d4cd2537e524035343f011990e35e0ecc648a232..cc4eb7746578a6cc03a968deb665caaab5040aec 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1478,6 +1478,11 @@ class StratifiedShuffleSplit(BaseShuffleSplit): y = check_array(y, ensure_2d=False, dtype=None) n_train, n_test = _validate_shuffle_split(n_samples, self.test_size, self.train_size) + + if y.ndim == 2: + # for multi-label y, map each distinct row to its string repr: + y = np.array([str(row) for row in y]) + classes, y_indices = np.unique(y, return_inverse=True) n_classes = classes.shape[0] diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 93f0dff1891d1c5098335ccbc4125dcf265b25c7..0135465e0ffd2720fdeec8bc74b03cca8eebaa0f 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -663,10 +663,37 @@ def test_stratified_shuffle_split_overlap_train_test_bug(): sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0) - train, test = next(iter(sss.split(X=X, y=y))) + train, test = next(sss.split(X=X, y=y)) + # no overlap assert_array_equal(np.intersect1d(train, test), []) + # complete partition + assert_array_equal(np.union1d(train, test), np.arange(len(y))) + + +def test_stratified_shuffle_split_multilabel(): + # fix for issue 9037 + for y in [np.array([[0, 1], [1, 0], [1, 0], [0, 1]]), + np.array([[0, 1], [1, 1], [1, 1], [0, 1]])]: + X = np.ones_like(y) + sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0) + train, test = next(sss.split(X=X, y=y)) + y_train = y[train] + y_test = y[test] + + # no overlap + assert_array_equal(np.intersect1d(train, test), []) + + # complete partition + assert_array_equal(np.union1d(train, test), np.arange(len(y))) + + # correct stratification of entire rows + # (by design, here y[:, 0] uniquely determines the entire row of y) + expected_ratio = np.mean(y[:, 0]) + assert_equal(expected_ratio, np.mean(y_train[:, 0])) + assert_equal(expected_ratio, np.mean(y_test[:, 0])) + def test_predefinedsplit_with_kfold_split(): # Check that PredefinedSplit can reproduce a split generated by Kfold.