diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index cf109e621626b9ac3e90601b653075d60cd46837..0064830c9a952d58b692abaa0d3fdf35ab9d5091 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -773,6 +773,10 @@ class LeaveOneGroupOut(BaseCrossValidator): # We make a copy of groups to avoid side-effects during iteration groups = np.array(groups, copy=True) unique_groups = np.unique(groups) + if len(unique_groups) <= 1: + raise ValueError( + "The groups parameter contains fewer than 2 unique groups " + "(%s). LeaveOneGroupOut expects at least 2." % unique_groups) for i in unique_groups: yield groups == i @@ -862,6 +866,12 @@ class LeavePGroupsOut(BaseCrossValidator): raise ValueError("The groups parameter should not be None") groups = np.array(groups, copy=True) unique_groups = np.unique(groups) + if self.n_groups >= len(unique_groups): + raise ValueError( + "The groups parameter contains fewer than (or equal to) " + "n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut " + "expects that at least n_groups + 1 (%d) unique groups be " + "present" % (self.n_groups, unique_groups, self.n_groups + 1)) combi = combinations(range(len(unique_groups)), self.n_groups) for indices in combi: test_index = np.zeros(_num_samples(X), dtype=np.bool) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 4dcd8f55038d8c29dbd8c28bb6ed973a1218f4dd..b547ac641556352e8270e1a2abd5916d389e80dc 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -724,6 +724,31 @@ def test_leave_group_out_changing_groups(): assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups)) +def test_leave_one_p_group_out_error_on_fewer_number_of_groups(): + X = y = groups = np.ones(0) + msg = ("The groups parameter contains fewer than 2 unique groups ([]). " + "LeaveOneGroupOut expects at least 2.") + assert_raise_message(ValueError, msg, next, + LeaveOneGroupOut().split(X, y, groups)) + X = y = groups = np.ones(1) + msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). " + "LeaveOneGroupOut expects at least 2.") + assert_raise_message(ValueError, msg, next, + LeaveOneGroupOut().split(X, y, groups)) + X = y = groups = np.ones(1) + msg = ("The groups parameter contains fewer than (or equal to) n_groups " + "(3) numbers of unique groups ([ 1.]). LeavePGroupsOut expects " + "that at least n_groups + 1 (4) unique groups be present") + assert_raise_message(ValueError, msg, next, + LeavePGroupsOut(n_groups=3).split(X, y, groups)) + X = y = groups = np.arange(3) + msg = ("The groups parameter contains fewer than (or equal to) n_groups " + "(3) numbers of unique groups ([0 1 2]). LeavePGroupsOut expects " + "that at least n_groups + 1 (4) unique groups be present") + assert_raise_message(ValueError, msg, next, + LeavePGroupsOut(n_groups=3).split(X, y, groups)) + + def test_train_test_split_errors(): assert_raises(ValueError, train_test_split) assert_raises(ValueError, train_test_split, range(3), train_size=1.1)