From d7c956afc8c67642bbd42fd9bacb007b5fbc443a Mon Sep 17 00:00:00 2001 From: Raghav RV <rvraghav93@gmail.com> Date: Thu, 3 Nov 2016 23:48:47 +0100 Subject: [PATCH] [MRG] FIX Validate and convert X, y and groups to ndarray before splitting (#7593) --- sklearn/model_selection/_split.py | 57 ++++-- sklearn/model_selection/tests/test_split.py | 195 ++++++++++++++++++-- 2 files changed, 215 insertions(+), 37 deletions(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 642e8107e1..7d26a0d5b0 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -25,6 +25,7 @@ import numpy as np from scipy.misc import comb from ..utils import indexable, check_random_state, safe_indexing from ..utils.validation import _num_samples, column_or_1d +from ..utils.validation import check_array from ..utils.multiclass import type_of_target from ..externals.six import with_metaclass from ..externals.six.moves import zip @@ -472,6 +473,7 @@ class GroupKFold(_BaseKFold): def _iter_test_indices(self, X, y, groups): if groups is None: raise ValueError("The groups parameter should not be None") + groups = check_array(groups, ensure_2d=False, dtype=None) unique_groups, groups = np.unique(groups, return_inverse=True) n_groups = len(unique_groups) @@ -618,12 +620,16 @@ class StratifiedKFold(_BaseKFold): Training data, where n_samples is the number of samples and n_features is the number of features. + Note that providing ``y`` is sufficient to generate the splits and + hence ``np.zeros(n_samples)`` may be used as a placeholder for + ``X`` instead of actual training data. + y : array-like, shape (n_samples,) The target variable for supervised learning problems. + Stratification is done based on the y labels. - groups : array-like, with shape (n_samples,), optional - Group labels for the samples used while splitting the dataset into - train/test set. + groups : object + Always ignored, exists for compatibility. Returns ------- @@ -633,6 +639,7 @@ class StratifiedKFold(_BaseKFold): test : ndarray The testing set indices for that split. """ + y = check_array(y, ensure_2d=False, dtype=None) return super(StratifiedKFold, self).split(X, y, groups) @@ -696,11 +703,10 @@ class TimeSeriesSplit(_BaseKFold): and n_features is the number of features. y : array-like, shape (n_samples,) - The target variable for supervised learning problems. + Always ignored, exists for compatibility. groups : array-like, with shape (n_samples,), optional - Group labels for the samples used while splitting the dataset into - train/test set. + Always ignored, exists for compatibility. Returns ------- @@ -746,12 +752,12 @@ class LeaveOneGroupOut(BaseCrossValidator): >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> y = np.array([1, 2, 1, 2]) >>> groups = np.array([1, 1, 2, 2]) - >>> lol = LeaveOneGroupOut() - >>> lol.get_n_splits(X, y, groups) + >>> logo = LeaveOneGroupOut() + >>> logo.get_n_splits(X, y, groups) 2 - >>> print(lol) + >>> print(logo) LeaveOneGroupOut() - >>> for train_index, test_index in lol.split(X, y, groups): + >>> for train_index, test_index in logo.split(X, y, groups): ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] ... y_train, y_test = y[train_index], y[test_index] @@ -771,7 +777,7 @@ class LeaveOneGroupOut(BaseCrossValidator): if groups is None: raise ValueError("The groups parameter should not be None") # We make a copy of groups to avoid side-effects during iteration - groups = np.array(groups, copy=True) + groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) if len(unique_groups) <= 1: raise ValueError( @@ -833,12 +839,12 @@ class LeavePGroupsOut(BaseCrossValidator): >>> X = np.array([[1, 2], [3, 4], [5, 6]]) >>> y = np.array([1, 2, 1]) >>> groups = np.array([1, 2, 3]) - >>> lpl = LeavePGroupsOut(n_groups=2) - >>> lpl.get_n_splits(X, y, groups) + >>> lpgo = LeavePGroupsOut(n_groups=2) + >>> lpgo.get_n_splits(X, y, groups) 3 - >>> print(lpl) + >>> print(lpgo) LeavePGroupsOut(n_groups=2) - >>> for train_index, test_index in lpl.split(X, y, groups): + >>> for train_index, test_index in lpgo.split(X, y, groups): ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] ... y_train, y_test = y[train_index], y[test_index] @@ -864,7 +870,7 @@ class LeavePGroupsOut(BaseCrossValidator): def _iter_test_masks(self, X, y, groups): if groups is None: raise ValueError("The groups parameter should not be None") - groups = np.array(groups, copy=True) + groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) if self.n_groups >= len(unique_groups): raise ValueError( @@ -886,9 +892,11 @@ class LeavePGroupsOut(BaseCrossValidator): ---------- X : object Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. y : object Always ignored, exists for compatibility. + ``np.zeros(n_samples)`` may be used as a placeholder. groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into @@ -901,6 +909,8 @@ class LeavePGroupsOut(BaseCrossValidator): """ if groups is None: raise ValueError("The groups parameter should not be None") + groups = check_array(groups, ensure_2d=False, dtype=None) + X, y, groups = indexable(X, y, groups) return int(comb(len(np.unique(groups)), self.n_groups, exact=True)) @@ -1097,6 +1107,7 @@ class GroupShuffleSplit(ShuffleSplit): def _iter_indices(self, X, y, groups): if groups is None: raise ValueError("The groups parameter should not be None") + groups = check_array(groups, ensure_2d=False, dtype=None) classes, group_indices = np.unique(groups, return_inverse=True) for group_train, group_test in super( GroupShuffleSplit, self)._iter_indices(X=classes): @@ -1237,6 +1248,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit): def _iter_indices(self, X, y, groups=None): n_samples = _num_samples(X) + y = check_array(y, ensure_2d=False, dtype=None) n_train, n_test = _validate_shuffle_split(n_samples, self.test_size, self.train_size) classes, y_indices = np.unique(y, return_inverse=True) @@ -1290,12 +1302,16 @@ class StratifiedShuffleSplit(BaseShuffleSplit): Training data, where n_samples is the number of samples and n_features is the number of features. + Note that providing ``y`` is sufficient to generate the splits and + hence ``np.zeros(n_samples)`` may be used as a placeholder for + ``X`` instead of actual training data. + y : array-like, shape (n_samples,) The target variable for supervised learning problems. + Stratification is done based on the y labels. - groups : array-like, with shape (n_samples,), optional - Group labels for the samples used while splitting the dataset into - train/test set. + groups : object + Always ignored, exists for compatibility. Returns ------- @@ -1305,6 +1321,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit): test : ndarray The testing set indices for that split. """ + y = check_array(y, ensure_2d=False, dtype=None) return super(StratifiedShuffleSplit, self).split(X, y, groups) @@ -1613,7 +1630,7 @@ def train_test_split(*arrays, **options): stratify : array-like or None (default is None) If not None, data is split in a stratified fashion, using this as - the groups array. + the class labels. Returns ------- diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 936abf03ac..660b0b1781 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -59,9 +59,80 @@ from sklearn.svm import SVC X = np.ones(10) y = np.arange(10) // 2 +P_sparse = coo_matrix(np.eye(5)) +test_groups = ( + np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), + np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]), + np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]), + np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]), + [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3], + ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3']) digits = load_digits() +class MockClassifier(object): + """Dummy classifier to test the cross-validation""" + + def __init__(self, a=0, allow_nd=False): + self.a = a + self.allow_nd = allow_nd + + def fit(self, X, Y=None, sample_weight=None, class_prior=None, + sparse_sample_weight=None, sparse_param=None, dummy_int=None, + dummy_str=None, dummy_obj=None, callback=None): + """The dummy arguments are to test that this fit function can + accept non-array arguments through cross-validation, such as: + - int + - str (this is actually array-like) + - object + - function + """ + self.dummy_int = dummy_int + self.dummy_str = dummy_str + self.dummy_obj = dummy_obj + if callback is not None: + callback(self) + + if self.allow_nd: + X = X.reshape(len(X), -1) + if X.ndim >= 3 and not self.allow_nd: + raise ValueError('X cannot be d') + if sample_weight is not None: + assert_true(sample_weight.shape[0] == X.shape[0], + 'MockClassifier extra fit_param sample_weight.shape[0]' + ' is {0}, should be {1}'.format(sample_weight.shape[0], + X.shape[0])) + if class_prior is not None: + assert_true(class_prior.shape[0] == len(np.unique(y)), + 'MockClassifier extra fit_param class_prior.shape[0]' + ' is {0}, should be {1}'.format(class_prior.shape[0], + len(np.unique(y)))) + if sparse_sample_weight is not None: + fmt = ('MockClassifier extra fit_param sparse_sample_weight' + '.shape[0] is {0}, should be {1}') + assert_true(sparse_sample_weight.shape[0] == X.shape[0], + fmt.format(sparse_sample_weight.shape[0], X.shape[0])) + if sparse_param is not None: + fmt = ('MockClassifier extra fit_param sparse_param.shape ' + 'is ({0}, {1}), should be ({2}, {3})') + assert_true(sparse_param.shape == P_sparse.shape, + fmt.format(sparse_param.shape[0], + sparse_param.shape[1], + P_sparse.shape[0], P_sparse.shape[1])) + return self + + def predict(self, T): + if self.allow_nd: + T = T.reshape(len(T), -1) + return T[:, 0] + + def score(self, X=None, Y=None): + return 1. / (1 + np.abs(self.a)) + + def get_params(self, deep=False): + return {'a': self.a, 'allow_nd': self.allow_nd} + + @ignore_warnings def test_cross_validator_with_default_params(): n_samples = 4 @@ -263,6 +334,14 @@ def test_stratified_kfold_no_shuffle(): # Check if get_n_splits returns the number of folds assert_equal(5, StratifiedKFold(5).get_n_splits(X, y)) + # Make sure string labels are also supported + X = np.ones(7) + y1 = ['1', '1', '1', '0', '0', '0', '0'] + y2 = [1, 1, 1, 0, 0, 0, 0] + np.testing.assert_equal( + list(StratifiedKFold(2).split(X, y1)), + list(StratifiedKFold(2).split(X, y2))) + def test_stratified_kfold_ratios(): # Check that stratified kfold preserves class ratios in individual splits @@ -485,12 +564,15 @@ def test_stratified_shuffle_split_iter(): np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2), np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]), np.array([-1] * 800 + [1] * 50), - np.concatenate([[i] * (100 + i) for i in range(11)]) + np.concatenate([[i] * (100 + i) for i in range(11)]), + [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3], + ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'], ] for y in ys: sss = StratifiedShuffleSplit(6, test_size=0.33, random_state=0).split(np.ones(len(y)), y) + y = np.asanyarray(y) # To make it indexable for y[train] # this is how test-size is computed internally # in _validate_shuffle_split test_size = np.ceil(0.33 * len(y)) @@ -598,13 +680,8 @@ def test_predefinedsplit_with_kfold_split(): def test_group_shuffle_split(): - groups = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), - np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]), - np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]), - np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4])] - - for l in groups: - X = y = np.ones(len(l)) + for groups_i in test_groups: + X = y = np.ones(len(groups_i)) n_splits = 6 test_size = 1./3 slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0) @@ -613,11 +690,12 @@ def test_group_shuffle_split(): repr(slo) # Test that the length is correct - assert_equal(slo.get_n_splits(X, y, groups=l), n_splits) + assert_equal(slo.get_n_splits(X, y, groups=groups_i), n_splits) - l_unique = np.unique(l) + l_unique = np.unique(groups_i) + l = np.asarray(groups_i) - for train, test in slo.split(X, y, groups=l): + for train, test in slo.split(X, y, groups=groups_i): # First test: no train group is in the test set and vice versa l_train_unique = np.unique(l[train]) l_test_unique = np.unique(l[test]) @@ -638,6 +716,46 @@ def test_group_shuffle_split(): round((1.0 - test_size) * len(l_unique))) <= 1) +def test_leave_one_p_group_out(): + logo = LeaveOneGroupOut() + lpgo_1 = LeavePGroupsOut(n_groups=1) + lpgo_2 = LeavePGroupsOut(n_groups=2) + + # Make sure the repr works + assert_equal(repr(logo), 'LeaveOneGroupOut()') + assert_equal(repr(lpgo_1), 'LeavePGroupsOut(n_groups=1)') + assert_equal(repr(lpgo_2), 'LeavePGroupsOut(n_groups=2)') + assert_equal(repr(LeavePGroupsOut(n_groups=3)), + 'LeavePGroupsOut(n_groups=3)') + + for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), + (lpgo_2, 2))): + for i, groups_i in enumerate(test_groups): + n_groups = len(np.unique(groups_i)) + n_splits = (n_groups if p_groups_out == 1 + else n_groups * (n_groups - 1) / 2) + X = y = np.ones(len(groups_i)) + + # Test that the length is correct + assert_equal(cv.get_n_splits(X, y, groups=groups_i), n_splits) + + groups_arr = np.asarray(groups_i) + + # Split using the original list / array / list of string groups_i + for train, test in cv.split(X, y, groups=groups_i): + # First test: no train group is in the test set and vice versa + assert_array_equal(np.intersect1d(groups_arr[train], + groups_arr[test]).tolist(), + []) + + # Second test: train and test add up to all the data + assert_equal(len(train) + len(test), len(groups_i)) + + # Third test: + # The number of groups in test must be equal to p_groups_out + assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out) + + def test_leave_group_out_changing_groups(): # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if # the groups variable is changed before calling split @@ -655,16 +773,17 @@ def test_leave_group_out_changing_groups(): assert_array_equal(test, test_chan) # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3 - assert_equal(3, LeavePGroupsOut(n_groups=2).get_n_splits(X, y, groups)) + assert_equal( + 3, LeavePGroupsOut(n_groups=2).get_n_splits(X, y=X, + groups=groups)) # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups) - assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups)) + assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y=X, + groups=groups)) def test_leave_one_p_group_out_error_on_fewer_number_of_groups(): X = y = groups = np.ones(0) - msg = ("The groups parameter contains fewer than 2 unique groups ([]). " - "LeaveOneGroupOut expects at least 2.") - assert_raise_message(ValueError, msg, next, + assert_raise_message(ValueError, "Found array with 0 sample(s)", next, LeaveOneGroupOut().split(X, y, groups)) X = y = groups = np.ones(1) msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). " @@ -780,6 +899,27 @@ def train_test_split_mock_pandas(): X_train_arr, X_test_arr = train_test_split(X_df) +def train_test_split_list_input(): + # Check that when y is a list / list of string labels, it works. + X = np.ones(7) + y1 = ['1'] * 4 + ['0'] * 3 + y2 = np.hstack((np.ones(4), np.zeros(3))) + y3 = y2.tolist() + + for stratify in (True, False): + X_train1, X_test1, y_train1, y_test1 = train_test_split( + X, y1, stratify=y1 if stratify else None, random_state=0) + X_train2, X_test2, y_train2, y_test2 = train_test_split( + X, y2, stratify=y2 if stratify else None, random_state=0) + X_train3, X_test3, y_train3, y_test3 = train_test_split( + X, y3, stratify=y3 if stratify else None, random_state=0) + + np.testing.assert_equal(X_train1, X_train2) + np.testing.assert_equal(y_train2, y_train3) + np.testing.assert_equal(X_test1, X_test3) + np.testing.assert_equal(y_test3, y_test2) + + def test_shufflesplit_errors(): # When the {test|train}_size is a float/invalid, error is raised at init assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None) @@ -804,6 +944,20 @@ def test_shufflesplit_reproducible(): list(a for a, b in ss.split(X))) +def test_stratifiedshufflesplit_list_input(): + # Check that when y is a list / list of string labels, it works. + sss = StratifiedShuffleSplit(test_size=2, random_state=42) + X = np.ones(7) + y1 = ['1'] * 4 + ['0'] * 3 + y2 = np.hstack((np.ones(4), np.zeros(3))) + y3 = y2.tolist() + + np.testing.assert_equal(list(sss.split(X, y1)), + list(sss.split(X, y2))) + np.testing.assert_equal(list(sss.split(X, y3)), + list(sss.split(X, y2))) + + def test_train_test_split_allow_nans(): # Check that train_test_split allows input data with NaNs X = np.arange(200, dtype=np.float64).reshape(10, -1) @@ -831,7 +985,7 @@ def test_check_cv(): X = np.ones(5) y_multilabel = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1], - [1, 1, 0, 1], [0, 0, 1, 0]]) + [1, 1, 0, 1], [0, 0, 1, 0]]) cv = check_cv(3, y_multilabel, classifier=True) np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X))) @@ -963,6 +1117,13 @@ def test_group_kfold(): for train, test in lkf.split(X, y, groups): assert_equal(len(np.intersect1d(groups[train], groups[test])), 0) + # groups can also be a list + cv_iter = list(lkf.split(X, y, groups.tolist())) + for (train1, test1), (train2, test2) in zip(lkf.split(X, y, groups), + cv_iter): + assert_array_equal(train1, train2) + assert_array_equal(test1, test2) + # Should fail if there are more folds than groups groups = np.array([1, 1, 1, 2, 2]) X = y = np.ones(len(groups)) -- GitLab