diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index a4a1e3d65c7ca773f734a9d4f6232e096a5817cf..03c74b88f5f286ee9bb8a5fe75c8226d58d396db 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1756,8 +1756,10 @@ def _permutation_test_score(estimator, X, y, cv, scorer): """Auxiliary function for permutation_test_score""" avg_score = [] for train, test in cv: - estimator.fit(X[train], y[train]) - avg_score.append(scorer(estimator, X[test], y[test])) + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + estimator.fit(X_train, y_train) + avg_score.append(scorer(estimator, X_test, y_test)) return np.mean(avg_score) @@ -1770,7 +1772,7 @@ def _shuffle(y, labels, random_state): for label in np.unique(labels): this_mask = (labels == label) ind[this_mask] = random_state.permutation(ind[this_mask]) - return y[ind] + return safe_indexing(y, ind) def check_cv(cv, X=None, y=None, classifier=False): diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 88c3922f99363236c357243801756e7f793d1d9b..91f60366f871703ac0dbcf4256a9aee0e4b3caee 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -622,8 +622,10 @@ def _permutation_test_score(estimator, X, y, groups, cv, scorer): """Auxiliary function for permutation_test_score""" avg_score = [] for train, test in cv.split(X, y, groups): - estimator.fit(X[train], y[train]) - avg_score.append(scorer(estimator, X[test], y[test])) + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + estimator.fit(X_train, y_train) + avg_score.append(scorer(estimator, X_test, y_test)) return np.mean(avg_score) @@ -636,7 +638,7 @@ def _shuffle(y, groups, random_state): for group in np.unique(groups): this_mask = (groups == group) indices[this_mask] = random_state.permutation(indices[this_mask]) - return y[indices] + return safe_indexing(y, indices) def learning_curve(estimator, X, y, groups=None, diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 830a079a0fc6d68c185c534a0df47076e3ee1137..d1f83b469d6c8dd65a19a7039a63b22788815e29 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -966,3 +966,22 @@ def test_score_memmap(): break except WindowsError: sleep(1.) + + +def test_permutation_test_score_pandas(): + # check permutation_test_score doesn't destroy pandas dataframe + types = [(MockDataFrame, MockDataFrame)] + try: + from pandas import Series, DataFrame + types.append((Series, DataFrame)) + except ImportError: + pass + for TargetType, InputFeatureType in types: + # X dataframe, y series + iris = load_iris() + X, y = iris.data, iris.target + X_df, y_ser = InputFeatureType(X), TargetType(y) + check_df = lambda x: isinstance(x, InputFeatureType) + check_series = lambda x: isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) + permutation_test_score(clf, X_df, y_ser)