From 54b0e4bf62152ac9abb5b93780a0af3e6f5d1c44 Mon Sep 17 00:00:00 2001
From: Russell Smith <rsmith54@users.noreply.github.com>
Date: Tue, 20 Sep 2016 07:46:45 -0400
Subject: [PATCH] Add OneVs{One,All}Classifier._pairwise: fix for #7306 (#7350)

---
 doc/whats_new.rst                           |   7 ++
 sklearn/model_selection/_split.py           |  34 ------
 sklearn/model_selection/_validation.py      |   4 +-
 sklearn/model_selection/tests/test_split.py |  22 ----
 sklearn/multiclass.py                       | 125 +++++++++-----------
 sklearn/svm/base.py                         |   2 +-
 sklearn/tests/test_multiclass.py            |  46 ++++++-
 sklearn/utils/metaestimators.py             |  38 +++++-
 sklearn/utils/multiclass.py                 |  53 ++++++++-
 sklearn/utils/tests/test_multiclass.py      |  29 ++++-
 10 files changed, 227 insertions(+), 133 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 8db58b884d..174d2e6804 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -462,6 +462,11 @@ Bug fixes
       <https://github.com/scikit-learn/scikit-learn/pull/7159>`_)
       by `Yichuan Liu <https://github.com/yl565>`_.
 
+    - Cross-validation of :class:`OneVsOneClassifier` and
+      :class:`OneVsRestClassifier` now works with precomputed kernels.
+      (`#7350 <https://github.com/scikit-learn/scikit-learn/pull/7350/>`_)
+      By `Russell Smith`_.
+
 
 API changes summary
 -------------------
@@ -4638,3 +4643,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
 .. _Robert McGibbon: https://github.com/rmcgibbo
 
 .. _Gregory Stupp: https://github.com/stuppie
+
+.. _Russell Smith: https://github.com/rsmith54
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 3570d688ff..cf109e6216 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -1677,40 +1677,6 @@ def train_test_split(*arrays, **options):
 
 train_test_split.__test__ = False  # to avoid a pb with nosetests
 
-
-def _safe_split(estimator, X, y, indices, train_indices=None):
-    """Create subset of dataset and properly handle kernels."""
-    if (hasattr(estimator, 'kernel') and callable(estimator.kernel) and
-            not isinstance(estimator.kernel, GPKernel)):
-        # cannot compute the kernel values with custom function
-        raise ValueError("Cannot use a custom kernel function. "
-                         "Precompute the kernel matrix instead.")
-
-    if not hasattr(X, "shape"):
-        if getattr(estimator, "_pairwise", False):
-            raise ValueError("Precomputed kernels or affinity matrices have "
-                             "to be passed as arrays or sparse matrices.")
-        X_subset = [X[index] for index in indices]
-    else:
-        if getattr(estimator, "_pairwise", False):
-            # X is a precomputed square kernel matrix
-            if X.shape[0] != X.shape[1]:
-                raise ValueError("X should be a square kernel matrix")
-            if train_indices is None:
-                X_subset = X[np.ix_(indices, indices)]
-            else:
-                X_subset = X[np.ix_(indices, train_indices)]
-        else:
-            X_subset = safe_indexing(X, indices)
-
-    if y is not None:
-        y_subset = safe_indexing(y, indices)
-    else:
-        y_subset = None
-
-    return X_subset, y_subset
-
-
 def _build_repr(self):
     # XXX This is copied from BaseEstimator's get_params
     cls = self.__class__
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index bf1d70b081..d82a62707e 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -23,11 +23,11 @@ from ..base import is_classifier, clone
 from ..utils import indexable, check_random_state, safe_indexing
 from ..utils.fixes import astype
 from ..utils.validation import _is_arraylike, _num_samples
+from ..utils.metaestimators import _safe_split
 from ..externals.joblib import Parallel, delayed, logger
 from ..metrics.scorer import check_scoring
 from ..exceptions import FitFailedWarning
