From d7c956afc8c67642bbd42fd9bacb007b5fbc443a Mon Sep 17 00:00:00 2001
From: Raghav RV <rvraghav93@gmail.com>
Date: Thu, 3 Nov 2016 23:48:47 +0100
Subject: [PATCH] [MRG] FIX Validate and convert X, y and groups to ndarray
 before splitting (#7593)

---
 sklearn/model_selection/_split.py           |  57 ++++--
 sklearn/model_selection/tests/test_split.py | 195 ++++++++++++++++++--
 2 files changed, 215 insertions(+), 37 deletions(-)

diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 642e8107e1..7d26a0d5b0 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -25,6 +25,7 @@ import numpy as np
 from scipy.misc import comb
 from ..utils import indexable, check_random_state, safe_indexing
 from ..utils.validation import _num_samples, column_or_1d
+from ..utils.validation import check_array
 from ..utils.multiclass import type_of_target
 from ..externals.six import with_metaclass
 from ..externals.six.moves import zip
@@ -472,6 +473,7 @@ class GroupKFold(_BaseKFold):
     def _iter_test_indices(self, X, y, groups):
         if groups is None:
             raise ValueError("The groups parameter should not be None")
+        groups = check_array(groups, ensure_2d=False, dtype=None)
 
         unique_groups, groups = np.unique(groups, return_inverse=True)
         n_groups = len(unique_groups)
@@ -618,12 +620,16 @@ class StratifiedKFold(_BaseKFold):
             Training data, where n_samples is the number of samples
             and n_features is the number of features.
 
+            Note that providing ``y`` is sufficient to generate the splits and
+            hence ``np.zeros(n_samples)`` may be used as a placeholder for
+            ``X`` instead of actual training data.
+
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
+            Stratification is done based on the y labels.
 
-        groups : array-like, with shape (n_samples,), optional
-            Group labels for the samples used while splitting the dataset into
-            train/test set.
+        groups : object
+            Always ignored, exists for compatibility.
 
         Returns
         -------
@@ -633,6 +639,7 @@ class StratifiedKFold(_BaseKFold):
         test : ndarray
             The testing set indices for that split.
         """
+        y = check_array(y, ensure_2d=False, dtype=None)
         return super(StratifiedKFold, self).split(X, y, groups)
 
 
@@ -696,11 +703,10 @@ class TimeSeriesSplit(_BaseKFold):
             and n_features is the number of features.
 
         y : array-like, shape (n_samples,)
-            The target variable for supervised learning problems.
+            Always ignored, exists for compatibility.
 
         groups : array-like, with shape (n_samples,), optional
-            Group labels for the samples used while splitting the dataset into
-            train/test set.
+            Always ignored, exists for compatibility.
 
         Returns
         -------
@@ -746,12 +752,12 @@ class LeaveOneGroupOut(BaseCrossValidator):
     >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
     >>> y = np.array([1, 2, 1, 2])
     >>> groups = np.array([1, 1, 2, 2])
-    >>> lol = LeaveOneGroupOut()
-    >>> lol.get_n_splits(X, y, groups)
+    >>> logo = LeaveOneGroupOut()
+    >>> logo.get_n_splits(X, y, groups)
     2
-    >>> print(lol)
+    >>> print(logo)
     LeaveOneGroupOut()
-    >>> for train_index, test_index in lol.split(X, y, groups):
+    >>> for train_index, test_index in logo.split(X, y, groups):
     ...    print("TRAIN:", train_index, "TEST:", test_index)
     ...    X_train, X_test = X[train_index], X[test_index]
     ...    y_train, y_test = y[train_index], y[test_index]
@@ -771,7 +777,7 @@ class LeaveOneGroupOut(BaseCrossValidator):
         if groups is None:
             raise ValueError("The groups parameter should not be None")
         # We make a copy of groups to avoid side-effects during iteration
-        groups = np.array(groups, copy=True)
+        groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
         unique_groups = np.unique(groups)
         if len(unique_groups) <= 1:
             raise ValueError(
@@ -833,12 +839,12 @@ class LeavePGroupsOut(BaseCrossValidator):
     >>> X = np.array([[1, 2], [3, 4], [5, 6]])
     >>> y = np.array([1, 2, 1])
     >>> groups = np.array([1, 2, 3])
-    >>> lpl = LeavePGroupsOut(n_groups=2)
-    >>> lpl.get_n_splits(X, y, groups)
+    >>> lpgo = LeavePGroupsOut(n_groups=2)
+    >>> lpgo.get_n_splits(X, y, groups)
     3
-    >>> print(lpl)
+    >>> print(lpgo)
     LeavePGroupsOut(n_groups=2)
-    >>> for train_index, test_index in lpl.split(X, y, groups):
+    >>> for train_index, test_index in lpgo.split(X, y, groups):
     ...    print("TRAIN:", train_index, "TEST:", test_index)
     ...    X_train, X_test = X[train_index], X[test_index]
     ...    y_train, y_test = y[train_index], y[test_index]
@@ -864,7 +870,7 @@ class LeavePGroupsOut(BaseCrossValidator):
     def _iter_test_masks(self, X, y, groups):
         if groups is None:
             raise ValueError("The groups parameter should not be None")
-        groups = np.array(groups, copy=True)
+        groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
         unique_groups = np.unique(groups)
         if self.n_groups >= len(unique_groups):
             raise ValueError(
@@ -886,9 +892,11 @@ class LeavePGroupsOut(BaseCrossValidator):
         ----------
         X : object
             Always ignored, exists for compatibility.
+            ``np.zeros(n_samples)`` may be used as a placeholder.
 
         y : object
             Always ignored, exists for compatibility.
+            ``np.zeros(n_samples)`` may be used as a placeholder.
 
         groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
@@ -901,6 +909,8 @@ class LeavePGroupsOut(BaseCrossValidator):
         """
         if groups is None:
             raise ValueError("The groups parameter should not be None")
+        groups = check_array(groups, ensure_2d=False, dtype=None)
+        X, y, groups = indexable(X, y, groups)
         return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
 
 
@@ -1097,6 +1107,7 @@ class GroupShuffleSplit(ShuffleSplit):
     def _iter_indices(self, X, y, groups):
         if groups is None:
             raise ValueError("The groups parameter should not be None")
+        groups = check_array(groups, ensure_2d=False, dtype=None)
         classes, group_indices = np.unique(groups, return_inverse=True)
         for group_train, group_test in super(
                 GroupShuffleSplit, self)._iter_indices(X=classes):
@@ -1237,6 +1248,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
 
     def _iter_indices(self, X, y, groups=None):
         n_samples = _num_samples(X)
+        y = check_array(y, ensure_2d=False, dtype=None)
         n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
                                                   self.train_size)
         classes, y_indices = np.unique(y, return_inverse=True)
