From 829efa5929c129a4112e539699f9034152332ee4 Mon Sep 17 00:00:00 2001
From: NarineK <narine.kokhlikyan@gmail.com>
Date: Wed, 19 Oct 2016 12:18:07 -0700
Subject: [PATCH] [MRG+1] Learning curve: Add an option to randomly choose
 indices for different training sizes (#7506)

* Chooses randomly the indices for different training sizes

* Bring back deleted line

* Rewrote the description of 'shuffle' attribute

* use random.sample instead of np.random.choice

* replace tabs with spaces

* merge to master

* Added shuffle in model-selection's learning_curve method

* Added shuffle for incremental learning + addressed Joel's comment

* Shorten long lines

* Add 2 blank spaces between test cases

* Addressed Joel's review comments

* Added 2 blank lines between methods

* Added non regression test for learning_curve with shuffle

* Fixed indentions

* Fixed space issues

* Modified test cases + small code improvements

* Fix some style issues

* Addressed Joel's comments - removed _shuffle_train_indices, more test cases and added new entry under 0.19/enhancements

* Added some modifications in whats_new.rst
---
 doc/whats_new.rst                             |  8 ++
 sklearn/model_selection/_validation.py        | 32 ++++++--
 .../model_selection/tests/test_validation.py  | 74 ++++++++++++++-----
 3 files changed, 87 insertions(+), 27 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 00790f304f..6dd0a1d060 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -41,6 +41,12 @@ Enhancements
      (`#6101 <https://github.com/scikit-learn/scikit-learn/pull/6101>`_)
      By `Ibraim Ganiev`_.
 
+   - Added ``shuffle`` and ``random_state`` parameters to shuffle training
+     data before taking prefixes of it based on training sizes in
+     :func:`model_selection.learning_curve`.
+     (`#7506` <https://github.com/scikit-learn/scikit-learn/pull/7506>_) by
+     `Narine Kokhlikyan`_.
+
 Bug fixes
 .........
 
@@ -4861,3 +4867,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
 .. _Utkarsh Upadhyay: https://github.com/musically-ut
 
 .. _Eugene Chen: https://github.com/eyc88
+
+.. _Narine Kokhlikyan: https://github.com/NarineK
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index 9745cb9dec..cc77d7c284 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -642,7 +642,8 @@ def _shuffle(y, groups, random_state):
 def learning_curve(estimator, X, y, groups=None,
                    train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None,
                    exploit_incremental_learning=False, n_jobs=1,
-                   pre_dispatch="all", verbose=0):
+                   pre_dispatch="all", verbose=0, shuffle=False,
+                   random_state=None):
     """Learning curve.
 
     Determines cross-validated training and test scores for different training
@@ -718,7 +719,14 @@ def learning_curve(estimator, X, y, groups=None,
     verbose : integer, optional
         Controls the verbosity: the higher, the more messages.
 
-    Returns
+    shuffle : boolean, optional
+        Whether to shuffle training data before taking prefixes of it
+        based on``train_sizes``.
+
+    random_state : None, int or RandomState
+        When shuffle=True, pseudo-random number generator state used for
+        shuffling. If None, use default numpy RNG for shuffling.
+
     -------
     train_sizes_abs : array, shape = (n_unique_ticks,), dtype int
         Numbers of training examples that has been used to generate the
@@ -759,17 +767,27 @@ def learning_curve(estimator, X, y, groups=None,
 
     parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
                         verbose=verbose)
+
+    if shuffle:
+        rng = check_random_state(random_state)
+        cv_iter = ((rng.permutation(train), test) for train, test in cv_iter)
+
     if exploit_incremental_learning:
         classes = np.unique(y) if is_classifier(estimator) else None
         out = parallel(delayed(_incremental_fit_estimator)(
-            clone(estimator), X, y, classes, train, test, train_sizes_abs,
-            scorer, verbose) for train, test in cv.split(X, y, groups))
+            clone(estimator), X, y, classes, train,
+            test, train_sizes_abs, scorer, verbose)
+            for train, test in cv_iter)
     else:
+        train_test_proportions = []
+        for train, test in cv_iter:
+            for n_train_samples in train_sizes_abs:
+                train_test_proportions.append((train[:n_train_samples], test))
+
         out = parallel(delayed(_fit_and_score)(
-            clone(estimator), X, y, scorer, train[:n_train_samples], test,
+            clone(estimator), X, y, scorer, train, test,
             verbose, parameters=None, fit_params=None, return_train_score=True)
-            for train, test in cv_iter
-            for n_train_samples in train_sizes_abs)
+            for train, test in train_test_proportions)
         out = np.array(out)
         n_cv_folds = out.shape[0] // n_unique_ticks
         out = out.reshape(n_cv_folds, n_unique_ticks, 2)
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index 4f45a079fb..eb29be1a2a 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -560,18 +560,20 @@ def test_learning_curve():
                                n_redundant=0, n_classes=2,
                                n_clusters_per_class=1, random_state=0)
     estimator = MockImprovingEstimator(20)
-    with warnings.catch_warnings(record=True) as w:
-        train_sizes, train_scores, test_scores = learning_curve(
-            estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
-    if len(w) > 0:
-        raise RuntimeError("Unexpected warning: %r" % w[0].message)
-    assert_equal(train_scores.shape, (10, 3))
-    assert_equal(test_scores.shape, (10, 3))
-    assert_array_equal(train_sizes, np.linspace(2, 20, 10))
-    assert_array_almost_equal(train_scores.mean(axis=1),
-                              np.linspace(1.9, 1.0, 10))
-    assert_array_almost_equal(test_scores.mean(axis=1),
-                              np.linspace(0.1, 1.0, 10))
+    for shuffle_train in [False, True]:
+        with warnings.catch_warnings(record=True) as w:
+            train_sizes, train_scores, test_scores = learning_curve(
+                estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10),
+                shuffle=shuffle_train)
+        if len(w) > 0:
+            raise RuntimeError("Unexpected warning: %r" % w[0].message)
+        assert_equal(train_scores.shape, (10, 3))
+        assert_equal(test_scores.shape, (10, 3))
+        assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+        assert_array_almost_equal(train_scores.mean(axis=1),
+                                  np.linspace(1.9, 1.0, 10))
+        assert_array_almost_equal(test_scores.mean(axis=1),
+                                  np.linspace(0.1, 1.0, 10))
 
 
 def test_learning_curve_unsupervised():
@@ -622,14 +624,15 @@ def test_learning_curve_incremental_learning():
                                n_redundant=0, n_classes=2,
                                n_clusters_per_class=1, random_state=0)
     estimator = MockIncrementalImprovingEstimator(20)
-    train_sizes, train_scores, test_scores = learning_curve(
-        estimator, X, y, cv=3, exploit_incremental_learning=True,
-        train_sizes=np.linspace(0.1, 1.0, 10))
-    assert_array_equal(train_sizes, np.linspace(2, 20, 10))
-    assert_array_almost_equal(train_scores.mean(axis=1),
-                              np.linspace(1.9, 1.0, 10))
-    assert_array_almost_equal(test_scores.mean(axis=1),
-                              np.linspace(0.1, 1.0, 10))
+    for shuffle_train in [False, True]:
+        train_sizes, train_scores, test_scores = learning_curve(
+            estimator, X, y, cv=3, exploit_incremental_learning=True,
+            train_sizes=np.linspace(0.1, 1.0, 10), shuffle=shuffle_train)
+        assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+        assert_array_almost_equal(train_scores.mean(axis=1),
+                                  np.linspace(1.9, 1.0, 10))
+        assert_array_almost_equal(test_scores.mean(axis=1),
+                                  np.linspace(0.1, 1.0, 10))
 
 
 def test_learning_curve_incremental_learning_unsupervised():
@@ -713,6 +716,37 @@ def test_learning_curve_with_boolean_indices():
                               np.linspace(0.1, 1.0, 10))
 
 
+def test_learning_curve_with_shuffle():
+    """Following test case was designed this way to verify the code
+    changes made in pull request: #7506."""
+    X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [11, 12], [13, 14], [15, 16],
+                 [17, 18], [19, 20], [7, 8], [9, 10], [11, 12], [13, 14],
+                 [15, 16], [17, 18]])
+    y = np.array([1, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 2, 3, 4])
+    groups = np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 4, 4, 4, 4])
+    estimator = PassiveAggressiveClassifier(shuffle=False)
+
+    cv = GroupKFold(n_splits=2)
+    train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
+        estimator, X, y, cv=cv, n_jobs=1, train_sizes=np.linspace(0.3, 1.0, 3),
+        groups=groups, shuffle=True, random_state=2)
+    assert_array_almost_equal(train_scores_batch.mean(axis=1),
+                              np.array([0.75, 0.3, 0.36111111]))
+    assert_array_almost_equal(test_scores_batch.mean(axis=1),
+                              np.array([0.36111111, 0.25, 0.25]))
+    assert_raises(ValueError, learning_curve, estimator, X, y, cv=cv, n_jobs=1,
+                  train_sizes=np.linspace(0.3, 1.0, 3), groups=groups)
+
+    train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
+        estimator, X, y, cv=cv, n_jobs=1, train_sizes=np.linspace(0.3, 1.0, 3),
+        groups=groups, shuffle=True, random_state=2,
+        exploit_incremental_learning=True)
+    assert_array_almost_equal(train_scores_inc.mean(axis=1),
+                              train_scores_batch.mean(axis=1))
+    assert_array_almost_equal(test_scores_inc.mean(axis=1),
+                              test_scores_batch.mean(axis=1))
+
+
 def test_validation_curve():
     X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
                                n_redundant=0, n_classes=2,
-- 
GitLab