-
-from ._split import check_cv, _safe_split
+from ._split import check_cv
 
 __all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
            'learning_curve', 'validation_curve']
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index ec4f07aef8..4dcd8f5503 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -45,13 +45,11 @@ from sklearn.model_selection import GridSearchCV
 
 from sklearn.linear_model import Ridge
 
-from sklearn.model_selection._split import _safe_split
 from sklearn.model_selection._split import _validate_shuffle_split
 from sklearn.model_selection._split import _CVIterableWrapper
 from sklearn.model_selection._split import _build_repr
 
 from sklearn.datasets import load_digits
-from sklearn.datasets import load_iris
 from sklearn.datasets import make_classification
 
 from sklearn.externals import six
@@ -62,7 +60,6 @@ from sklearn.svm import SVC
 X = np.ones(10)
 y = np.arange(10) // 2
 P_sparse = coo_matrix(np.eye(5))
-iris = load_iris()
 digits = load_digits()
 
 
@@ -846,25 +843,6 @@ def test_shufflesplit_reproducible():
                        list(a for a, b in ss.split(X)))
 
 
-def test_safe_split_with_precomputed_kernel():
-    clf = SVC()
-    clfp = SVC(kernel="precomputed")
-
-    X, y = iris.data, iris.target
-    K = np.dot(X, X.T)
-
-    cv = ShuffleSplit(test_size=0.25, random_state=0)
-    tr, te = list(cv.split(X))[0]
-
-    X_tr, y_tr = _safe_split(clf, X, y, tr)
-    K_tr, y_tr2 = _safe_split(clfp, K, y, tr)
-    assert_array_almost_equal(K_tr, np.dot(X_tr, X_tr.T))
-
-    X_te, y_te = _safe_split(clf, X, y, te, tr)
-    K_te, y_te2 = _safe_split(clfp, K, y, te, tr)
-    assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T))
-
-
 def test_train_test_split_allow_nans():
     # Check that train_test_split allows input data with NaNs
     X = np.arange(200, dtype=np.float64).reshape(10, -1)
diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py
index 6a25e1191c..e3fad7e08e 100644
--- a/sklearn/multiclass.py
+++ b/sklearn/multiclass.py
@@ -37,6 +37,7 @@ import array
 import numpy as np
 import warnings
 import scipy.sparse as sp
+import itertools
 
 from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
 from .base import MetaEstimatorMixin, is_regressor
@@ -47,7 +48,10 @@ from .utils.validation import _num_samples
 from .utils.validation import check_is_fitted
 from .utils.validation import check_X_y
 from .utils.multiclass import (_check_partial_fit_first_call,
-                               check_classification_targets)
+                               check_classification_targets,
+                               _ovr_decision_function)
+from .utils.metaestimators import _safe_split
+
 from .externals.joblib import Parallel
 from .externals.joblib import delayed
 from .externals.six.moves import zip as izip
@@ -257,9 +261,9 @@ class OneVsRestClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
 
         self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(
             _partial_fit_binary)(self.estimators_[i],
-            X, next(columns) if self.classes_[i] in
-            self.label_binarizer_.classes_ else
-            np.zeros((1, len(y))))
+                                 X, next(columns) if self.classes_[i] in
+                                 self.label_binarizer_.classes_ else
+                                 np.zeros((1, len(y))))
             for i in range(self.n_classes_))
 
         return self
@@ -391,6 +395,11 @@ class OneVsRestClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
                 "Base estimator doesn't have an intercept_ attribute.")
         return np.array([e.intercept_.ravel() for e in self.estimators_])
 
+    @property
+    def _pairwise(self):
+        """Indicate if wrapped estimator is using a precomputed Gram matrix"""
+        return getattr(self.estimator, "_pairwise", False)
+
 
 def _fit_ovo_binary(estimator, X, y, i, j):
     """Fit a single binary estimator (one-vs-one)."""
@@ -399,8 +408,10 @@ def _fit_ovo_binary(estimator, X, y, i, j):
     y_binary = np.empty(y.shape, np.int)
     y_binary[y == i] = 0
     y_binary[y == j] = 1
