From 73d3f03cfc83aa35edb89173a8450c4059000fce Mon Sep 17 00:00:00 2001 From: polmauri <polmauri@gmail.com> Date: Tue, 25 Oct 2016 08:42:39 -0700 Subject: [PATCH] [MRG + 1] FIX raise an error message when n_groups > number of groups (#7681) (#7683) * FIX raise an error message when n_groups > actual number of groups (#7681) This change addresses issue #7681: - Raise ValueError when n_groups > actual number of unique groups in LeaveOneGroupOut and LeavePGroupsOut. - Add unit test. * Make requested changes - Check error message with `assert_raise_message` - Pass parameters to `assert_raise_message` instead of defining functions * Update condition and exception message --- sklearn/model_selection/_split.py | 10 +++++++++ sklearn/model_selection/tests/test_split.py | 25 +++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index cf109e6216..0064830c9a 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -773,6 +773,10 @@ class LeaveOneGroupOut(BaseCrossValidator): # We make a copy of groups to avoid side-effects during iteration groups = np.array(groups, copy=True) unique_groups = np.unique(groups) + if len(unique_groups) <= 1: + raise ValueError( + "The groups parameter contains fewer than 2 unique groups " + "(%s). LeaveOneGroupOut expects at least 2." % unique_groups) for i in unique_groups: yield groups == i @@ -862,6 +866,12 @@ class LeavePGroupsOut(BaseCrossValidator): raise ValueError("The groups parameter should not be None") groups = np.array(groups, copy=True) unique_groups = np.unique(groups) + if self.n_groups >= len(unique_groups): + raise ValueError( + "The groups parameter contains fewer than (or equal to) " + "n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut " + "expects that at least n_groups + 1 (%d) unique groups be " + "present" % (self.n_groups, unique_groups, self.n_groups + 1)) combi = combinations(range(len(unique_groups)), self.n_groups) for indices in combi: test_index = np.zeros(_num_samples(X), dtype=np.bool) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 4dcd8f5503..b547ac6415 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -724,6 +724,31 @@ def test_leave_group_out_changing_groups(): assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups)) +def test_leave_one_p_group_out_error_on_fewer_number_of_groups(): + X = y = groups = np.ones(0) + msg = ("The groups parameter contains fewer than 2 unique groups ([]). " + "LeaveOneGroupOut expects at least 2.") + assert_raise_message(ValueError, msg, next, + LeaveOneGroupOut().split(X, y, groups)) + X = y = groups = np.ones(1) + msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). " + "LeaveOneGroupOut expects at least 2.") + assert_raise_message(ValueError, msg, next, + LeaveOneGroupOut().split(X, y, groups)) + X = y = groups = np.ones(1) + msg = ("The groups parameter contains fewer than (or equal to) n_groups " + "(3) numbers of unique groups ([ 1.]). LeavePGroupsOut expects " + "that at least n_groups + 1 (4) unique groups be present") + assert_raise_message(ValueError, msg, next, + LeavePGroupsOut(n_groups=3).split(X, y, groups)) + X = y = groups = np.arange(3) + msg = ("The groups parameter contains fewer than (or equal to) n_groups " + "(3) numbers of unique groups ([0 1 2]). LeavePGroupsOut expects " + "that at least n_groups + 1 (4) unique groups be present") + assert_raise_message(ValueError, msg, next, + LeavePGroupsOut(n_groups=3).split(X, y, groups)) + + def test_train_test_split_errors(): assert_raises(ValueError, train_test_split) assert_raises(ValueError, train_test_split, range(3), train_size=1.1) -- GitLab