From af1796ef68d5193f0ca8573a7c2e71c97c97e9ff Mon Sep 17 00:00:00 2001 From: Neeraj Gangwar <y.neeraj2008@gmail.com> Date: Sun, 5 Mar 2017 02:09:04 +0530 Subject: [PATCH] [MRG+1] Repeated K-Fold and Repeated Stratified K-Fold (#8120) * Add _RepeatedSplits and RepeatedKFold class * Add RepeatedStratifiedKFold and doc for repeated cvs * Change default value of n_repeats * Change input parameters of repeated cv constructor to n_splits, n_repeats, random_state * Generate random states in split function rather than store it beforehand * Doc changes, inheriting RepeatedKFold, RepeatedStratifiedKFold from _RepeatedSplits and other review changes * Remove blank line, put testcases for deterministic split in loop and add StopIteration check in testcase * Using rng directly as random_state param to create cv instance and added a check for cvargs * Fix pep8 warnings * Changing default values for n_splits and n_repeats and add entry in changelog * Adding name to the feature * Missing space --- doc/modules/classes.rst | 2 + doc/modules/cross_validation.rst | 31 ++++ doc/whats_new.rst | 6 + sklearn/model_selection/__init__.py | 4 + sklearn/model_selection/_split.py | 171 ++++++++++++++++++++ sklearn/model_selection/tests/test_split.py | 72 +++++++++ 6 files changed, 286 insertions(+) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 3aee8f258b..3101488fd6 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -170,6 +170,8 @@ Splitter Classes model_selection.LeavePGroupsOut model_selection.LeaveOneOut model_selection.LeavePOut + model_selection.RepeatedKFold + model_selection.RepeatedStratifiedKFold model_selection.ShuffleSplit model_selection.GroupShuffleSplit model_selection.StratifiedShuffleSplit diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 3f7cf95e59..4b9a36e979 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -263,6 +263,33 @@ Thus, one can create the training/test sets using numpy indexing:: >>> X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test] +Repeated K-Fold +--------------- + +:class:`RepeatedKFold` repeats K-Fold n times. It can be used when one +requires to run :class:`KFold` n times, producing different splits in +each repetition. + +Example of 2-fold K-Fold repeated 2 times:: + + >>> import numpy as np + >>> from sklearn.model_selection import RepeatedKFold + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> random_state = 12883823 + >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state) + >>> for train, test in rkf.split(X): + ... print("%s %s" % (train, test)) + ... + [2 3] [0 1] + [0 1] [2 3] + [0 2] [1 3] + [1 3] [0 2] + + +Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold n times +with different randomization in each repetition. + + Leave One Out (LOO) ------------------- @@ -409,6 +436,10 @@ two slightly unbalanced classes:: [0 1 3 4 5 8 9] [2 6 7] [0 1 2 4 5 6 7] [3 8 9] +:class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times +with different randomization in each repetition. + + Stratified Shuffle Split ------------------------ diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7d2fa8a562..a5a7b369bf 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -41,6 +41,10 @@ New features Kullback-Leibler divergence and the Itakura-Saito divergence. By `Tom Dupre la Tour`_. + - Added the :class:`sklearn.model_selection.RepeatedKFold` and + :class:`sklearn.model_selection.RepeatedStratifiedKFold`. + :issue:`8120` by `Neeraj Gangwar`_. + - Added :func:`metrics.mean_squared_log_error`, which computes the mean square error of the logarithmic transformation of targets, particularly useful for targets with an exponential trend. @@ -5004,3 +5008,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Vincent Pham: https://github.com/vincentpham1991 .. _Denis Engemann: http://denis-engemann.de + +.. _Neeraj Gangwar: http://neerajgangwar.in diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index f5ab0d7526..73c842e706 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -7,6 +7,8 @@ from ._split import LeaveOneGroupOut from ._split import LeaveOneOut from ._split import LeavePGroupsOut from ._split import LeavePOut +from ._split import RepeatedKFold +from ._split import RepeatedStratifiedKFold from ._split import ShuffleSplit from ._split import GroupShuffleSplit from ._split import StratifiedShuffleSplit @@ -36,6 +38,8 @@ __all__ = ('BaseCrossValidator', 'LeaveOneOut', 'LeavePGroupsOut', 'LeavePOut', + 'RepeatedKFold', + 'RepeatedStratifiedKFold', 'ParameterGrid', 'ParameterSampler', 'PredefinedSplit', diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index b2ed060e31..992c4f6d81 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -41,6 +41,8 @@ __all__ = ['BaseCrossValidator', 'LeaveOneOut', 'LeavePGroupsOut', 'LeavePOut', + 'RepeatedStratifiedKFold', + 'RepeatedKFold', 'ShuffleSplit', 'GroupShuffleSplit', 'StratifiedKFold', @@ -397,6 +399,8 @@ class KFold(_BaseKFold): classification tasks). GroupKFold: K-fold iterator variant with non-overlapping groups. + + RepeatedKFold: Repeats K-Fold n times. """ def __init__(self, n_splits=3, shuffle=False, @@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold): All the folds have size ``trunc(n_samples / n_splits)``, the last one has the complementary. + See also + -------- + RepeatedStratifiedKFold: Repeats Stratified K-Fold n times. """ def __init__(self, n_splits=3, shuffle=False, random_state=None): @@ -913,6 +920,170 @@ class LeavePGroupsOut(BaseCrossValidator): return int(comb(len(np.unique(groups)), self.n_groups, exact=True)) +class _RepeatedSplits(with_metaclass(ABCMeta)): + """Repeated splits for an arbitrary randomized CV splitter. + + Repeats splits for cross-validators n times with different randomization + in each repetition. + + Parameters + ---------- + cv : callable + Cross-validator class. + + n_repeats : int, default=10 + Number of times cross-validator needs to be repeated. + + random_state : None, int or RandomState, default=None + Random state to be used to generate random state for each + repetition. + + **cvargs : additional params + Constructor parameters for cv. Must not contain random_state + and shuffle. + """ + def __init__(self, cv, n_repeats=10, random_state=None, **cvargs): + if not isinstance(n_repeats, (np.integer, numbers.Integral)): + raise ValueError("Number of repetitions must be of Integral type.") + + if n_repeats <= 1: + raise ValueError("Number of repetitions must be greater than 1.") + + if any(key in cvargs for key in ('random_state', 'shuffle')): + raise ValueError( + "cvargs must not contain random_state or shuffle.") + + self.cv = cv + self.n_repeats = n_repeats + self.random_state = random_state + self.cvargs = cvargs + + def split(self, X, y=None, groups=None): + """Generates indices to split data into training and test set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : array-like, of length n_samples + The target variable for supervised learning problems. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + n_repeats = self.n_repeats + rng = check_random_state(self.random_state) + + for idx in range(n_repeats): + cv = self.cv(random_state=rng, shuffle=True, + **self.cvargs) + for train_index, test_index in cv.split(X, y, groups): + yield train_index, test_index + + +class RepeatedKFold(_RepeatedSplits): + """Repeated K-Fold cross validator. + + Repeats K-Fold n times with different randomization in each repetition. + + Read more in the :ref:`User Guide <cross_validation>`. + + Parameters + ---------- + n_splits : int, default=5 + Number of folds. Must be at least 2. + + n_repeats : int, default=10 + Number of times cross-validator needs to be repeated. + + random_state : None, int or RandomState, default=None + Random state to be used to generate random state for each + repetition. + + Examples + -------- + >>> from sklearn.model_selection import RepeatedKFold + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([0, 0, 1, 1]) + >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124) + >>> for train_index, test_index in rkf.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... + TRAIN: [0 1] TEST: [2 3] + TRAIN: [2 3] TEST: [0 1] + TRAIN: [1 2] TEST: [0 3] + TRAIN: [0 3] TEST: [1 2] + + + See also + -------- + RepeatedStratifiedKFold: Repeates Stratified K-Fold n times. + """ + def __init__(self, n_splits=5, n_repeats=10, random_state=None): + super(RepeatedKFold, self).__init__( + KFold, n_repeats, random_state, n_splits=n_splits) + + +class RepeatedStratifiedKFold(_RepeatedSplits): + """Repeated Stratified K-Fold cross validator. + + Repeats Stratified K-Fold n times with different randomization in each + repetition. + + Read more in the :ref:`User Guide <cross_validation>`. + + Parameters + ---------- + n_splits : int, default=5 + Number of folds. Must be at least 2. + + n_repeats : int, default=10 + Number of times cross-validator needs to be repeated. + + random_state : None, int or RandomState, default=None + Random state to be used to generate random state for each + repetition. + + Examples + -------- + >>> from sklearn.model_selection import RepeatedStratifiedKFold + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([0, 0, 1, 1]) + >>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, + ... random_state=36851234) + >>> for train_index, test_index in rskf.split(X, y): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... + TRAIN: [1 2] TEST: [0 3] + TRAIN: [0 3] TEST: [1 2] + TRAIN: [1 3] TEST: [0 2] + TRAIN: [0 2] TEST: [1 3] + + + See also + -------- + RepeatedKFold: Repeats K-Fold n times. + """ + def __init__(self, n_splits=5, n_repeats=10, random_state=None): + super(RepeatedStratifiedKFold, self).__init__( + StratifiedKFold, n_repeats, random_state, n_splits=n_splits) + + class BaseShuffleSplit(with_metaclass(ABCMeta)): """Base class for ShuffleSplit and StratifiedShuffleSplit""" diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 601e9b259c..c997ac9d73 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -42,6 +42,8 @@ from sklearn.model_selection import PredefinedSplit from sklearn.model_selection import check_cv from sklearn.model_selection import train_test_split from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import RepeatedKFold +from sklearn.model_selection import RepeatedStratifiedKFold from sklearn.linear_model import Ridge @@ -804,6 +806,76 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups(): LeavePGroupsOut(n_groups=3).split(X, y, groups)) +def test_repeated_cv_value_errors(): + # n_repeats is not integer or <= 1 + for cv in (RepeatedKFold, RepeatedStratifiedKFold): + assert_raises(ValueError, cv, n_repeats=1) + assert_raises(ValueError, cv, n_repeats=1.5) + + +def test_repeated_kfold_determinstic_split(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + random_state = 258173307 + rkf = RepeatedKFold( + n_splits=2, + n_repeats=2, + random_state=random_state) + + # split should produce same and deterministic splits on + # each call + for _ in range(3): + splits = rkf.split(X) + train, test = next(splits) + assert_array_equal(train, [2, 4]) + assert_array_equal(test, [0, 1, 3]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 3]) + assert_array_equal(test, [2, 4]) + + train, test = next(splits) + assert_array_equal(train, [0, 1]) + assert_array_equal(test, [2, 3, 4]) + + train, test = next(splits) + assert_array_equal(train, [2, 3, 4]) + assert_array_equal(test, [0, 1]) + + assert_raises(StopIteration, next, splits) + + +def test_repeated_stratified_kfold_determinstic_split(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + y = [1, 1, 1, 0, 0] + random_state = 1944695409 + rskf = RepeatedStratifiedKFold( + n_splits=2, + n_repeats=2, + random_state=random_state) + + # split should produce same and deterministic splits on + # each call + for _ in range(3): + splits = rskf.split(X, y) + train, test = next(splits) + assert_array_equal(train, [1, 4]) + assert_array_equal(test, [0, 2, 3]) + + train, test = next(splits) + assert_array_equal(train, [0, 2, 3]) + assert_array_equal(test, [1, 4]) + + train, test = next(splits) + assert_array_equal(train, [2, 3]) + assert_array_equal(test, [0, 1, 4]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 4]) + assert_array_equal(test, [2, 3]) + + assert_raises(StopIteration, next, splits) + + def test_train_test_split_errors(): assert_raises(ValueError, train_test_split) assert_raises(ValueError, train_test_split, range(3), train_size=1.1) -- GitLab