From 38f6a91566bc643e2a8f76beb16f3e673faab848 Mon Sep 17 00:00:00 2001
From: Raghav RV <rvraghav93@gmail.com>
Date: Sun, 30 Oct 2016 22:49:17 +0100
Subject: [PATCH] [MRG + 2] FIX Be robust to non re-entrant/ non deterministic
 cv.split calls (#7660)

---
 sklearn/model_selection/_search.py            |   3 +-
 sklearn/model_selection/_split.py             |   2 +-
 sklearn/model_selection/_validation.py        |  18 +--
 sklearn/model_selection/tests/common.py       |  23 ++++
 sklearn/model_selection/tests/test_search.py  |  57 ++++++++
 sklearn/model_selection/tests/test_split.py   |  80 +++---------
 .../model_selection/tests/test_validation.py  | 123 +++++++++++++++++-
 7 files changed, 226 insertions(+), 80 deletions(-)
 create mode 100644 sklearn/model_selection/tests/common.py

diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 82516f1e6b..d2f5542ebd 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -550,6 +550,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         base_estimator = clone(self.estimator)
         pre_dispatch = self.pre_dispatch
 
+        cv_iter = list(cv.split(X, y, groups))
         out = Parallel(
             n_jobs=self.n_jobs, verbose=self.verbose,
             pre_dispatch=pre_dispatch
@@ -561,7 +562,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                                   return_times=True, return_parameters=True,
                                   error_score=self.error_score)
           for parameters in parameter_iterable
-          for train, test in cv.split(X, y, groups))
+          for train, test in cv_iter)
 
         # if one choose to see train score, "out" will contain train score info
         if self.return_train_score:
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 0064830c9a..aecff7be39 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -1477,7 +1477,7 @@ class PredefinedSplit(BaseCrossValidator):
 class _CVIterableWrapper(BaseCrossValidator):
     """Wrapper class for old style cv objects and iterables."""
     def __init__(self, cv):
-        self.cv = cv
+        self.cv = list(cv)
 
     def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index b8546d804e..23db2a9ceb 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -1,4 +1,3 @@
