From 754109c78d90cc792f0944aaaa11371175b09277 Mon Sep 17 00:00:00 2001 From: Toshihiro Kamishima <mail@kamishima.net> Date: Fri, 12 May 2017 21:56:13 +0900 Subject: [PATCH] [MRG+1] enable to use get_n_splits of LeaveOneGroupOut and LeavePGroupsOut with dummy parameters (#8794) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove needless argument checking * add parameter checking as in LeavePGroupsOut * add examples with dummy inputs * add unittest for a get_n_splits method in LeaveOneGroupOut and LeavePGroupsOut classes * X and y can be ommited in a get_n_splits function. * fix error messages * update examples * fix test for an error message * Revert "fix test for an error message" This reverts commit 68b984207c704de1a3411d9ded68a11eba1e56f3. * fix test for an error message * fix error messages * remove tailing white spaces * add periods to messages * test for ValueError’s of get_n_splits methods of LeaveOneOut / LeavePOut classes * fix documents: * parameter name: group -> groups * modfy white space --- sklearn/model_selection/_split.py | 44 +++++++++++-------- sklearn/model_selection/tests/test_search.py | 2 +- sklearn/model_selection/tests/test_split.py | 25 +++++++++++ .../model_selection/tests/test_validation.py | 4 +- 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 151bbafd62..de889fab0b 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -188,7 +188,7 @@ class LeaveOneOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if X is None: - raise ValueError("The X parameter should not be None") + raise ValueError("The 'X' parameter should not be None.") return _num_samples(X) @@ -259,7 +259,7 @@ class LeavePOut(BaseCrossValidator): Always ignored, exists for compatibility. """ if X is None: - raise ValueError("The X parameter should not be None") + raise ValueError("The 'X' parameter should not be None.") return int(comb(_num_samples(X), self.p, exact=True)) @@ -477,7 +477,7 @@ class GroupKFold(_BaseKFold): def _iter_test_indices(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) unique_groups, groups = np.unique(groups, return_inverse=True) @@ -765,6 +765,8 @@ class LeaveOneGroupOut(BaseCrossValidator): >>> logo = LeaveOneGroupOut() >>> logo.get_n_splits(X, y, groups) 2 + >>> logo.get_n_splits(groups=groups) # 'groups' is always required + 2 >>> print(logo) LeaveOneGroupOut() >>> for train_index, test_index in logo.split(X, y, groups): @@ -785,7 +787,7 @@ class LeaveOneGroupOut(BaseCrossValidator): def _iter_test_masks(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") # We make a copy of groups to avoid side-effects during iteration groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) @@ -796,20 +798,22 @@ class LeaveOneGroupOut(BaseCrossValidator): for i in unique_groups: yield groups == i - def get_n_splits(self, X, y, groups): + def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator Parameters ---------- - X : object + X : object, optional Always ignored, exists for compatibility. - y : object + y : object, optional Always ignored, exists for compatibility. groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into - train/test set. + train/test set. This 'groups' parameter must always be specified to + calculate the number of splits, though the other parameters can be + omitted. Returns ------- @@ -817,7 +821,8 @@ class LeaveOneGroupOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") + groups = check_array(groups, ensure_2d=False, dtype=None) return len(np.unique(groups)) @@ -852,6 +857,8 @@ class LeavePGroupsOut(BaseCrossValidator): >>> lpgo = LeavePGroupsOut(n_groups=2) >>> lpgo.get_n_splits(X, y, groups) 3 + >>> lpgo.get_n_splits(groups=groups) # 'groups' is always required + 3 >>> print(lpgo) LeavePGroupsOut(n_groups=2) >>> for train_index, test_index in lpgo.split(X, y, groups): @@ -879,7 +886,7 @@ class LeavePGroupsOut(BaseCrossValidator): def _iter_test_masks(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) if self.n_groups >= len(unique_groups): @@ -895,22 +902,22 @@ class LeavePGroupsOut(BaseCrossValidator): test_index[groups == l] = True yield test_index - def get_n_splits(self, X, y, groups): + def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator Parameters ---------- - X : object + X : object, optional Always ignored, exists for compatibility. - ``np.zeros(n_samples)`` may be used as a placeholder. - y : object + y : object, optional Always ignored, exists for compatibility. - ``np.zeros(n_samples)`` may be used as a placeholder. groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into - train/test set. + train/test set. This 'groups' parameter must always be specified to + calculate the number of splits, though the other parameters can be + omitted. Returns ------- @@ -918,9 +925,8 @@ class LeavePGroupsOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) - X, y, groups = indexable(X, y, groups) return int(comb(len(np.unique(groups)), self.n_groups, exact=True)) @@ -1318,7 +1324,7 @@ class GroupShuffleSplit(ShuffleSplit): def _iter_indices(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) classes, group_indices = np.unique(groups, return_inverse=True) for group_train, group_test in super( diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 055a4c061a..3f804d414b 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -317,7 +317,7 @@ def test_grid_search_groups(): for cv in group_cvs: gs = GridSearchCV(clf, grid, cv=cv) assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", gs.fit, X, y) gs.fit(X, y, groups=groups) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index fcd0160ca7..e97fdce5e1 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -189,6 +189,13 @@ def test_cross_validator_with_default_params(): # Test if the repr works without any errors assert_equal(cv_repr, repr(cv)) + # ValueError for get_n_splits methods + msg = "The 'X' parameter should not be None." + assert_raise_message(ValueError, msg, + loo.get_n_splits, None, y, groups) + assert_raise_message(ValueError, msg, + lpo.get_n_splits, None, y, groups) + def check_valid_split(train, test, n_samples=None): # Use python sets to get more informative assertion failure messages @@ -757,6 +764,24 @@ def test_leave_one_p_group_out(): # The number of groups in test must be equal to p_groups_out assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out) + # check get_n_splits() with dummy parameters + assert_equal(logo.get_n_splits(None, None, ['a', 'b', 'c', 'b', 'c']), 3) + assert_equal(logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]), 3) + assert_equal(lpgo_2.get_n_splits(None, None, np.arange(4)), 6) + assert_equal(lpgo_1.get_n_splits(groups=np.arange(4)), 4) + + # raise ValueError if a `groups` parameter is illegal + with assert_raises(ValueError): + logo.get_n_splits(None, None, [0.0, np.nan, 0.0]) + with assert_raises(ValueError): + lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0]) + + msg = "The 'groups' parameter should not be None." + assert_raise_message(ValueError, msg, + logo.get_n_splits, None, None, None) + assert_raise_message(ValueError, msg, + lpgo_1.get_n_splits, None, None, None) + def test_leave_group_out_changing_groups(): # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index cc6f5973a0..9228837e1b 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -259,10 +259,10 @@ def test_cross_val_score_predict_groups(): GroupShuffleSplit()] for cv in group_cvs: assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", cross_val_score, estimator=clf, X=X, y=y, cv=cv) assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", cross_val_predict, estimator=clf, X=X, y=y, cv=cv) -- GitLab