From 986a49bbe018aa8060f53c146fc06f278f80b7a6 Mon Sep 17 00:00:00 2001
From: Stijn Tonk <equialgo@gmail.com>
Date: Thu, 29 Dec 2016 02:46:53 +0100
Subject: [PATCH] FIX Split data using _safe_split in _permutaion_test_score
 (#5697)

Squashed commits:
[94fd9f4] split data using _safe_split in _permutaion_test_scorer
[522053b] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series
[21b23ce] running test_permutation_test_score_pandas on iris data to prevent warnings.
[15a48bf] adding safe_indexing to _shuffle function
[9ea5c9e] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series
[3cf5e8f] split  data using _safe_split in _permutaion_test_scorer to fix error when using Pandas DataFrame/Series
---
 sklearn/cross_validation.py                   |  8 +++++---
 sklearn/model_selection/_validation.py        |  8 +++++---
 .../model_selection/tests/test_validation.py  | 19 +++++++++++++++++++
 3 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py
index a4a1e3d65c..03c74b88f5 100644
--- a/sklearn/cross_validation.py
+++ b/sklearn/cross_validation.py
@@ -1756,8 +1756,10 @@ def _permutation_test_score(estimator, X, y, cv, scorer):
     """Auxiliary function for permutation_test_score"""
     avg_score = []
     for train, test in cv:
-        estimator.fit(X[train], y[train])
-        avg_score.append(scorer(estimator, X[test], y[test]))
+        X_train, y_train = _safe_split(estimator, X, y, train)
+        X_test, y_test = _safe_split(estimator, X, y, test, train)
+        estimator.fit(X_train, y_train)
+        avg_score.append(scorer(estimator, X_test, y_test))
     return np.mean(avg_score)
 
 
@@ -1770,7 +1772,7 @@ def _shuffle(y, labels, random_state):
         for label in np.unique(labels):
             this_mask = (labels == label)
             ind[this_mask] = random_state.permutation(ind[this_mask])
-    return y[ind]
+    return safe_indexing(y, ind)
 
 
 def check_cv(cv, X=None, y=None, classifier=False):
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index 88c3922f99..91f60366f8 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -622,8 +622,10 @@ def _permutation_test_score(estimator, X, y, groups, cv, scorer):
     """Auxiliary function for permutation_test_score"""
     avg_score = []
     for train, test in cv.split(X, y, groups):
-        estimator.fit(X[train], y[train])
-        avg_score.append(scorer(estimator, X[test], y[test]))
+        X_train, y_train = _safe_split(estimator, X, y, train)
+        X_test, y_test = _safe_split(estimator, X, y, test, train)
+        estimator.fit(X_train, y_train)
+        avg_score.append(scorer(estimator, X_test, y_test))
     return np.mean(avg_score)
 
 
@@ -636,7 +638,7 @@ def _shuffle(y, groups, random_state):
         for group in np.unique(groups):
             this_mask = (groups == group)
             indices[this_mask] = random_state.permutation(indices[this_mask])
-    return y[indices]
+    return safe_indexing(y, indices)
 
 
 def learning_curve(estimator, X, y, groups=None,
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index 830a079a0f..d1f83b469d 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -966,3 +966,22 @@ def test_score_memmap():
                 break
             except WindowsError:
                 sleep(1.)
+
+
+def test_permutation_test_score_pandas():
+    # check permutation_test_score doesn't destroy pandas dataframe
+    types = [(MockDataFrame, MockDataFrame)]
+    try:
+        from pandas import Series, DataFrame
+        types.append((Series, DataFrame))
+    except ImportError:
+        pass
+    for TargetType, InputFeatureType in types:
+        # X dataframe, y series
+        iris = load_iris()
+        X, y = iris.data, iris.target
+        X_df, y_ser = InputFeatureType(X), TargetType(y)
+        check_df = lambda x: isinstance(x, InputFeatureType)
+        check_series = lambda x: isinstance(x, TargetType)
+        clf = CheckingClassifier(check_X=check_df, check_y=check_series)
+        permutation_test_score(clf, X_df, y_ser)
-- 
GitLab