-
 """
 The :mod:`sklearn.model_selection._validation` module includes classes and
 functions to validate the model.
@@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
     X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
+    cv_iter = list(cv.split(X, y, groups))
     scorer = check_scoring(estimator, scoring=scoring)
     # We clone the estimator to make sure that all the folds are
     # independent, and that it is pickle-able.
@@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
     scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
                                               train, test, verbose, None,
                                               fit_params)
-                      for train, test in cv.split(X, y, groups))
+                      for train, test in cv_iter)
     return np.array(scores)[:, 0]
 
 
@@ -385,6 +385,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
     X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
+    cv_iter = list(cv.split(X, y, groups))
 
     # Ensure the estimator has implemented the passed decision function
     if not callable(getattr(estimator, method)):
@@ -397,7 +398,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
                         pre_dispatch=pre_dispatch)
     prediction_blocks = parallel(delayed(_fit_and_predict)(
         clone(estimator), X, y, train, test, verbose, fit_params, method)
-        for train, test in cv.split(X, y, groups))
+        for train, test in cv_iter)
 
     # Concatenate the predictions
     predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
@@ -751,9 +752,8 @@ def learning_curve(estimator, X, y, groups=None,
     X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
-    cv_iter = cv.split(X, y, groups)
     # Make a list since we will be iterating multiple times over the folds
-    cv_iter = list(cv_iter)
+    cv_iter = list(cv.split(X, y, groups))
     scorer = check_scoring(estimator, scoring=scoring)
 
     n_max_training_samples = len(cv_iter[0][0])
@@ -776,9 +776,8 @@ def learning_curve(estimator, X, y, groups=None,
     if exploit_incremental_learning:
         classes = np.unique(y) if is_classifier(estimator) else None
         out = parallel(delayed(_incremental_fit_estimator)(
-            clone(estimator), X, y, classes, train,
-            test, train_sizes_abs, scorer, verbose)
-            for train, test in cv_iter)
+            clone(estimator), X, y, classes, train, test, train_sizes_abs,
+            scorer, verbose) for train, test in cv_iter)
     else:
         train_test_proportions = []
         for train, test in cv_iter:
@@ -962,6 +961,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
     X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
+    cv_iter = list(cv.split(X, y, groups))
 
     scorer = check_scoring(estimator, scoring=scoring)
 
@@ -970,7 +970,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
     out = parallel(delayed(_fit_and_score)(
         estimator, X, y, scorer, train, test, verbose,
         parameters={param_name: v}, fit_params=None, return_train_score=True)
-        for train, test in cv.split(X, y, groups) for v in param_range)
+        for train, test in cv_iter for v in param_range)
 
     out = np.asarray(out)
     n_params = len(param_range)
diff --git a/sklearn/model_selection/tests/common.py b/sklearn/model_selection/tests/common.py
new file mode 100644
index 0000000000..13549eef37
--- /dev/null
+++ b/sklearn/model_selection/tests/common.py
@@ -0,0 +1,23 @@
+"""
+Common utilities for testing model selection.
+"""
+
+import numpy as np
+
+from sklearn.model_selection import KFold
+
+
+class OneTimeSplitter:
+    """A wrapper to make KFold single entry cv iterator"""
+    def __init__(self, n_splits=4, n_samples=99):
+        self.n_splits = n_splits
+        self.n_samples = n_samples
+        self.indices = iter(KFold(n_splits=n_splits).split(np.ones(n_samples)))
+
+    def split(self, X=None, y=None, groups=None):
+        """Split can be called only once"""
+        for index in self.indices:
+            yield index
+
+    def get_n_splits(self, X=None, y=None, groups=None):
+        return self.n_splits
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 36e6965a11..1ce2875507 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -60,6 +60,8 @@ from sklearn.preprocessing import Imputer
 from sklearn.pipeline import Pipeline
 from sklearn.linear_model import SGDClassifier
 
+from sklearn.model_selection.tests.common import OneTimeSplitter
+
 
 # Neither of the following two estimators inherit from BaseEstimator,
 # to test hyperparameter search on user-defined classifiers.
@@ -1154,3 +1156,58 @@ def test_search_train_scores_set_to_false():
     gs = GridSearchCV(clf, param_grid={'C': [0.1, 0.2]},
                       return_train_score=False)
     gs.fit(X, y)
+
+
+def test_grid_search_cv_splits_consistency():
+    # Check if a one time iterable is accepted as a cv parameter.
+    n_samples = 100
+    n_splits = 5
+    X, y = make_classification(n_samples=n_samples, random_state=0)
+
+    gs = GridSearchCV(LinearSVC(random_state=0),
+                      param_grid={'C': [0.1, 0.2, 0.3]},
+                      cv=OneTimeSplitter(n_splits=n_splits,
+                                         n_samples=n_samples))
+    gs.fit(X, y)
+
+    gs2 = GridSearchCV(LinearSVC(random_state=0),
+                       param_grid={'C': [0.1, 0.2, 0.3]},
+                       cv=KFold(n_splits=n_splits))
+    gs2.fit(X, y)
+
+    def _pop_time_keys(cv_results):
+        for key in ('mean_fit_time', 'std_fit_time',
+                    'mean_score_time', 'std_score_time'):
+            cv_results.pop(key)
+        return cv_results
+
+    # OneTimeSplitter is a non-re-entrant cv where split can be called only
+    # once if ``cv.split`` is called once per param setting in GridSearchCV.fit
+    # the 2nd and 3rd parameter will not be evaluated as no train/test indices
+    # will be generated for the 2nd and subsequent cv.split calls.
+    # This is a check to make sure cv.split is not called once per param
+    # setting.
+    np.testing.assert_equal(_pop_time_keys(gs.cv_results_),
+                            _pop_time_keys(gs2.cv_results_))
+
+    # Check consistency of folds across the parameters
+    gs = GridSearchCV(LinearSVC(random_state=0),
+                      param_grid={'C': [0.1, 0.1, 0.2, 0.2]},
+                      cv=KFold(n_splits=n_splits, shuffle=True))
+    gs.fit(X, y)
+
+    # As the first two param settings (C=0.1) and the next two param
+    # settings (C=0.2) are same, the test and train scores must also be
+    # same as long as the same train/test indices are generated for all
+    # the cv splits, for both param setting
+    for score_type in ('train', 'test'):
+        per_param_scores = {}
+        for param_i in range(4):
+            per_param_scores[param_i] = list(
+                gs.cv_results_['split%d_%s_score' % (s, score_type)][param_i]
+                for s in range(5))
+
+        assert_array_almost_equal(per_param_scores[0],
+                                  per_param_scores[1])
+        assert_array_almost_equal(per_param_scores[2],
+                                  per_param_scores[3])
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index b547ac6415..936abf03ac 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -59,73 +59,9 @@ from sklearn.svm import SVC
 
 X = np.ones(10)
 y = np.arange(10) // 2
-P_sparse = coo_matrix(np.eye(5))
 digits = load_digits()
 
 
-class MockClassifier(object):
-    """Dummy classifier to test the cross-validation"""
-
-    def __init__(self, a=0, allow_nd=False):
-        self.a = a
-        self.allow_nd = allow_nd
-
-    def fit(self, X, Y=None, sample_weight=None, class_prior=None,
-            sparse_sample_weight=None, sparse_param=None, dummy_int=None,
-            dummy_str=None, dummy_obj=None, callback=None):
-        """The dummy arguments are to test that this fit function can
-        accept non-array arguments through cross-validation, such as:
-            - int
-            - str (this is actually array-like)
-            - object
-            - function
-        """
-        self.dummy_int = dummy_int
-        self.dummy_str = dummy_str
-        self.dummy_obj = dummy_obj
-        if callback is not None:
-            callback(self)
-
-        if self.allow_nd:
-            X = X.reshape(len(X), -1)
-        if X.ndim >= 3 and not self.allow_nd:
-            raise ValueError('X cannot be d')
-        if sample_weight is not None:
-            assert_true(sample_weight.shape[0] == X.shape[0],
-                        'MockClassifier extra fit_param sample_weight.shape[0]'
-                        ' is {0}, should be {1}'.format(sample_weight.shape[0],
-                                                        X.shape[0]))
-        if class_prior is not None:
-            assert_true(class_prior.shape[0] == len(np.unique(y)),
-                        'MockClassifier extra fit_param class_prior.shape[0]'
-                        ' is {0}, should be {1}'.format(class_prior.shape[0],
-                                                        len(np.unique(y))))
-        if sparse_sample_weight is not None:
-            fmt = ('MockClassifier extra fit_param sparse_sample_weight'
-                   '.shape[0] is {0}, should be {1}')
-            assert_true(sparse_sample_weight.shape[0] == X.shape[0],
-                        fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
-        if sparse_param is not None:
-            fmt = ('MockClassifier extra fit_param sparse_param.shape '
-                   'is ({0}, {1}), should be ({2}, {3})')
-            assert_true(sparse_param.shape == P_sparse.shape,
-                        fmt.format(sparse_param.shape[0],
-                                   sparse_param.shape[1],
-                                   P_sparse.shape[0], P_sparse.shape[1]))
-        return self
-
-    def predict(self, T):
-        if self.allow_nd:
-            T = T.reshape(len(T), -1)
-        return T[:, 0]
-
-    def score(self, X=None, Y=None):
-        return 1. / (1 + np.abs(self.a))
-
-    def get_params(self, deep=False):
-        return {'a': self.a, 'allow_nd': self.allow_nd}
-
-
 @ignore_warnings
 def test_cross_validator_with_default_params():
     n_samples = 4
@@ -933,6 +869,22 @@ def test_cv_iterable_wrapper():
     # Check if get_n_splits works correctly
     assert_equal(len(cv), wrapped_old_skf.get_n_splits())
 
+    kf_iter = KFold(n_splits=5).split(X, y)
+    kf_iter_wrapped = check_cv(kf_iter)
+    # Since the wrapped iterable is enlisted and stored,
+    # split can be called any number of times to produce
+    # consistent results.
+    assert_array_equal(list(kf_iter_wrapped.split(X, y)),
+                       list(kf_iter_wrapped.split(X, y)))
+    # If the splits are randomized, successive calls to split yields different
+    # results
+    kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y)
+    kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
+    assert_array_equal(list(kf_randomized_iter_wrapped.split(X, y)),
+                       list(kf_randomized_iter_wrapped.split(X, y)))
+    assert_true(np.any(np.array(list(kf_iter_wrapped.split(X, y))) !=
+                       np.array(list(kf_randomized_iter_wrapped.split(X, y)))))
+
 
 def test_group_kfold():
     rng = np.random.RandomState(0)
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index 26af0f76e6..31c5fc8257 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -60,7 +60,7 @@ from sklearn.utils import shuffle
 from sklearn.datasets import make_classification
 from sklearn.datasets import make_multilabel_classification
 
-from sklearn.model_selection.tests.test_split import MockClassifier
+from sklearn.model_selection.tests.common import OneTimeSplitter
 
 
 try:
@@ -131,6 +131,69 @@ class MockEstimatorWithParameter(BaseEstimator):
         return X is self.X_subset
 
 
+class MockClassifier(object):
+    """Dummy classifier to test the cross-validation"""
+
+    def __init__(self, a=0, allow_nd=False):
+        self.a = a
+        self.allow_nd = allow_nd
+
+    def fit(self, X, Y=None, sample_weight=None, class_prior=None,
+            sparse_sample_weight=None, sparse_param=None, dummy_int=None,
+            dummy_str=None, dummy_obj=None, callback=None):
+        """The dummy arguments are to test that this fit function can
+        accept non-array arguments through cross-validation, such as:
+            - int
+            - str (this is actually array-like)
+            - object
+            - function
+        """
+        self.dummy_int = dummy_int
+        self.dummy_str = dummy_str
+        self.dummy_obj = dummy_obj
+        if callback is not None:
+            callback(self)
+
+        if self.allow_nd:
+            X = X.reshape(len(X), -1)
+        if X.ndim >= 3 and not self.allow_nd:
+            raise ValueError('X cannot be d')
+        if sample_weight is not None:
+            assert_true(sample_weight.shape[0] == X.shape[0],
+                        'MockClassifier extra fit_param sample_weight.shape[0]'
+                        ' is {0}, should be {1}'.format(sample_weight.shape[0],
+                                                        X.shape[0]))
+        if class_prior is not None:
+            assert_true(class_prior.shape[0] == len(np.unique(y)),
+                        'MockClassifier extra fit_param class_prior.shape[0]'
+                        ' is {0}, should be {1}'.format(class_prior.shape[0],
+                                                        len(np.unique(y))))
+        if sparse_sample_weight is not None:
+            fmt = ('MockClassifier extra fit_param sparse_sample_weight'
+                   '.shape[0] is {0}, should be {1}')
+            assert_true(sparse_sample_weight.shape[0] == X.shape[0],
+                        fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
+        if sparse_param is not None:
+            fmt = ('MockClassifier extra fit_param sparse_param.shape '
+                   'is ({0}, {1}), should be ({2}, {3})')
+            assert_true(sparse_param.shape == P_sparse.shape,
+                        fmt.format(sparse_param.shape[0],
+                                   sparse_param.shape[1],
+                                   P_sparse.shape[0], P_sparse.shape[1]))
+        return self
+
+    def predict(self, T):
+        if self.allow_nd:
+            T = T.reshape(len(T), -1)
+        return T[:, 0]
+
+    def score(self, X=None, Y=None):
+        return 1. / (1 + np.abs(self.a))
+
+    def get_params(self, deep=False):
+        return {'a': self.a, 'allow_nd': self.allow_nd}
+
+
 # XXX: use 2D array, since 1D X is being detected as a single sample in
 # check_consistent_length
 X = np.ones((10, 2))
@@ -139,6 +202,7 @@ y = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
 # The number of samples per class needs to be > n_splits,
 # for StratifiedKFold(n_splits=3)
 y2 = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 3])
+P_sparse = coo_matrix(np.eye(5))
 
 
 def test_cross_val_score():
@@ -556,14 +620,17 @@ def test_cross_val_score_sparse_fit_params():
 
 
 def test_learning_curve():
-    X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
-                               n_redundant=0, n_classes=2,
+    n_samples = 30
+    n_splits = 3
+    X, y = make_classification(n_samples=n_samples, n_features=1,
+                               n_informative=1, n_redundant=0, n_classes=2,
                                n_clusters_per_class=1, random_state=0)
-    estimator = MockImprovingEstimator(20)
+    estimator = MockImprovingEstimator(n_samples * ((n_splits - 1) / n_splits))
     for shuffle_train in [False, True]:
         with warnings.catch_warnings(record=True) as w:
             train_sizes, train_scores, test_scores = learning_curve(
-                estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10),
+                estimator, X, y, cv=KFold(n_splits=n_splits),
+                train_sizes=np.linspace(0.1, 1.0, 10),
                 shuffle=shuffle_train)
         if len(w) > 0:
             raise RuntimeError("Unexpected warning: %r" % w[0].message)
@@ -575,6 +642,18 @@ def test_learning_curve():
         assert_array_almost_equal(test_scores.mean(axis=1),
                                   np.linspace(0.1, 1.0, 10))
 
+        # Test a custom cv splitter that can iterate only once
+        with warnings.catch_warnings(record=True) as w:
+            train_sizes2, train_scores2, test_scores2 = learning_curve(
+                estimator, X, y,
+                cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
+                train_sizes=np.linspace(0.1, 1.0, 10),
+                shuffle=shuffle_train)
+        if len(w) > 0:
+            raise RuntimeError("Unexpected warning: %r" % w[0].message)
+        assert_array_almost_equal(train_scores2, train_scores)
+        assert_array_almost_equal(test_scores2, test_scores)
+
 
 def test_learning_curve_unsupervised():
     X, _ = make_classification(n_samples=30, n_features=1, n_informative=1,
@@ -766,6 +845,40 @@ def test_validation_curve():
     assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)
 
 
+def test_validation_curve_cv_splits_consistency():
+    n_samples = 100
+    n_splits = 5
+    X, y = make_classification(n_samples=100, random_state=0)
+
+    scores1 = validation_curve(SVC(kernel='linear', random_state=0), X, y,
+                               'C', [0.1, 0.1, 0.2, 0.2],
+                               cv=OneTimeSplitter(n_splits=n_splits,
+                                                  n_samples=n_samples))
+    # The OneTimeSplitter is a non-re-entrant cv splitter. Unless, the
+    # `split` is called for each parameter, the following should produce
+    # identical results for param setting 1 and param setting 2 as both have
+    # the same C value.
+    assert_array_almost_equal(*np.vsplit(np.hstack(scores1)[(0, 2, 1, 3), :],
+                                         2))
+
+    scores2 = validation_curve(SVC(kernel='linear', random_state=0), X, y,
+                               'C', [0.1, 0.1, 0.2, 0.2],
+                               cv=KFold(n_splits=n_splits, shuffle=True))
+
+    # For scores2, compare the 1st and 2nd parameter's scores
+    # (Since the C value for 1st two param setting is 0.1, they must be
+    # consistent unless the train test folds differ between the param settings)
+    assert_array_almost_equal(*np.vsplit(np.hstack(scores2)[(0, 2, 1, 3), :],
+                                         2))
+
+    scores3 = validation_curve(SVC(kernel='linear', random_state=0), X, y,
+                               'C', [0.1, 0.1, 0.2, 0.2],
+                               cv=KFold(n_splits=n_splits))
+
+    # OneTimeSplitter is basically unshuffled KFold(n_splits=5). Sanity check.
+    assert_array_almost_equal(np.array(scores3), np.array(scores1))
+
+
 def test_check_is_permutation():
     rng = np.random.RandomState(0)
     p = np.arange(100)
-- 
GitLab