diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 151bbafd622206288a9be69846ffdb860483c7de..de889fab0bda5cfffb645145edfefc9d622f7ddc 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -188,7 +188,7 @@ class LeaveOneOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if X is None: - raise ValueError("The X parameter should not be None") + raise ValueError("The 'X' parameter should not be None.") return _num_samples(X) @@ -259,7 +259,7 @@ class LeavePOut(BaseCrossValidator): Always ignored, exists for compatibility. """ if X is None: - raise ValueError("The X parameter should not be None") + raise ValueError("The 'X' parameter should not be None.") return int(comb(_num_samples(X), self.p, exact=True)) @@ -477,7 +477,7 @@ class GroupKFold(_BaseKFold): def _iter_test_indices(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) unique_groups, groups = np.unique(groups, return_inverse=True) @@ -765,6 +765,8 @@ class LeaveOneGroupOut(BaseCrossValidator): >>> logo = LeaveOneGroupOut() >>> logo.get_n_splits(X, y, groups) 2 + >>> logo.get_n_splits(groups=groups) # 'groups' is always required + 2 >>> print(logo) LeaveOneGroupOut() >>> for train_index, test_index in logo.split(X, y, groups): @@ -785,7 +787,7 @@ class LeaveOneGroupOut(BaseCrossValidator): def _iter_test_masks(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") # We make a copy of groups to avoid side-effects during iteration groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) @@ -796,20 +798,22 @@ class LeaveOneGroupOut(BaseCrossValidator): for i in unique_groups: yield groups == i - def get_n_splits(self, X, y, groups): + def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator Parameters ---------- - X : object + X : object, optional Always ignored, exists for compatibility. - y : object + y : object, optional Always ignored, exists for compatibility. groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into - train/test set. + train/test set. This 'groups' parameter must always be specified to + calculate the number of splits, though the other parameters can be + omitted. Returns ------- @@ -817,7 +821,8 @@ class LeaveOneGroupOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") + groups = check_array(groups, ensure_2d=False, dtype=None) return len(np.unique(groups)) @@ -852,6 +857,8 @@ class LeavePGroupsOut(BaseCrossValidator): >>> lpgo = LeavePGroupsOut(n_groups=2) >>> lpgo.get_n_splits(X, y, groups) 3 + >>> lpgo.get_n_splits(groups=groups) # 'groups' is always required + 3 >>> print(lpgo) LeavePGroupsOut(n_groups=2) >>> for train_index, test_index in lpgo.split(X, y, groups): @@ -879,7 +886,7 @@ class LeavePGroupsOut(BaseCrossValidator): def _iter_test_masks(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, copy=True, ensure_2d=False, dtype=None) unique_groups = np.unique(groups) if self.n_groups >= len(unique_groups): @@ -895,22 +902,22 @@ class LeavePGroupsOut(BaseCrossValidator): test_index[groups == l] = True yield test_index - def get_n_splits(self, X, y, groups): + def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator Parameters ---------- - X : object + X : object, optional Always ignored, exists for compatibility. - ``np.zeros(n_samples)`` may be used as a placeholder. - y : object + y : object, optional Always ignored, exists for compatibility. - ``np.zeros(n_samples)`` may be used as a placeholder. groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into - train/test set. + train/test set. This 'groups' parameter must always be specified to + calculate the number of splits, though the other parameters can be + omitted. Returns ------- @@ -918,9 +925,8 @@ class LeavePGroupsOut(BaseCrossValidator): Returns the number of splitting iterations in the cross-validator. """ if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) - X, y, groups = indexable(X, y, groups) return int(comb(len(np.unique(groups)), self.n_groups, exact=True)) @@ -1318,7 +1324,7 @@ class GroupShuffleSplit(ShuffleSplit): def _iter_indices(self, X, y, groups): if groups is None: - raise ValueError("The groups parameter should not be None") + raise ValueError("The 'groups' parameter should not be None.") groups = check_array(groups, ensure_2d=False, dtype=None) classes, group_indices = np.unique(groups, return_inverse=True) for group_train, group_test in super( diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 055a4c061a7c031412d2f82db8bb06ea9b918b48..3f804d414b75010d1bb7013c07d55e70bc1490f2 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -317,7 +317,7 @@ def test_grid_search_groups(): for cv in group_cvs: gs = GridSearchCV(clf, grid, cv=cv) assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", gs.fit, X, y) gs.fit(X, y, groups=groups) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index fcd0160ca74eeb3aeca0419721105bcab72b2665..e97fdce5e1e5aaec7d72540998a819c3fd3f11c8 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -189,6 +189,13 @@ def test_cross_validator_with_default_params(): # Test if the repr works without any errors assert_equal(cv_repr, repr(cv)) + # ValueError for get_n_splits methods + msg = "The 'X' parameter should not be None." + assert_raise_message(ValueError, msg, + loo.get_n_splits, None, y, groups) + assert_raise_message(ValueError, msg, + lpo.get_n_splits, None, y, groups) + def check_valid_split(train, test, n_samples=None): # Use python sets to get more informative assertion failure messages @@ -757,6 +764,24 @@ def test_leave_one_p_group_out(): # The number of groups in test must be equal to p_groups_out assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out) + # check get_n_splits() with dummy parameters + assert_equal(logo.get_n_splits(None, None, ['a', 'b', 'c', 'b', 'c']), 3) + assert_equal(logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]), 3) + assert_equal(lpgo_2.get_n_splits(None, None, np.arange(4)), 6) + assert_equal(lpgo_1.get_n_splits(groups=np.arange(4)), 4) + + # raise ValueError if a `groups` parameter is illegal + with assert_raises(ValueError): + logo.get_n_splits(None, None, [0.0, np.nan, 0.0]) + with assert_raises(ValueError): + lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0]) + + msg = "The 'groups' parameter should not be None." + assert_raise_message(ValueError, msg, + logo.get_n_splits, None, None, None) + assert_raise_message(ValueError, msg, + lpgo_1.get_n_splits, None, None, None) + def test_leave_group_out_changing_groups(): # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index cc6f5973a0b096c81fe1c9487242636ef4e8fa75..9228837e1be1a900bac71d81500be215b6556872 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -259,10 +259,10 @@ def test_cross_val_score_predict_groups(): GroupShuffleSplit()] for cv in group_cvs: assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", cross_val_score, estimator=clf, X=X, y=y, cv=cv) assert_raise_message(ValueError, - "The groups parameter should not be None", + "The 'groups' parameter should not be None.", cross_val_predict, estimator=clf, X=X, y=y, cv=cv)