-    ind = np.arange(X.shape[0])
-    return _fit_binary(estimator, X[ind[cond]], y_binary, classes=[i, j])
+    indcond = np.arange(X.shape[0])[cond]
+    return _fit_binary(estimator,
+                       _safe_split(estimator, X, None, indices=indcond)[0],
+                       y_binary, classes=[i, j]), indcond
 
 
 def _partial_fit_ovo_binary(estimator, X, y, i, j):
@@ -472,10 +483,17 @@ class OneVsOneClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
 
         self.classes_ = np.unique(y)
         n_classes = self.classes_.shape[0]
-        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
-            delayed(_fit_ovo_binary)(
-                self.estimator, X, y, self.classes_[i], self.classes_[j])
-            for i in range(n_classes) for j in range(i + 1, n_classes))
+        estimators_indices = list(zip(*(Parallel(n_jobs=self.n_jobs)(
+            delayed(_fit_ovo_binary)
+            (self.estimator, X, y, self.classes_[i], self.classes_[j])
+            for i in range(n_classes) for j in range(i + 1, n_classes)))))
+
+        self.estimators_ = estimators_indices[0]
+        try:
+            self.pairwise_indices_ = estimators_indices[1] \
+                                     if self._pairwise else None
+        except AttributeError:
+            self.pairwise_indices_ = None
 
         return self
 
