diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 51d5e9011939b47444edfd7d4eedbb6be2f585e7..1f32a2992fcc971357730ba60f5552415de17a0f 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -56,6 +56,11 @@ class BaseCrossValidator(with_metaclass(ABCMeta)): Implementations must define `_iter_test_masks` or `_iter_test_indices`. """ + def __init__(self): + # We need this for the build_repr to work properly in py2.7 + # see #6304 + pass + def split(self, X, y=None, labels=None): """Generate indices to split data into training and test set. diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index ea72546f924a58738110f9d70425c0e4cdbde6dd..907ea5815583052b4c830df3c3ea2425525cbc1c 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -129,7 +129,7 @@ class MockClassifier(object): @ignore_warnings -def test_cross_validator_with_default_indices(): +def test_cross_validator_with_default_params(): n_samples = 4 n_unique_labels = 4 n_folds = 2 @@ -149,10 +149,23 @@ def test_cross_validator_with_default_indices(): ss = ShuffleSplit(random_state=0) ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2 + loo_repr = "LeaveOneOut()" + lpo_repr = "LeavePOut(p=2)" + kf_repr = "KFold(n_folds=2, random_state=None, shuffle=False)" + skf_repr = "StratifiedKFold(n_folds=2, random_state=None, shuffle=False)" + lolo_repr = "LeaveOneLabelOut()" + lopo_repr = "LeavePLabelOut(n_labels=2)" + ss_repr = ("ShuffleSplit(n_iter=10, random_state=0, test_size=0.1, " + "train_size=None)") + ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))" + n_splits = [n_samples, comb(n_samples, p), n_folds, n_folds, n_unique_labels, comb(n_unique_labels, p), n_iter, 2] - for i, cv in enumerate([loo, lpo, kf, skf, lolo, lopo, ss, ps]): + for i, (cv, cv_repr) in enumerate(zip( + [loo, lpo, kf, skf, lolo, lopo, ss, ps], + [loo_repr, lpo_repr, kf_repr, skf_repr, lolo_repr, lopo_repr, + ss_repr, ps_repr])): # Test if get_n_splits works correctly assert_equal(n_splits[i], cv.get_n_splits(X, y, labels)) @@ -165,6 +178,9 @@ def test_cross_validator_with_default_indices(): assert_equal(np.asarray(train).dtype.kind, 'i') assert_equal(np.asarray(train).dtype.kind, 'i') + # Test if the repr works without any errors + assert_equal(cv_repr, repr(cv)) + def check_valid_split(train, test, n_samples=None): # Use python sets to get more informative assertion failure messages