@@ -1290,12 +1302,16 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
             Training data, where n_samples is the number of samples
             and n_features is the number of features.
 
+            Note that providing ``y`` is sufficient to generate the splits and
+            hence ``np.zeros(n_samples)`` may be used as a placeholder for
+            ``X`` instead of actual training data.
+
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
+            Stratification is done based on the y labels.
 
-        groups : array-like, with shape (n_samples,), optional
-            Group labels for the samples used while splitting the dataset into
-            train/test set.
+        groups : object
+            Always ignored, exists for compatibility.
 
         Returns
         -------
@@ -1305,6 +1321,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
         test : ndarray
             The testing set indices for that split.
         """
+        y = check_array(y, ensure_2d=False, dtype=None)
         return super(StratifiedShuffleSplit, self).split(X, y, groups)
 
 
@@ -1613,7 +1630,7 @@ def train_test_split(*arrays, **options):
 
     stratify : array-like or None (default is None)
         If not None, data is split in a stratified fashion, using this as
-        the groups array.
+        the class labels.
 
     Returns
     -------
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index 936abf03ac..660b0b1781 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -59,9 +59,80 @@ from sklearn.svm import SVC
 
 X = np.ones(10)
 y = np.arange(10) // 2
+P_sparse = coo_matrix(np.eye(5))
+test_groups = (
+    np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
+    np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
+    np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
+    np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
+    [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
+    ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'])
 digits = load_digits()
 
 
+class MockClassifier(object):
+    """Dummy classifier to test the cross-validation"""
+
+    def __init__(self, a=0, allow_nd=False):
+        self.a = a
+        self.allow_nd = allow_nd
+
+    def fit(self, X, Y=None, sample_weight=None, class_prior=None,
+            sparse_sample_weight=None, sparse_param=None, dummy_int=None,
+            dummy_str=None, dummy_obj=None, callback=None):
+        """The dummy arguments are to test that this fit function can
+        accept non-array arguments through cross-validation, such as:
+            - int
+            - str (this is actually array-like)
+            - object
+            - function
+        """
+        self.dummy_int = dummy_int
+        self.dummy_str = dummy_str
+        self.dummy_obj = dummy_obj
+        if callback is not None:
+            callback(self)
+
+        if self.allow_nd:
+            X = X.reshape(len(X), -1)
+        if X.ndim >= 3 and not self.allow_nd:
+            raise ValueError('X cannot be d')
+        if sample_weight is not None:
+            assert_true(sample_weight.shape[0] == X.shape[0],
+                        'MockClassifier extra fit_param sample_weight.shape[0]'
+                        ' is {0}, should be {1}'.format(sample_weight.shape[0],
+                                                        X.shape[0]))
+        if class_prior is not None:
+            assert_true(class_prior.shape[0] == len(np.unique(y)),
+                        'MockClassifier extra fit_param class_prior.shape[0]'
+                        ' is {0}, should be {1}'.format(class_prior.shape[0],
+                                                        len(np.unique(y))))
+        if sparse_sample_weight is not None:
+            fmt = ('MockClassifier extra fit_param sparse_sample_weight'
+                   '.shape[0] is {0}, should be {1}')
+            assert_true(sparse_sample_weight.shape[0] == X.shape[0],
+                        fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
+        if sparse_param is not None:
+            fmt = ('MockClassifier extra fit_param sparse_param.shape '
+                   'is ({0}, {1}), should be ({2}, {3})')
+            assert_true(sparse_param.shape == P_sparse.shape,
+                        fmt.format(sparse_param.shape[0],
+                                   sparse_param.shape[1],
+                                   P_sparse.shape[0], P_sparse.shape[1]))
+        return self
+
+    def predict(self, T):
+        if self.allow_nd:
+            T = T.reshape(len(T), -1)
+        return T[:, 0]
+
+    def score(self, X=None, Y=None):
+        return 1. / (1 + np.abs(self.a))
+
+    def get_params(self, deep=False):
+        return {'a': self.a, 'allow_nd': self.allow_nd}
+
+
 @ignore_warnings
 def test_cross_validator_with_default_params():
     n_samples = 4
@@ -263,6 +334,14 @@ def test_stratified_kfold_no_shuffle():
     # Check if get_n_splits returns the number of folds
     assert_equal(5, StratifiedKFold(5).get_n_splits(X, y))
 
+    # Make sure string labels are also supported
+    X = np.ones(7)
+    y1 = ['1', '1', '1', '0', '0', '0', '0']
+    y2 = [1, 1, 1, 0, 0, 0, 0]
+    np.testing.assert_equal(
+        list(StratifiedKFold(2).split(X, y1)),
+        list(StratifiedKFold(2).split(X, y2)))
+
 
 def test_stratified_kfold_ratios():
     # Check that stratified kfold preserves class ratios in individual splits
@@ -485,12 +564,15 @@ def test_stratified_shuffle_split_iter():
           np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
           np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
           np.array([-1] * 800 + [1] * 50),
-          np.concatenate([[i] * (100 + i) for i in range(11)])
+          np.concatenate([[i] * (100 + i) for i in range(11)]),
+          [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
+          ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'],
           ]
 
     for y in ys:
         sss = StratifiedShuffleSplit(6, test_size=0.33,
                                      random_state=0).split(np.ones(len(y)), y)
+        y = np.asanyarray(y)  # To make it indexable for y[train]
         # this is how test-size is computed internally
         # in _validate_shuffle_split
         test_size = np.ceil(0.33 * len(y))
@@ -598,13 +680,8 @@ def test_predefinedsplit_with_kfold_split():
 
 
 def test_group_shuffle_split():
-    groups = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
-              np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
-              np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
-              np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4])]
-
-    for l in groups:
-        X = y = np.ones(len(l))
+    for groups_i in test_groups:
+        X = y = np.ones(len(groups_i))
         n_splits = 6
         test_size = 1./3
         slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0)
@@ -613,11 +690,12 @@ def test_group_shuffle_split():
         repr(slo)
 
         # Test that the length is correct
-        assert_equal(slo.get_n_splits(X, y, groups=l), n_splits)
+        assert_equal(slo.get_n_splits(X, y, groups=groups_i), n_splits)
 
-        l_unique = np.unique(l)
+        l_unique = np.unique(groups_i)
+        l = np.asarray(groups_i)
 
-        for train, test in slo.split(X, y, groups=l):
+        for train, test in slo.split(X, y, groups=groups_i):
             # First test: no train group is in the test set and vice versa
             l_train_unique = np.unique(l[train])
             l_test_unique = np.unique(l[test])
@@ -638,6 +716,46 @@ def test_group_shuffle_split():
                             round((1.0 - test_size) * len(l_unique))) <= 1)
 
 
+def test_leave_one_p_group_out():
+    logo = LeaveOneGroupOut()
+    lpgo_1 = LeavePGroupsOut(n_groups=1)
+    lpgo_2 = LeavePGroupsOut(n_groups=2)
+
+    # Make sure the repr works
+    assert_equal(repr(logo), 'LeaveOneGroupOut()')
+    assert_equal(repr(lpgo_1), 'LeavePGroupsOut(n_groups=1)')
+    assert_equal(repr(lpgo_2), 'LeavePGroupsOut(n_groups=2)')
+    assert_equal(repr(LeavePGroupsOut(n_groups=3)),
+                 'LeavePGroupsOut(n_groups=3)')
+
+    for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1),
+                                            (lpgo_2, 2))):
+        for i, groups_i in enumerate(test_groups):
+            n_groups = len(np.unique(groups_i))
+            n_splits = (n_groups if p_groups_out == 1
+                        else n_groups * (n_groups - 1) / 2)
+            X = y = np.ones(len(groups_i))
+
+            # Test that the length is correct
+            assert_equal(cv.get_n_splits(X, y, groups=groups_i), n_splits)
+
+            groups_arr = np.asarray(groups_i)
+
+            # Split using the original list / array / list of string groups_i
+            for train, test in cv.split(X, y, groups=groups_i):
+                # First test: no train group is in the test set and vice versa
+                assert_array_equal(np.intersect1d(groups_arr[train],
+                                                  groups_arr[test]).tolist(),
+                                   [])
+
+                # Second test: train and test add up to all the data
+                assert_equal(len(train) + len(test), len(groups_i))
+
+                # Third test:
+                # The number of groups in test must be equal to p_groups_out
+                assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out)
+
+
 def test_leave_group_out_changing_groups():
     # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
     # the groups variable is changed before calling split
@@ -655,16 +773,17 @@ def test_leave_group_out_changing_groups():
             assert_array_equal(test, test_chan)
 
     # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3
-    assert_equal(3, LeavePGroupsOut(n_groups=2).get_n_splits(X, y, groups))
+    assert_equal(
+        3, LeavePGroupsOut(n_groups=2).get_n_splits(X, y=X,
+                                                    groups=groups))
     # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups)
-    assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups))
+    assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y=X,
+                                                    groups=groups))
 
 
 def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
     X = y = groups = np.ones(0)
-    msg = ("The groups parameter contains fewer than 2 unique groups ([]). "
-           "LeaveOneGroupOut expects at least 2.")
-    assert_raise_message(ValueError, msg, next,
+    assert_raise_message(ValueError, "Found array with 0 sample(s)", next,
                          LeaveOneGroupOut().split(X, y, groups))
     X = y = groups = np.ones(1)
     msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). "
@@ -780,6 +899,27 @@ def train_test_split_mock_pandas():
     X_train_arr, X_test_arr = train_test_split(X_df)
 
 
+def train_test_split_list_input():
+    # Check that when y is a list / list of string labels, it works.
+    X = np.ones(7)
+    y1 = ['1'] * 4 + ['0'] * 3
+    y2 = np.hstack((np.ones(4), np.zeros(3)))
+    y3 = y2.tolist()
+
+    for stratify in (True, False):
+        X_train1, X_test1, y_train1, y_test1 = train_test_split(
+            X, y1, stratify=y1 if stratify else None, random_state=0)
+        X_train2, X_test2, y_train2, y_test2 = train_test_split(
+            X, y2, stratify=y2 if stratify else None, random_state=0)
+        X_train3, X_test3, y_train3, y_test3 = train_test_split(
+            X, y3, stratify=y3 if stratify else None, random_state=0)
+
+        np.testing.assert_equal(X_train1, X_train2)
+        np.testing.assert_equal(y_train2, y_train3)
+        np.testing.assert_equal(X_test1, X_test3)
+        np.testing.assert_equal(y_test3, y_test2)
+
+
 def test_shufflesplit_errors():
     # When the {test|train}_size is a float/invalid, error is raised at init
     assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None)
@@ -804,6 +944,20 @@ def test_shufflesplit_reproducible():
                        list(a for a, b in ss.split(X)))
 
 
+def test_stratifiedshufflesplit_list_input():
+    # Check that when y is a list / list of string labels, it works.
+    sss = StratifiedShuffleSplit(test_size=2, random_state=42)
+    X = np.ones(7)
+    y1 = ['1'] * 4 + ['0'] * 3
+    y2 = np.hstack((np.ones(4), np.zeros(3)))
+    y3 = y2.tolist()
+
+    np.testing.assert_equal(list(sss.split(X, y1)),
+                            list(sss.split(X, y2)))
+    np.testing.assert_equal(list(sss.split(X, y3)),
+                            list(sss.split(X, y2)))
+
+
 def test_train_test_split_allow_nans():
     # Check that train_test_split allows input data with NaNs
     X = np.arange(200, dtype=np.float64).reshape(10, -1)
@@ -831,7 +985,7 @@ def test_check_cv():
 
     X = np.ones(5)
     y_multilabel = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1],
-                            [1, 1, 0, 1], [0, 0, 1, 0]])
+                             [1, 1, 0, 1], [0, 0, 1, 0]])
     cv = check_cv(3, y_multilabel, classifier=True)
     np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
 
@@ -963,6 +1117,13 @@ def test_group_kfold():
     for train, test in lkf.split(X, y, groups):
         assert_equal(len(np.intersect1d(groups[train], groups[test])), 0)
 
+    # groups can also be a list
+    cv_iter = list(lkf.split(X, y, groups.tolist()))
+    for (train1, test1), (train2, test2) in zip(lkf.split(X, y, groups),
+                                                cv_iter):
+        assert_array_equal(train1, train2)
+        assert_array_equal(test1, test2)
+
     # Should fail if there are more folds than groups
     groups = np.array([1, 1, 1, 2, 2])
     X = y = np.ones(len(groups))
-- 
GitLab