@@ -509,16 +527,20 @@ class OneVsOneClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
         if _check_partial_fit_first_call(self, classes):
             self.estimators_ = [clone(self.estimator) for i in
                                 range(self.n_classes_ *
-                                (self.n_classes_-1) // 2)]
+                                      (self.n_classes_ - 1) // 2)]
 
         X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
         check_classification_targets(y)
-        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
-            delayed(_partial_fit_ovo_binary)(
-                estimator, X, y, self.classes_[i], self.classes_[j])
-            for estimator, (i, j) in izip(self.estimators_, ((i, j) for i
-                                in range(self.n_classes_) for j in range
-                                            (i + 1, self.n_classes_))))
+        combinations = itertools.combinations(range(self.n_classes_), 2)
+        self.estimators_ = Parallel(
+            n_jobs=self.n_jobs)(
+                delayed(_partial_fit_ovo_binary)(
+                    estimator, X, y, self.classes_[i], self.classes_[j])
+                for estimator, (i, j) in izip(
+                        self.estimators_, (combinations)))
+
+        self.pairwise_indices_ = None
+
         return self
 
     def predict(self, X):
@@ -559,62 +581,29 @@ class OneVsOneClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
         """
         check_is_fitted(self, 'estimators_')
 
-        predictions = np.vstack([est.predict(X) for est in self.estimators_]).T
-        confidences = np.vstack([_predict_binary(est, X) for est in self.estimators_]).T
-        return _ovr_decision_function(predictions, confidences,
-                                      len(self.classes_))
+        indices = self.pairwise_indices_
+        if indices is None:
+            Xs = [X] * len(self.estimators_)
+        else:
+            Xs = [X[:, idx] for idx in indices]
+
+        predictions = np.vstack([est.predict(Xi)
+                                 for est, Xi in zip(self.estimators_, Xs)]).T
+        confidences = np.vstack([_predict_binary(est, Xi)
+                                 for est, Xi in zip(self.estimators_, Xs)]).T
+        Y = _ovr_decision_function(predictions,
+                                   confidences, len(self.classes_))
+
+        return Y
 
     @property
     def n_classes_(self):
         return len(self.classes_)
 
-
-def _ovr_decision_function(predictions, confidences, n_classes):
-    """Compute a continuous, tie-breaking ovr decision function.
-
-    It is important to include a continuous value, not only votes,
-    to make computing AUC or calibration meaningful.
-
-    Parameters
-    ----------
-    predictions : array-like, shape (n_samples, n_classifiers)
-        Predicted classes for each binary classifier.
-
-    confidences : array-like, shape (n_samples, n_classifiers)
-        Decision functions or predicted probabilities for positive class
-        for each binary classifier.
-
-    n_classes : int
-        Number of classes. n_classifiers must be
-        ``n_classes * (n_classes - 1 ) / 2``
-    """
-    n_samples = predictions.shape[0]
-    votes = np.zeros((n_samples, n_classes))
-    sum_of_confidences = np.zeros((n_samples, n_classes))
-
-    k = 0
-    for i in range(n_classes):
-        for j in range(i + 1, n_classes):
-            sum_of_confidences[:, i] -= confidences[:, k]
-            sum_of_confidences[:, j] += confidences[:, k]
-            votes[predictions[:, k] == 0, i] += 1
-            votes[predictions[:, k] == 1, j] += 1
-            k += 1
-
-    max_confidences = sum_of_confidences.max()
-    min_confidences = sum_of_confidences.min()
-
-    if max_confidences == min_confidences:
-        return votes
-
-    # Scale the sum_of_confidences to (-0.5, 0.5) and add it with votes.
-    # The motivation is to use confidence levels as a way to break ties in
-    # the votes without switching any decision made based on a difference
-    # of 1 vote.
-    eps = np.finfo(sum_of_confidences.dtype).eps
-    max_abs_confidence = max(abs(max_confidences), abs(min_confidences))
-    scale = (0.5 - eps) / max_abs_confidence
-    return votes + sum_of_confidences * scale
+    @property
+    def _pairwise(self):
+        """Indicate if wrapped estimator is using a precomputed Gram matrix"""
+        return getattr(self.estimator, "_pairwise", False)
 
 
 class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py
index ac6231c4b0..ddf7672993 100644
--- a/sklearn/svm/base.py
+++ b/sklearn/svm/base.py
@@ -9,7 +9,7 @@ from . import libsvm, liblinear
 from . import libsvm_sparse
 from ..base import BaseEstimator, ClassifierMixin
 from ..preprocessing import LabelEncoder
-from ..multiclass import _ovr_decision_function
+from ..utils.multiclass import _ovr_decision_function
 from ..utils import check_array, check_consistent_length, check_random_state
 from ..utils import column_or_1d, check_X_y
 from ..utils import compute_class_weight, deprecated
diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py
index ef0489696f..5bdc13f8d5 100644
--- a/sklearn/tests/test_multiclass.py
+++ b/sklearn/tests/test_multiclass.py
@@ -24,7 +24,7 @@ from sklearn.naive_bayes import MultinomialNB
 from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
                                   Perceptron, LogisticRegression)
 from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
-from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import GridSearchCV, cross_val_score
 from sklearn.pipeline import Pipeline
 from sklearn import svm
 from sklearn import datasets
@@ -605,3 +605,47 @@ def test_ecoc_gridsearch():
     cv.fit(iris.data, iris.target)
     best_C = cv.best_estimator_.estimators_[0].C
     assert_true(best_C in Cs)
+
+
+def test_pairwise_indices():
+    clf_precomputed = svm.SVC(kernel='precomputed')
+    X, y = iris.data, iris.target
+
+    ovr_false = OneVsOneClassifier(clf_precomputed)
+    linear_kernel = np.dot(X, X.T)
+    ovr_false.fit(linear_kernel, y)
+
+    n_estimators = len(ovr_false.estimators_)
+    precomputed_indices = ovr_false.pairwise_indices_
+
+    for idx in precomputed_indices:
+        assert_equal(idx.shape[0] * n_estimators / (n_estimators - 1),
+                     linear_kernel.shape[0])
+
+
+def test_pairwise_attribute():
+    clf_precomputed = svm.SVC(kernel='precomputed')
+    clf_notprecomputed = svm.SVC()
+
+    for MultiClassClassifier in [OneVsRestClassifier, OneVsOneClassifier]:
+        ovr_false = MultiClassClassifier(clf_notprecomputed)
+        assert_false(ovr_false._pairwise)
+
+        ovr_true = MultiClassClassifier(clf_precomputed)
+        assert_true(ovr_true._pairwise)
+
+
+def test_pairwise_cross_val_score():
+    clf_precomputed = svm.SVC(kernel='precomputed')
+    clf_notprecomputed = svm.SVC(kernel='linear')
+
+    X, y = iris.data, iris.target
+
+    for MultiClassClassifier in [OneVsRestClassifier, OneVsOneClassifier]:
+        ovr_false = MultiClassClassifier(clf_notprecomputed)
+        ovr_true = MultiClassClassifier(clf_precomputed)
+
+        linear_kernel = np.dot(X, X.T)
+        score_precomputed = cross_val_score(ovr_true, linear_kernel, y)
+        score_linear = cross_val_score(ovr_false, X, y)
+        assert_array_equal(score_precomputed, score_linear)
diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py
index 346064448b..d34c62b185 100644
--- a/sklearn/utils/metaestimators.py
+++ b/sklearn/utils/metaestimators.py
@@ -5,7 +5,8 @@
 
 from operator import attrgetter
 from functools import update_wrapper
-
+import numpy as np
+from ..utils import safe_indexing
 
 __all__ = ['if_delegate_has_method']
 
@@ -77,3 +78,38 @@ def if_delegate_has_method(delegate):
 
     return lambda fn: _IffHasAttrDescriptor(fn, delegate,
                                             attribute_name=fn.__name__)
+
+
+def _safe_split(estimator, X, y, indices, train_indices=None):
+    """Create subset of dataset and properly handle kernels."""
+    from ..gaussian_process.kernels import Kernel as GPKernel
+
+    if (hasattr(estimator, 'kernel') and callable(estimator.kernel) and
+            not isinstance(estimator.kernel, GPKernel)):
+        # cannot compute the kernel values with custom function
+        raise ValueError("Cannot use a custom kernel function. "
+                         "Precompute the kernel matrix instead.")
+
+    if not hasattr(X, "shape"):
+        if getattr(estimator, "_pairwise", False):
+            raise ValueError("Precomputed kernels or affinity matrices have "
+                             "to be passed as arrays or sparse matrices.")
+        X_subset = [X[index] for index in indices]
+    else:
+        if getattr(estimator, "_pairwise", False):
+            # X is a precomputed square kernel matrix
+            if X.shape[0] != X.shape[1]:
+                raise ValueError("X should be a square kernel matrix")
+            if train_indices is None:
+                X_subset = X[np.ix_(indices, indices)]
+            else:
+                X_subset = X[np.ix_(indices, train_indices)]
+        else:
+            X_subset = safe_indexing(X, indices)
+
+    if y is not None:
+        y_subset = safe_indexing(y, indices)
+    else:
+        y_subset = None
+
+    return X_subset, y_subset
diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py
index 1ce9f22fcc..2a2cfe1c30 100644
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -23,7 +23,6 @@ from .validation import check_array
 from ..utils.fixes import bincount
 from ..utils.fixes import array_equal
 
-
 def _unique_multiclass(y):
     if hasattr(y, '__array__'):
         return np.unique(np.asarray(y))
@@ -160,7 +159,7 @@ def check_classification_targets(y):
     """Ensure that target y is of a non-regression type.
 
     Only the following target types (as defined in type_of_target) are allowed:
-        'binary', 'multiclass', 'multiclass-multioutput', 
+        'binary', 'multiclass', 'multiclass-multioutput',
         'multilabel-indicator', 'multilabel-sequences'
 
     Parameters
@@ -168,7 +167,7 @@ def check_classification_targets(y):
     y : array-like
     """
     y_type = type_of_target(y)
-    if y_type not in ['binary', 'multiclass', 'multiclass-multioutput', 
+    if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
             'multilabel-indicator', 'multilabel-sequences']:
         raise ValueError("Unknown label type: %r" % y_type)
 
@@ -386,3 +385,51 @@ def class_distribution(y, sample_weight=None):
             class_prior.append(class_prior_k / class_prior_k.sum())
 
     return (classes, n_classes, class_prior)
+
+
+def _ovr_decision_function(predictions, confidences, n_classes):
+    """Compute a continuous, tie-breaking ovr decision function.
+
+    It is important to include a continuous value, not only votes,
+    to make computing AUC or calibration meaningful.
+
+    Parameters
+    ----------
+    predictions : array-like, shape (n_samples, n_classifiers)
+        Predicted classes for each binary classifier.
+
+    confidences : array-like, shape (n_samples, n_classifiers)
+        Decision functions or predicted probabilities for positive class
+        for each binary classifier.
+
+    n_classes : int
+        Number of classes. n_classifiers must be
+        ``n_classes * (n_classes - 1 ) / 2``
+    """
+    n_samples = predictions.shape[0]
+    votes = np.zeros((n_samples, n_classes))
+    sum_of_confidences = np.zeros((n_samples, n_classes))
+
+    k = 0
+    for i in range(n_classes):
+        for j in range(i + 1, n_classes):
+            sum_of_confidences[:, i] -= confidences[:, k]
+            sum_of_confidences[:, j] += confidences[:, k]
+            votes[predictions[:, k] == 0, i] += 1
+            votes[predictions[:, k] == 1, j] += 1
+            k += 1
+
+    max_confidences = sum_of_confidences.max()
+    min_confidences = sum_of_confidences.min()
+
+    if max_confidences == min_confidences:
+        return votes
+
+    # Scale the sum_of_confidences to (-0.5, 0.5) and add it with votes.
+    # The motivation is to use confidence levels as a way to break ties in
+    # the votes without switching any decision made based on a difference
+    # of 1 vote.
+    eps = np.finfo(sum_of_confidences.dtype).eps
+    max_abs_confidence = max(abs(max_confidences), abs(min_confidences))
+    scale = (0.5 - eps) / max_abs_confidence
+    return votes + sum_of_confidences * scale
diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py
index d103a177d7..34f60ffec8 100644
--- a/sklearn/utils/tests/test_multiclass.py
+++ b/sklearn/utils/tests/test_multiclass.py
@@ -28,6 +28,11 @@ from sklearn.utils.multiclass import type_of_target
 from sklearn.utils.multiclass import class_distribution
 from sklearn.utils.multiclass import check_classification_targets
 
+from sklearn.utils.metaestimators import _safe_split
+from sklearn.model_selection import ShuffleSplit
+from sklearn.svm import SVC
+from sklearn import datasets
+
 
 class NotAnArray(object):
     """An object that is convertable to an array. This is useful to
@@ -266,7 +271,7 @@ def test_check_classification_targets():
         if y_type in ["unknown", "continuous", 'continuous-multioutput']:
             for example in EXAMPLES[y_type]:
                 msg = 'Unknown label type: '
-                assert_raises_regex(ValueError, msg, 
+                assert_raises_regex(ValueError, msg,
                     check_classification_targets, example)
         else:
             for example in EXAMPLES[y_type]:
@@ -345,3 +350,25 @@ def test_class_distribution():
         assert_array_almost_equal(classes_sp[k], classes_expected[k])
         assert_array_almost_equal(n_classes_sp[k], n_classes_expected[k])
         assert_array_almost_equal(class_prior_sp[k], class_prior_expected[k])
+
+
+def test_safe_split_with_precomputed_kernel():
+    clf = SVC()
+    clfp = SVC(kernel="precomputed")
+
+    iris = datasets.load_iris()
+    X, y = iris.data, iris.target
+    K = np.dot(X, X.T)
+
+    cv = ShuffleSplit(test_size=0.25, random_state=0)
+    train, test = list(cv.split(X))[0]
+
+    X_train, y_train = _safe_split(clf, X, y, train)
+    K_train, y_train2 = _safe_split(clfp, K, y, train)
+    assert_array_almost_equal(K_train, np.dot(X_train, X_train.T))
+    assert_array_almost_equal(y_train, y_train2)
+
+    X_test, y_test = _safe_split(clf, X, y, test, train)
+    K_test, y_test2 = _safe_split(clfp, K, y, test, train)
+    assert_array_almost_equal(K_test, np.dot(X_test, X_train.T))
+    assert_array_almost_equal(y_test, y_test2)
-- 
GitLab