From 754109c78d90cc792f0944aaaa11371175b09277 Mon Sep 17 00:00:00 2001
From: Toshihiro Kamishima <mail@kamishima.net>
Date: Fri, 12 May 2017 21:56:13 +0900
Subject: [PATCH] [MRG+1] enable to use get_n_splits of LeaveOneGroupOut and
 LeavePGroupsOut with dummy parameters (#8794)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* remove needless argument checking

* add parameter checking as in LeavePGroupsOut

* add examples with dummy inputs

* add unittest for a get_n_splits method in LeaveOneGroupOut and LeavePGroupsOut classes

* X and y can be ommited in a get_n_splits function.

* fix error messages

* update examples

* fix test for an error message

* Revert "fix test for an error message"

This reverts commit 68b984207c704de1a3411d9ded68a11eba1e56f3.

* fix test for an error message

* fix error messages

* remove tailing white spaces

* add periods to messages

* test for ValueError’s of get_n_splits methods of LeaveOneOut / LeavePOut classes

* fix documents:
* parameter name: group -> groups
* modfy white space
---
 sklearn/model_selection/_split.py             | 44 +++++++++++--------
 sklearn/model_selection/tests/test_search.py  |  2 +-
 sklearn/model_selection/tests/test_split.py   | 25 +++++++++++
 .../model_selection/tests/test_validation.py  |  4 +-
 4 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 151bbafd62..de889fab0b 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -188,7 +188,7 @@ class LeaveOneOut(BaseCrossValidator):
             Returns the number of splitting iterations in the cross-validator.
         """
         if X is None:
-            raise ValueError("The X parameter should not be None")
+            raise ValueError("The 'X' parameter should not be None.")
         return _num_samples(X)
 
 
@@ -259,7 +259,7 @@ class LeavePOut(BaseCrossValidator):
             Always ignored, exists for compatibility.
         """
         if X is None:
-            raise ValueError("The X parameter should not be None")
+            raise ValueError("The 'X' parameter should not be None.")
         return int(comb(_num_samples(X), self.p, exact=True))
 
 
@@ -477,7 +477,7 @@ class GroupKFold(_BaseKFold):
 
     def _iter_test_indices(self, X, y, groups):
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
         groups = check_array(groups, ensure_2d=False, dtype=None)
 
         unique_groups, groups = np.unique(groups, return_inverse=True)
@@ -765,6 +765,8 @@ class LeaveOneGroupOut(BaseCrossValidator):
     >>> logo = LeaveOneGroupOut()
     >>> logo.get_n_splits(X, y, groups)
     2
+    >>> logo.get_n_splits(groups=groups) # 'groups' is always required
+    2
     >>> print(logo)
     LeaveOneGroupOut()
     >>> for train_index, test_index in logo.split(X, y, groups):
@@ -785,7 +787,7 @@ class LeaveOneGroupOut(BaseCrossValidator):
 
     def _iter_test_masks(self, X, y, groups):
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
         # We make a copy of groups to avoid side-effects during iteration
         groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
         unique_groups = np.unique(groups)
@@ -796,20 +798,22 @@ class LeaveOneGroupOut(BaseCrossValidator):
         for i in unique_groups:
             yield groups == i
 
-    def get_n_splits(self, X, y, groups):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
         ----------
-        X : object
+        X : object, optional
             Always ignored, exists for compatibility.
 
-        y : object
+        y : object, optional
             Always ignored, exists for compatibility.
 
         groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
-            train/test set.
+            train/test set. This 'groups' parameter must always be specified to
+            calculate the number of splits, though the other parameters can be
+            omitted.
 
         Returns
         -------
@@ -817,7 +821,8 @@ class LeaveOneGroupOut(BaseCrossValidator):
             Returns the number of splitting iterations in the cross-validator.
         """
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
+        groups = check_array(groups, ensure_2d=False, dtype=None)
         return len(np.unique(groups))
 
 
@@ -852,6 +857,8 @@ class LeavePGroupsOut(BaseCrossValidator):
     >>> lpgo = LeavePGroupsOut(n_groups=2)
     >>> lpgo.get_n_splits(X, y, groups)
     3
+    >>> lpgo.get_n_splits(groups=groups)  # 'groups' is always required
+    3
     >>> print(lpgo)
     LeavePGroupsOut(n_groups=2)
     >>> for train_index, test_index in lpgo.split(X, y, groups):
@@ -879,7 +886,7 @@ class LeavePGroupsOut(BaseCrossValidator):
 
     def _iter_test_masks(self, X, y, groups):
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
         groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
         unique_groups = np.unique(groups)
         if self.n_groups >= len(unique_groups):
@@ -895,22 +902,22 @@ class LeavePGroupsOut(BaseCrossValidator):
                 test_index[groups == l] = True
             yield test_index
 
-    def get_n_splits(self, X, y, groups):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
         ----------
-        X : object
+        X : object, optional
             Always ignored, exists for compatibility.
-            ``np.zeros(n_samples)`` may be used as a placeholder.
 
-        y : object
+        y : object, optional
             Always ignored, exists for compatibility.
-            ``np.zeros(n_samples)`` may be used as a placeholder.
 
         groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
-            train/test set.
+            train/test set. This 'groups' parameter must always be specified to
+            calculate the number of splits, though the other parameters can be
+            omitted.
 
         Returns
         -------
@@ -918,9 +925,8 @@ class LeavePGroupsOut(BaseCrossValidator):
             Returns the number of splitting iterations in the cross-validator.
         """
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
         groups = check_array(groups, ensure_2d=False, dtype=None)
-        X, y, groups = indexable(X, y, groups)
         return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
 
 
@@ -1318,7 +1324,7 @@ class GroupShuffleSplit(ShuffleSplit):
 
     def _iter_indices(self, X, y, groups):
         if groups is None:
-            raise ValueError("The groups parameter should not be None")
+            raise ValueError("The 'groups' parameter should not be None.")
         groups = check_array(groups, ensure_2d=False, dtype=None)
         classes, group_indices = np.unique(groups, return_inverse=True)
         for group_train, group_test in super(
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 055a4c061a..3f804d414b 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -317,7 +317,7 @@ def test_grid_search_groups():
     for cv in group_cvs:
         gs = GridSearchCV(clf, grid, cv=cv)
         assert_raise_message(ValueError,
-                             "The groups parameter should not be None",
+                             "The 'groups' parameter should not be None.",
                              gs.fit, X, y)
         gs.fit(X, y, groups=groups)
 
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index fcd0160ca7..e97fdce5e1 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -189,6 +189,13 @@ def test_cross_validator_with_default_params():
         # Test if the repr works without any errors
         assert_equal(cv_repr, repr(cv))
 
+    # ValueError for get_n_splits methods
+    msg = "The 'X' parameter should not be None."
+    assert_raise_message(ValueError, msg,
+                         loo.get_n_splits, None, y, groups)
+    assert_raise_message(ValueError, msg,
+                         lpo.get_n_splits, None, y, groups)
+
 
 def check_valid_split(train, test, n_samples=None):
     # Use python sets to get more informative assertion failure messages
@@ -757,6 +764,24 @@ def test_leave_one_p_group_out():
                 # The number of groups in test must be equal to p_groups_out
                 assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out)
 
+    # check get_n_splits() with dummy parameters
+    assert_equal(logo.get_n_splits(None, None, ['a', 'b', 'c', 'b', 'c']), 3)
+    assert_equal(logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]), 3)
+    assert_equal(lpgo_2.get_n_splits(None, None, np.arange(4)), 6)
+    assert_equal(lpgo_1.get_n_splits(groups=np.arange(4)), 4)
+
+    # raise ValueError if a `groups` parameter is illegal
+    with assert_raises(ValueError):
+        logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
+    with assert_raises(ValueError):
+        lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])
+
+    msg = "The 'groups' parameter should not be None."
+    assert_raise_message(ValueError, msg,
+                         logo.get_n_splits, None, None, None)
+    assert_raise_message(ValueError, msg,
+                         lpgo_1.get_n_splits, None, None, None)
+
 
 def test_leave_group_out_changing_groups():
     # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index cc6f5973a0..9228837e1b 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -259,10 +259,10 @@ def test_cross_val_score_predict_groups():
                  GroupShuffleSplit()]
     for cv in group_cvs:
         assert_raise_message(ValueError,
-                             "The groups parameter should not be None",
+                             "The 'groups' parameter should not be None.",
                              cross_val_score, estimator=clf, X=X, y=y, cv=cv)
         assert_raise_message(ValueError,
-                             "The groups parameter should not be None",
+                             "The 'groups' parameter should not be None.",
                              cross_val_predict, estimator=clf, X=X, y=y, cv=cv)
 
 
-- 
GitLab