diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index 872844cd0cc2916a6494c380f8e85c71e202d0c8..cb81b2d34017b0d20cdddd6afc3a167ba8560af4 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -1,3 +1,4 @@
+
 .. _cross_validation:
 
 ===================================================
@@ -209,6 +210,28 @@ The following sections list utilities to generate indices
 that can be used to generate dataset splits according to different cross
 validation strategies.
 
+.. _iid_cv
+
+Cross-validation iterators for i.i.d. data
+==========================================
+
+Assuming that some data is Independent Identically Distributed (i.i.d.) is
+making the assumption that all samples stem from the same generative process
+and that the generative process is assumed to have no memory of past generated
+samples.
+
+The following cross-validators can be used in such cases.
+
+**NOTE**
+
+While i.i.d. data is a common assumption in machine learning theory, it rarely
+holds in practice. If one knows that the samples have been generated using a
+time-dependent process, it's safer to
+use a `time-series aware cross-validation scheme <time_series_cv>`
+Similarly if we know that the generative process has a group structure
+(samples from collected from different subjects, experiments, measurement
+devices) it safer to use `group-wise cross-validation <group_cv>`.
+
 
 K-fold
 ------
@@ -239,58 +262,7 @@ Thus, one can create the training/test sets using numpy indexing::
   >>> X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]
 
 
-Stratified k-fold
------------------
-
-:class:`StratifiedKFold` is a variation of *k-fold* which returns *stratified*
-folds: each set contains approximately the same percentage of samples of each
-target class as the complete set.
-
-Example of stratified 3-fold cross-validation on a dataset with 10 samples from
-two slightly unbalanced classes::
-
-  >>> from sklearn.model_selection import StratifiedKFold
-
-  >>> X = np.ones(10)
-  >>> y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
-  >>> skf = StratifiedKFold(n_splits=3)
-  >>> for train, test in skf.split(X, y):
-  ...     print("%s %s" % (train, test))
-  [2 3 6 7 8 9] [0 1 4 5]
-  [0 1 3 4 5 8 9] [2 6 7]
-  [0 1 2 4 5 6 7] [3 8 9]
-
-
-Label k-fold
-------------
-
-:class:`LabelKFold` is a variation of *k-fold* which ensures that the same
-label is not in both testing and training sets. This is necessary for example
-if you obtained data from different subjects and you want to avoid over-fitting
-(i.e., learning person specific features) by testing and training on different
-subjects.
-
-Imagine you have three subjects, each with an associated number from 1 to 3::
-
-  >>> from sklearn.model_selection import LabelKFold
-
-  >>> X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
-  >>> y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
-  >>> labels = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
-
-  >>> lkf = LabelKFold(n_splits=3)
-  >>> for train, test in lkf.split(X, y, labels):
-  ...     print("%s %s" % (train, test))
-  [0 1 2 3 4 5] [6 7 8 9]
-  [0 1 2 6 7 8 9] [3 4 5]
-  [3 4 5 6 7 8 9] [0 1 2]
-
-Each subject is in a different testing fold, and the same subject is never in
-both testing and training. Notice that the folds do not have exactly the same
-size due to the imbalance in the data.
-
-
-Leave-One-Out - LOO
+Leave One Out (LOO)
 -------------------
 
 :class:`LeaveOneOut` (or LOO) is a simple cross-validation. Each learning
@@ -348,7 +320,7 @@ fold cross validation should be preferred to LOO.
    Statistical Learning <http://www-bcf.usc.edu/~gareth/ISL>`_, Springer 2013.
 
 
-Leave-P-Out - LPO
+Leave P Out (LPO)
 -----------------
 
 :class:`LeavePOut` is very similar to :class:`LeaveOneOut` as it creates all
@@ -373,68 +345,6 @@ Example of Leave-2-Out on a dataset with 4 samples::
   [0 1] [2 3]
 
 
-Leave-One-Label-Out - LOLO
---------------------------
-
-:class:`LeaveOneLabelOut` (LOLO) is a cross-validation scheme which holds out
-the samples according to a third-party provided array of integer labels. This
-label information can be used to encode arbitrary domain specific pre-defined
-cross-validation folds.
-
-Each training set is thus constituted by all the samples except the ones
-related to a specific label.
-
-For example, in the cases of multiple experiments, LOLO can be used to
-create a cross-validation based on the different experiments: we create
-a training set using the samples of all the experiments except one::
-
-  >>> from sklearn.model_selection import LeaveOneLabelOut
-
-  >>> X = [1, 5, 10, 50]
-  >>> y = [0, 1, 1, 2]
-  >>> labels = [1, 1, 2, 2]
-  >>> lolo = LeaveOneLabelOut()
-  >>> for train, test in lolo.split(X, y, labels):
-  ...     print("%s %s" % (train, test))
-  [2 3] [0 1]
-  [0 1] [2 3]
-
-Another common application is to use time information: for instance the
-labels could be the year of collection of the samples and thus allow
-for cross-validation against time-based splits.
-
-.. warning::
-
-  Contrary to :class:`StratifiedKFold`,
-  the ``labels`` of :class:`LeaveOneLabelOut` should not encode
-  the target class to predict: the goal of :class:`StratifiedKFold`
-  is to rebalance dataset classes across
-  the train / test split to ensure that the train and test folds have
-  approximately the same percentage of samples of each class while
-  :class:`LeaveOneLabelOut` will do the opposite by ensuring that the samples
-  of the train and test fold will not share the same label value.
-
-
-Leave-P-Label-Out
------------------
-
-:class:`LeavePLabelOut` is similar as *Leave-One-Label-Out*, but removes
-samples related to :math:`P` labels for each training/test set.
-
-Example of Leave-2-Label Out::
-
-  >>> from sklearn.model_selection import LeavePLabelOut
-
-  >>> X = np.arange(6)
-  >>> y = [1, 1, 1, 2, 2, 2]
-  >>> labels = [1, 1, 2, 2, 3, 3]
-  >>> lplo = LeavePLabelOut(n_labels=2)
-  >>> for train, test in lplo.split(X, y, labels):
-  ...     print("%s %s" % (train, test))
-  [4 5] [0 1 2 3]
-  [2 3] [0 1 4 5]
-  [0 1] [2 3 4 5]
-
 .. _ShuffleSplit:
 
 Random permutations cross-validation a.k.a. Shuffle & Split
@@ -467,26 +377,167 @@ Here is a usage example::
 validation that allows a finer control on the number of iterations and
 the proportion of samples on each side of the train / test split.
 
+Cross-validation iterators with stratification based on class labels.
+=====================================================================
+
+Some classification problems can exhibit a large imbalance in the distribution
+of the target classes: for instance there could be several times more negative
+samples than positive samples. In such cases it is recommended to use
+stratified sampling as implemented in :class:`StratifiedKFold` and
+:class:`StratifiedShuffleSplit` to ensure that relative class frequencies is
+approximately preserved in each train and validation fold.
 
-Label-Shuffle-Split
+Stratified k-fold
+-----------------
+
+:class:`StratifiedKFold` is a variation of *k-fold* which returns *stratified*
+folds: each set contains approximately the same percentage of samples of each
+target class as the complete set.
+
+Example of stratified 3-fold cross-validation on a dataset with 10 samples from
+two slightly unbalanced classes::
+
+  >>> from sklearn.model_selection import StratifiedKFold
+
+  >>> X = np.ones(10)
+  >>> y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
+  >>> skf = StratifiedKFold(n_splits=3)
+  >>> for train, test in skf.split(X, y):
+  ...     print("%s %s" % (train, test))
+  [2 3 6 7 8 9] [0 1 4 5]
+  [0 1 3 4 5 8 9] [2 6 7]
+  [0 1 2 4 5 6 7] [3 8 9]
+
+Stratified Shuffle Split
+------------------------
+
+:class:`StratifiedShuffleSplit` is a variation of *ShuffleSplit*, which returns
+stratified splits, *i.e* which creates splits by preserving the same
+percentage for each target class as in the complete set.
+
+.. _group_cv
+
+Cross-validation iterators for grouped data.
+============================================
+
+The i.i.d. assumption is broken if the underlying generative process yield
+groups of dependent samples.
+
+Such a grouping of data is domain specific. An example would be when there is
+medical data collected from multiple patients, with multiple samples taken from
+each patient. And such data is likely to be dependent on the individual group.
+In our example, the patient id for each sample will be its group identifier.
+
+In this case we would like to know if a model trained on a particular set of
+groups generalizes well to the unseen groups. To measure this, we need to
+ensure that all the samples in the validation fold come from groups that are
+not represented at all in the paired training fold.
+ 
+The following cross-validation splitters can be used to do that.
+The grouping identifier for the samples is specified via the ``groups``
+parameter.
+
+
+Group k-fold
+------------
+
+class:GroupKFold is a variation of k-fold which ensures that the same group is
+not represented in both testing and training sets. For example if the data is
+obtained from different subjects with several samples per-subject and if the
+model is flexible enough to learn from highly person specific features it
+could fail to generalize to new subjects. class:GroupKFold makes it possible
+to detect this kind of overfitting situations.
+
+Imagine you have three subjects, each with an associated number from 1 to 3::
+
+  >>> from sklearn.model_selection import GroupKFold
+
+  >>> X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
+  >>> y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
+  >>> groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
+
+  >>> gkf = GroupKFold(n_splits=3)
+  >>> for train, test in gkf.split(X, y, groups=groups):
+  ...     print("%s %s" % (train, test))
+  [0 1 2 3 4 5] [6 7 8 9]
+  [0 1 2 6 7 8 9] [3 4 5]
+  [3 4 5 6 7 8 9] [0 1 2]
+
+Each subject is in a different testing fold, and the same subject is never in
+both testing and training. Notice that the folds do not have exactly the same
+size due to the imbalance in the data.
+
+
+Leave One Group Out
 -------------------
 
-:class:`LabelShuffleSplit`
+:class:`LeaveOneGroupOut` is a cross-validation scheme which holds out
+the samples according to a third-party provided array of integer groups. This
+group information can be used to encode arbitrary domain specific pre-defined
+cross-validation folds.
+
+Each training set is thus constituted by all the samples except the ones
+related to a specific group.
+
+For example, in the cases of multiple experiments, :class:`LeaveOneGroupOut`
+can be used to create a cross-validation based on the different experiments:
+we create a training set using the samples of all the experiments except one::
+
+  >>> from sklearn.model_selection import LeaveOneGroupOut
+
+  >>> X = [1, 5, 10, 50, 60, 70, 80]
+  >>> y = [0, 1, 1, 2, 2, 2, 2]
+  >>> groups = [1, 1, 2, 2, 3, 3, 3]
+  >>> logo = LeaveOneGroupOut()
+  >>> for train, test in logo.split(X, y, groups=groups):
+  ...     print("%s %s" % (train, test))
+  [2 3 4 5 6] [0 1]
+  [0 1 4 5 6] [2 3]
+  [0 1 2 3] [4 5 6]
+
+Another common application is to use time information: for instance the
+groups could be the year of collection of the samples and thus allow
+for cross-validation against time-based splits.
+
+Leave P Groups Out
+------------------
 
-The :class:`LabelShuffleSplit` iterator behaves as a combination of
-:class:`ShuffleSplit` and :class:`LeavePLabelsOut`, and generates a
-sequence of randomized partitions in which a subset of labels are held
+:class:`LeavePGroupsOut` is similar as :class:`LeaveOneGroupOut`, but removes
+samples related to :math:`P` groups for each training/test set.
+
+Example of Leave-2-Group Out::
+
+  >>> from sklearn.model_selection import LeavePGroupsOut
+
+  >>> X = np.arange(6)
+  >>> y = [1, 1, 1, 2, 2, 2]
+  >>> groups = [1, 1, 2, 2, 3, 3]
+  >>> lpgo = LeavePGroupsOut(n_groups=2)
+  >>> for train, test in lpgo.split(X, y, groups=groups):
+  ...     print("%s %s" % (train, test))
+  [4 5] [0 1 2 3]
+  [2 3] [0 1 4 5]
+  [0 1] [2 3 4 5]
+
+Group Shuffle Split
+-------------------
+
+:class:`GroupShuffleSplit`
+
+The :class:`GroupShuffleSplit` iterator behaves as a combination of
+:class:`ShuffleSplit` and :class:`LeavePGroupsOut`, and generates a
+sequence of randomized partitions in which a subset of groups are held
 out for each split.
 
 Here is a usage example::
 
-  >>> from sklearn.model_selection import LabelShuffleSplit
+  >>> from sklearn.model_selection import GroupShuffleSplit
 
   >>> X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 0.001]
   >>> y = ["a", "b", "b", "b", "c", "c", "c", "a"]
-  >>> labels = [1, 1, 2, 2, 3, 3, 4, 4]
-  >>> lss = LabelShuffleSplit(n_splits=4, test_size=0.5, random_state=0)
-  >>> for train, test in lss.split(X, y, labels):
+  >>> groups = [1, 1, 2, 2, 3, 3, 4, 4]
+  >>> gss = GroupShuffleSplit(n_splits=4, test_size=0.5, random_state=0)
+  >>> for train, test in gss.split(X, y, groups=groups):
   ...     print("%s %s" % (train, test))
   ...
   [0 1 2 3] [4 5 6 7]
@@ -494,17 +545,16 @@ Here is a usage example::
   [2 3 4 5] [0 1 6 7]
   [4 5 6 7] [0 1 2 3]
 
-This class is useful when the behavior of :class:`LeavePLabelsOut` is
-desired, but the number of labels is large enough that generating all
-possible partitions with :math:`P` labels withheld would be prohibitively
-expensive.  In such a scenario, :class:`LabelShuffleSplit` provides
+This class is useful when the behavior of :class:`LeavePGroupsOut` is
+desired, but the number of groups is large enough that generating all
+possible partitions with :math:`P` groups withheld would be prohibitively
+expensive.  In such a scenario, :class:`GroupShuffleSplit` provides
 a random sample (with replacement) of the train / test splits
-generated by :class:`LeavePLabelsOut`.
-
+generated by :class:`LeavePGroupsOut`.
 
 
 Predefined Fold-Splits / Validation-Sets
-----------------------------------------
+========================================
 
 For some datasets, a pre-defined split of the data into training- and
 validation fold or into several cross-validation folds already
@@ -514,12 +564,7 @@ e.g. when searching for hyperparameters.
 For example, when using a validation set, set the ``test_fold`` to 0 for all
 samples that are part of the validation set, and to -1 for all other samples.
 
-
-See also
---------
-:class:`StratifiedShuffleSplit` is a variation of *ShuffleSplit*, which returns
-stratified splits, *i.e* which creates splits by preserving the same
-percentage for each target class as in the complete set.
+.. _timeseries_cv
 
 Cross validation of time series data
 ====================================
@@ -536,8 +581,8 @@ least like those that are used to train the model. To achieve this, one
 solution is provided by :class:`TimeSeriesSplit`.
 
 
-TimeSeriesSplit
------------------------
+Time Series Split
+-----------------
 
 :class:`TimeSeriesSplit` is a variation of *k-fold* which 
 returns first :math:`k` folds as train set and the :math:`(k+1)` th 
@@ -568,8 +613,8 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::
 A note on shuffling
 ===================
 
-If the data ordering is not arbitrary (e.g. samples with the same label are
-contiguous), shuffling it first may be essential to get a meaningful cross-
+If the data ordering is not arbitrary (e.g. samples with the same class label
+are contiguous), shuffling it first may be essential to get a meaningful cross-
 validation result. However, the opposite may be true if the samples are not
 independently and identically distributed. For example, if samples correspond
 to news articles, and are ordered by their time of publication, then shuffling
diff --git a/doc/tutorial/statistical_inference/model_selection.rst b/doc/tutorial/statistical_inference/model_selection.rst
index ef3568ffb0bcd4aa13a0bf4436c17df3e20c133a..6158846b27fc9e5db2695650c078f12426528ca5 100644
--- a/doc/tutorial/statistical_inference/model_selection.rst
+++ b/doc/tutorial/statistical_inference/model_selection.rst
@@ -110,7 +110,7 @@ scoring method.
 
     - :class:`StratifiedKFold` **(n_iter, test_size, train_size, random_state)**
 
-    - :class:`LabelKFold` **(n_splits, shuffle, random_state)**
+    - :class:`GroupKFold` **(n_splits, shuffle, random_state)**
 
 
    *
@@ -119,7 +119,7 @@ scoring method.
 
     - Same as K-Fold but preserves the class distribution within each fold.
 
-    - Ensures that the same label is not in both testing and training sets.
+    - Ensures that the same group is not in both testing and training sets.
 
 
 .. list-table::
@@ -130,7 +130,7 @@ scoring method.
 
     - :class:`StratifiedShuffleSplit`
 
-    - :class:`LabelShuffleSplit`
+    - :class:`GroupShuffleSplit`
 
    *
 
@@ -138,16 +138,16 @@ scoring method.
 
     - Same as shuffle split but preserves the class distribution within each iteration.
 
-    - Ensures that the same label is not in both testing and training sets.
+    - Ensures that the same group is not in both testing and training sets.
 
 
 .. list-table::
 
    *
 
-    - :class:`LeaveOneLabelOut` **()**
+    - :class:`LeaveOneGroupOut` **()**
 
-    - :class:`LeavePLabelOut`  **(p)**
+    - :class:`LeavePGroupsOut`  **(p)**
 
     - :class:`LeaveOneOut` **()**
 
@@ -155,9 +155,9 @@ scoring method.
 
    *
 
-    - Takes a label array to group observations.
+    - Takes a group array to group observations.
 
-    - Leave P labels out.
+    - Leave P groups out.
 
     - Leave one observation out.
 
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 3ae2a27b6e7f2cfa7b691477292ce26dce606a94..9e4e9bfebde862e9f900b186fbe12b8e9f3f8993 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -64,16 +64,41 @@ Model Selection Enhancements and API Changes
   - **Parameters ``n_folds`` and ``n_iter`` renamed to ``n_splits``**
 
     Some parameter names have changed:
-    The ``n_folds`` parameter in :class:`model_selection.KFold`,
-    :class:`model_selection.LabelKFold`, and
-    :class:`model_selection.StratifiedKFold` is now renamed to ``n_splits``.
-    The ``n_iter`` parameter in :class:`model_selection.ShuffleSplit`,
-    :class:`model_selection.LabelShuffleSplit`,
-    and :class:`model_selection.StratifiedShuffleSplit` is now renamed
-    to ``n_splits``.
+    The ``n_folds`` parameter in new :class:`model_selection.KFold`,
+    :class:`model_selection.GroupKFold` (see below for the name change),
+    and :class:`model_selection.StratifiedKFold` is now renamed to
+    ``n_splits``. The ``n_iter`` parameter in
+    :class:`model_selection.ShuffleSplit`, the new class
+    :class:`model_selection.GroupShuffleSplit` and
+    :class:`model_selection.StratifiedShuffleSplit` is now renamed to
+    ``n_splits``.
+
+  - **Rename of splitter classes which accepts group labels along with data**
+
+    The cross-validation splitters ``LabelKFold``,
+    ``LabelShuffleSplit``, ``LeaveOneLabelOut`` and ``LeavePLabelOut`` have
+    been renamed to :class:`model_selection.GroupKFold`,
+    :class:`model_selection.GroupShuffleSplit`,
+    :class:`model_selection.LeaveOneGroupOut` and
+    :class:`model_selection.LeavePGroupsOut` respectively.
+
+    NOTE the change from singular to plural form in
+    :class:`model_selection.LeavePGroupsOut`.
+
+  - **Fit parameter ``labels`` renamed to ``groups``**
+
+    The ``labels`` parameter in the :func:`split` method of the newly renamed
+    splitters :class:`model_selection.GroupKFold`,
+    :class:`model_selection.LeaveOneGroupOut`,
+    :class:`model_selection.LeavePGroupsOut`,
+    :class:`model_selection.GroupShuffleSplit` is renamed to ``groups``
+    following the new nomenclature of their class names.
+
+  - **Parameter ``n_labels`` renamed to ``n_groups``**
+
+    The parameter ``n_labels`` in the newly renamed
+    :class:`model_selection.LeavePGroupsOut` is changed to ``n_groups``.
 
-Changelog
----------
 
 New features
 ............
@@ -464,6 +489,20 @@ API changes summary
       :func:`metrics.classification.hamming_loss`.
       (`#7260 <https://github.com/scikit-learn/scikit-learn/pull/7260>`_) by
       `Sebastián Vanrell`_.
+   
+    - The splitter classes ``LabelKFold``, ``LabelShuffleSplit``,
+     ``LeaveOneLabelOut`` and ``LeavePLabelsOut`` are renamed to
+     :class:`model_selection.GroupKFold`,
+     :class:`model_selection.GroupShuffleSplit`,
+     :class:`model_selection.LeaveOneGroupOut`
+     and :class:`model_selection.LeavePGroupsOut` respectively.
+     Also the parameter ``labels`` in the :func:`split` method of the newly
+     renamed splitters :class:`model_selection.LeaveOneGroupOut` and
+     :class:`model_selection.LeavePGroupsOut` is renamed to
+     ``groups``. Additionally in :class:`model_selection.LeavePGroupsOut`,
+     the parameter ``n_labels``is renamed to ``n_groups``.
+     (`#6660 <https://github.com/scikit-learn/scikit-learn/pull/6660>`_)
+     by `Raghav RV`_.
 
 
 .. currentmodule:: sklearn
diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py
index c2acf500aff52265dbcdeac3e7f0030e709ddedf..f5ab0d7526ccf503456130c0462b36dfc960f05c 100644
--- a/sklearn/model_selection/__init__.py
+++ b/sklearn/model_selection/__init__.py
@@ -1,14 +1,14 @@
 from ._split import BaseCrossValidator
 from ._split import KFold
-from ._split import LabelKFold
+from ._split import GroupKFold
 from ._split import StratifiedKFold
 from ._split import TimeSeriesSplit
-from ._split import LeaveOneLabelOut
+from ._split import LeaveOneGroupOut
 from ._split import LeaveOneOut
-from ._split import LeavePLabelOut
+from ._split import LeavePGroupsOut
 from ._split import LeavePOut
 from ._split import ShuffleSplit
-from ._split import LabelShuffleSplit
+from ._split import GroupShuffleSplit
 from ._split import StratifiedShuffleSplit
 from ._split import PredefinedSplit
 from ._split import train_test_split
@@ -30,11 +30,11 @@ __all__ = ('BaseCrossValidator',
            'GridSearchCV',
            'TimeSeriesSplit',
            'KFold',
-           'LabelKFold',
-           'LabelShuffleSplit',
-           'LeaveOneLabelOut',
+           'GroupKFold',
+           'GroupShuffleSplit',
+           'LeaveOneGroupOut',
            'LeaveOneOut',
-           'LeavePLabelOut',
+           'LeavePGroupsOut',
            'LeavePOut',
            'ParameterGrid',
            'ParameterSampler',
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index c96ee7f19d704aa72421add24370a9cf436d50a1..f1880555df33904e3321fa1a4db6e0c74fa2512d 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -528,15 +528,15 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self._check_is_fitted('inverse_transform')
         return self.best_estimator_.transform(Xt)
 
-    def _fit(self, X, y, labels, parameter_iterable):
+    def _fit(self, X, y, groups, parameter_iterable):
         """Actual fitting,  performing the search over parameters."""
 
         estimator = self.estimator
         cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
         self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
 
-        X, y, labels = indexable(X, y, labels)
-        n_splits = cv.get_n_splits(X, y, labels)
+        X, y, groups = indexable(X, y, groups)
+        n_splits = cv.get_n_splits(X, y, groups)
         if self.verbose > 0 and isinstance(parameter_iterable, Sized):
             n_candidates = len(parameter_iterable)
             print("Fitting {0} folds for each of {1} candidates, totalling"
@@ -554,7 +554,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                                   self.fit_params, return_parameters=True,
                                   error_score=self.error_score)
           for parameters in parameter_iterable
-          for train, test in cv.split(X, y, labels))
+          for train, test in cv.split(X, y, groups))
 
         test_scores, test_sample_counts, _, parameters = zip(*out)
 
@@ -876,7 +876,7 @@ class GridSearchCV(BaseSearchCV):
         self.param_grid = param_grid
         _check_param_grid(param_grid)
 
-    def fit(self, X, y=None, labels=None):
+    def fit(self, X, y=None, groups=None):
         """Run fit with all sets of parameters.
 
         Parameters
@@ -890,11 +890,11 @@ class GridSearchCV(BaseSearchCV):
             Target relative to X for classification or regression;
             None for unsupervised learning.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
         """
-        return self._fit(X, y, labels, ParameterGrid(self.param_grid))
+        return self._fit(X, y, groups, ParameterGrid(self.param_grid))
 
 
 class RandomizedSearchCV(BaseSearchCV):
@@ -1104,7 +1104,7 @@ class RandomizedSearchCV(BaseSearchCV):
             n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
             pre_dispatch=pre_dispatch, error_score=error_score)
 
-    def fit(self, X, y=None, labels=None):
+    def fit(self, X, y=None, groups=None):
         """Run fit on the estimator with randomly drawn parameters.
 
         Parameters
@@ -1117,11 +1117,11 @@ class RandomizedSearchCV(BaseSearchCV):
             Target relative to X for classification or regression;
             None for unsupervised learning.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
         """
         sampled_params = ParameterSampler(self.param_distributions,
                                           self.n_iter,
                                           random_state=self.random_state)
-        return self._fit(X, y, labels, sampled_params)
+        return self._fit(X, y, groups, sampled_params)
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index a4762652854374c5f59a97b68be0d70c7b67cd4b..3570d688ff70f4a3a81bbc1bfa39473fe1696b58 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -36,13 +36,13 @@ from ..gaussian_process.kernels import Kernel as GPKernel
 
 __all__ = ['BaseCrossValidator',
            'KFold',
-           'LabelKFold',
-           'LeaveOneLabelOut',
+           'GroupKFold',
+           'LeaveOneGroupOut',
            'LeaveOneOut',
-           'LeavePLabelOut',
+           'LeavePGroupsOut',
            'LeavePOut',
            'ShuffleSplit',
-           'LabelShuffleSplit',
+           'GroupShuffleSplit',
            'StratifiedKFold',
            'StratifiedShuffleSplit',
            'PredefinedSplit',
@@ -61,7 +61,7 @@ class BaseCrossValidator(with_metaclass(ABCMeta)):
         # see #6304
         pass
 
-    def split(self, X, y=None, labels=None):
+    def split(self, X, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -73,7 +73,7 @@ class BaseCrossValidator(with_metaclass(ABCMeta)):
         y : array-like, of length n_samples
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -85,31 +85,31 @@ class BaseCrossValidator(with_metaclass(ABCMeta)):
         test : ndarray
             The testing set indices for that split.
         """
-        X, y, labels = indexable(X, y, labels)
+        X, y, groups = indexable(X, y, groups)
         indices = np.arange(_num_samples(X))
-        for test_index in self._iter_test_masks(X, y, labels):
+        for test_index in self._iter_test_masks(X, y, groups):
             train_index = indices[np.logical_not(test_index)]
             test_index = indices[test_index]
             yield train_index, test_index
 
     # Since subclasses must implement either _iter_test_masks or
     # _iter_test_indices, neither can be abstract.
-    def _iter_test_masks(self, X=None, y=None, labels=None):
+    def _iter_test_masks(self, X=None, y=None, groups=None):
         """Generates boolean masks corresponding to test sets.
 
-        By default, delegates to _iter_test_indices(X, y, labels)
+        By default, delegates to _iter_test_indices(X, y, groups)
         """
-        for test_index in self._iter_test_indices(X, y, labels):
+        for test_index in self._iter_test_indices(X, y, groups):
             test_mask = np.zeros(_num_samples(X), dtype=np.bool)
             test_mask[test_index] = True
             yield test_mask
 
-    def _iter_test_indices(self, X=None, y=None, labels=None):
+    def _iter_test_indices(self, X=None, y=None, groups=None):
         """Generates integer indices corresponding to test sets."""
         raise NotImplementedError
 
     @abstractmethod
-    def get_n_splits(self, X=None, y=None, labels=None):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator"""
 
     def __repr__(self):
@@ -155,17 +155,17 @@ class LeaveOneOut(BaseCrossValidator):
 
     See also
     --------
-    LeaveOneLabelOut
+    LeaveOneGroupOut
         For splitting the data according to explicit, domain-specific
         stratification of the dataset.
 
-    LabelKFold: K-fold iterator variant with non-overlapping labels.
+    GroupKFold: K-fold iterator variant with non-overlapping groups.
     """
 
-    def _iter_test_indices(self, X, y=None, labels=None):
+    def _iter_test_indices(self, X, y=None, groups=None):
         return range(_num_samples(X))
 
-    def get_n_splits(self, X, y=None, labels=None):
+    def get_n_splits(self, X, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -177,7 +177,7 @@ class LeaveOneOut(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -237,11 +237,11 @@ class LeavePOut(BaseCrossValidator):
     def __init__(self, p):
         self.p = p
 
-    def _iter_test_indices(self, X, y=None, labels=None):
+    def _iter_test_indices(self, X, y=None, groups=None):
         for combination in combinations(range(_num_samples(X)), self.p):
             yield np.array(combination)
 
-    def get_n_splits(self, X, y=None, labels=None):
+    def get_n_splits(self, X, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -253,7 +253,7 @@ class LeavePOut(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
         """
         if X is None:
@@ -262,7 +262,7 @@ class LeavePOut(BaseCrossValidator):
 
 
 class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
-    """Base class for KFold, LabelKFold, and StratifiedKFold"""
+    """Base class for KFold, GroupKFold, and StratifiedKFold"""
 
     @abstractmethod
     def __init__(self, n_splits, shuffle, random_state):
@@ -286,7 +286,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
         self.shuffle = shuffle
         self.random_state = random_state
 
-    def split(self, X, y=None, labels=None):
+    def split(self, X, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -298,7 +298,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -310,7 +310,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
         test : ndarray
             The testing set indices for that split.
         """
-        X, y, labels = indexable(X, y, labels)
+        X, y, groups = indexable(X, y, groups)
         n_samples = _num_samples(X)
         if self.n_splits > n_samples:
             raise ValueError(
@@ -318,10 +318,10 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
                  " than the number of samples: {1}.").format(self.n_splits,
                                                              n_samples))
 
-        for train, test in super(_BaseKFold, self).split(X, y, labels):
+        for train, test in super(_BaseKFold, self).split(X, y, groups):
             yield train, test
 
-    def get_n_splits(self, X=None, y=None, labels=None):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -332,7 +332,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -392,18 +392,18 @@ class KFold(_BaseKFold):
     See also
     --------
     StratifiedKFold
-        Takes label information into account to avoid building folds with
+        Takes group information into account to avoid building folds with
         imbalanced class distributions (for binary or multiclass
         classification tasks).
 
-    LabelKFold: K-fold iterator variant with non-overlapping labels.
+    GroupKFold: K-fold iterator variant with non-overlapping groups.
     """
 
     def __init__(self, n_splits=3, shuffle=False,
                  random_state=None):
         super(KFold, self).__init__(n_splits, shuffle, random_state)
 
-    def _iter_test_indices(self, X, y=None, labels=None):
+    def _iter_test_indices(self, X, y=None, groups=None):
         n_samples = _num_samples(X)
         indices = np.arange(n_samples)
         if self.shuffle:
@@ -419,14 +419,14 @@ class KFold(_BaseKFold):
             current = stop
 
 
-class LabelKFold(_BaseKFold):
-    """K-fold iterator variant with non-overlapping labels.
+class GroupKFold(_BaseKFold):
+    """K-fold iterator variant with non-overlapping groups.
 
-    The same label will not appear in two different folds (the number of
-    distinct labels has to be at least equal to the number of folds).
+    The same group will not appear in two different folds (the number of
+    distinct groups has to be at least equal to the number of folds).
 
     The folds are approximately balanced in the sense that the number of
-    distinct labels is approximately the same in each fold.
+    distinct groups is approximately the same in each fold.
 
     Parameters
     ----------
@@ -435,16 +435,16 @@ class LabelKFold(_BaseKFold):
 
     Examples
     --------
-    >>> from sklearn.model_selection import LabelKFold
+    >>> from sklearn.model_selection import GroupKFold
     >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
     >>> y = np.array([1, 2, 3, 4])
-    >>> labels = np.array([0, 0, 2, 2])
-    >>> label_kfold = LabelKFold(n_splits=2)
-    >>> label_kfold.get_n_splits(X, y, labels)
+    >>> groups = np.array([0, 0, 2, 2])
+    >>> group_kfold = GroupKFold(n_splits=2)
+    >>> group_kfold.get_n_splits(X, y, groups)
     2
-    >>> print(label_kfold)
-    LabelKFold(n_splits=2)
-    >>> for train_index, test_index in label_kfold.split(X, y, labels):
+    >>> print(group_kfold)
+    GroupKFold(n_splits=2)
+    >>> for train_index, test_index in group_kfold.split(X, y, groups):
     ...     print("TRAIN:", train_index, "TEST:", test_index)
     ...     X_train, X_test = X[train_index], X[test_index]
     ...     y_train, y_test = y[train_index], y[test_index]
@@ -461,46 +461,46 @@ class LabelKFold(_BaseKFold):
 
     See also
     --------
-    LeaveOneLabelOut
+    LeaveOneGroupOut
         For splitting the data according to explicit domain-specific
         stratification of the dataset.
     """
     def __init__(self, n_splits=3):
-        super(LabelKFold, self).__init__(n_splits, shuffle=False,
+        super(GroupKFold, self).__init__(n_splits, shuffle=False,
                                          random_state=None)
 
-    def _iter_test_indices(self, X, y, labels):
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
+    def _iter_test_indices(self, X, y, groups):
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
 
-        unique_labels, labels = np.unique(labels, return_inverse=True)
-        n_labels = len(unique_labels)
+        unique_groups, groups = np.unique(groups, return_inverse=True)
+        n_groups = len(unique_groups)
 
-        if self.n_splits > n_labels:
+        if self.n_splits > n_groups:
             raise ValueError("Cannot have number of splits n_splits=%d greater"
-                             " than the number of labels: %d."
-                             % (self.n_splits, n_labels))
+                             " than the number of groups: %d."
+                             % (self.n_splits, n_groups))
 
-        # Weight labels by their number of occurrences
-        n_samples_per_label = np.bincount(labels)
+        # Weight groups by their number of occurrences
+        n_samples_per_group = np.bincount(groups)
 
-        # Distribute the most frequent labels first
-        indices = np.argsort(n_samples_per_label)[::-1]
-        n_samples_per_label = n_samples_per_label[indices]
+        # Distribute the most frequent groups first
+        indices = np.argsort(n_samples_per_group)[::-1]
+        n_samples_per_group = n_samples_per_group[indices]
 
         # Total weight of each fold
         n_samples_per_fold = np.zeros(self.n_splits)
 
-        # Mapping from label index to fold index
-        label_to_fold = np.zeros(len(unique_labels))
+        # Mapping from group index to fold index
+        group_to_fold = np.zeros(len(unique_groups))
 
         # Distribute samples by adding the largest weight to the lightest fold
-        for label_index, weight in enumerate(n_samples_per_label):
+        for group_index, weight in enumerate(n_samples_per_group):
             lightest_fold = np.argmin(n_samples_per_fold)
             n_samples_per_fold[lightest_fold] += weight
-            label_to_fold[indices[label_index]] = lightest_fold
+            group_to_fold[indices[group_index]] = lightest_fold
 
-        indices = label_to_fold[labels]
+        indices = group_to_fold[groups]
 
         for f in range(self.n_splits):
             yield np.where(indices == f)[0]
@@ -557,7 +557,7 @@ class StratifiedKFold(_BaseKFold):
     def __init__(self, n_splits=3, shuffle=False, random_state=None):
         super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state)
 
-    def _make_test_folds(self, X, y=None, labels=None):
+    def _make_test_folds(self, X, y=None, groups=None):
         if self.shuffle:
             rng = check_random_state(self.random_state)
         else:
@@ -566,17 +566,17 @@ class StratifiedKFold(_BaseKFold):
         n_samples = y.shape[0]
         unique_y, y_inversed = np.unique(y, return_inverse=True)
         y_counts = bincount(y_inversed)
-        min_labels = np.min(y_counts)
+        min_groups = np.min(y_counts)
         if np.all(self.n_splits > y_counts):
-            raise ValueError("All the n_labels for individual classes"
+            raise ValueError("All the n_groups for individual classes"
                              " are less than n_splits=%d."
                              % (self.n_splits))
-        if self.n_splits > min_labels:
+        if self.n_splits > min_groups:
             warnings.warn(("The least populated class in y has only %d"
                            " members, which is too few. The minimum"
-                           " number of labels for any class cannot"
+                           " number of groups for any class cannot"
                            " be less than n_splits=%d."
-                           % (min_labels, self.n_splits)), Warning)
+                           % (min_groups, self.n_splits)), Warning)
 
         # pre-assign each sample to a test fold index using individual KFold
         # splitting strategies for each class so as to respect the balance of
@@ -604,12 +604,12 @@ class StratifiedKFold(_BaseKFold):
 
         return test_folds
 
-    def _iter_test_masks(self, X, y=None, labels=None):
+    def _iter_test_masks(self, X, y=None, groups=None):
         test_folds = self._make_test_folds(X, y)
         for i in range(self.n_splits):
             yield test_folds == i
 
-    def split(self, X, y, labels=None):
+    def split(self, X, y, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -621,7 +621,7 @@ class StratifiedKFold(_BaseKFold):
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -633,7 +633,7 @@ class StratifiedKFold(_BaseKFold):
         test : ndarray
             The testing set indices for that split.
         """
-        return super(StratifiedKFold, self).split(X, y, labels)
+        return super(StratifiedKFold, self).split(X, y, groups)
 
 
 class TimeSeriesSplit(_BaseKFold):
@@ -686,7 +686,7 @@ class TimeSeriesSplit(_BaseKFold):
                                               shuffle=False,
                                               random_state=None)
 
-    def split(self, X, y=None, labels=None):
+    def split(self, X, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -698,7 +698,7 @@ class TimeSeriesSplit(_BaseKFold):
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -710,7 +710,7 @@ class TimeSeriesSplit(_BaseKFold):
         test : ndarray
             The testing set indices for that split.
         """
-        X, y, labels = indexable(X, y, labels)
+        X, y, groups = indexable(X, y, groups)
         n_samples = _num_samples(X)
         n_splits = self.n_splits
         n_folds = n_splits + 1
@@ -728,30 +728,30 @@ class TimeSeriesSplit(_BaseKFold):
                    indices[test_start:test_start + test_size])
 
 
-class LeaveOneLabelOut(BaseCrossValidator):
-    """Leave One Label Out cross-validator
+class LeaveOneGroupOut(BaseCrossValidator):
+    """Leave One Group Out cross-validator
 
     Provides train/test indices to split data according to a third-party
-    provided label. This label information can be used to encode arbitrary
+    provided group. This group information can be used to encode arbitrary
     domain specific stratifications of the samples as integers.
 
-    For instance the labels could be the year of collection of the samples
+    For instance the groups could be the year of collection of the samples
     and thus allow for cross-validation against time-based splits.
 
     Read more in the :ref:`User Guide <cross_validation>`.
 
     Examples
     --------
-    >>> from sklearn.model_selection import LeaveOneLabelOut
+    >>> from sklearn.model_selection import LeaveOneGroupOut
     >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
     >>> y = np.array([1, 2, 1, 2])
-    >>> labels = np.array([1, 1, 2, 2])
-    >>> lol = LeaveOneLabelOut()
-    >>> lol.get_n_splits(X, y, labels)
+    >>> groups = np.array([1, 1, 2, 2])
+    >>> lol = LeaveOneGroupOut()
+    >>> lol.get_n_splits(X, y, groups)
     2
     >>> print(lol)
-    LeaveOneLabelOut()
-    >>> for train_index, test_index in lol.split(X, y, labels):
+    LeaveOneGroupOut()
+    >>> for train_index, test_index in lol.split(X, y, groups):
     ...    print("TRAIN:", train_index, "TEST:", test_index)
     ...    X_train, X_test = X[train_index], X[test_index]
     ...    y_train, y_test = y[train_index], y[test_index]
@@ -767,16 +767,16 @@ class LeaveOneLabelOut(BaseCrossValidator):
 
     """
 
-    def _iter_test_masks(self, X, y, labels):
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
-        # We make a copy of labels to avoid side-effects during iteration
-        labels = np.array(labels, copy=True)
-        unique_labels = np.unique(labels)
-        for i in unique_labels:
-            yield labels == i
+    def _iter_test_masks(self, X, y, groups):
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
+        # We make a copy of groups to avoid side-effects during iteration
+        groups = np.array(groups, copy=True)
+        unique_groups = np.unique(groups)
+        for i in unique_groups:
+            yield groups == i
 
-    def get_n_splits(self, X, y, labels):
+    def get_n_splits(self, X, y, groups):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -787,7 +787,7 @@ class LeaveOneLabelOut(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -796,45 +796,45 @@ class LeaveOneLabelOut(BaseCrossValidator):
         n_splits : int
             Returns the number of splitting iterations in the cross-validator.
         """
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
-        return len(np.unique(labels))
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
+        return len(np.unique(groups))
 
 
-class LeavePLabelOut(BaseCrossValidator):
-    """Leave P Labels Out cross-validator
+class LeavePGroupsOut(BaseCrossValidator):
+    """Leave P Group(s) Out cross-validator
 
     Provides train/test indices to split data according to a third-party
-    provided label. This label information can be used to encode arbitrary
+    provided group. This group information can be used to encode arbitrary
     domain specific stratifications of the samples as integers.
 
-    For instance the labels could be the year of collection of the samples
+    For instance the groups could be the year of collection of the samples
     and thus allow for cross-validation against time-based splits.
 
-    The difference between LeavePLabelOut and LeaveOneLabelOut is that
+    The difference between LeavePGroupsOut and LeaveOneGroupOut is that
     the former builds the test sets with all the samples assigned to
-    ``p`` different values of the labels while the latter uses samples
-    all assigned the same labels.
+    ``p`` different values of the groups while the latter uses samples
+    all assigned the same groups.
 
     Read more in the :ref:`User Guide <cross_validation>`.
 
     Parameters
     ----------
-    n_labels : int
-        Number of labels (``p``) to leave out in the test split.
+    n_groups : int
+        Number of groups (``p``) to leave out in the test split.
 
     Examples
     --------
-    >>> from sklearn.model_selection import LeavePLabelOut
+    >>> from sklearn.model_selection import LeavePGroupsOut
     >>> X = np.array([[1, 2], [3, 4], [5, 6]])
     >>> y = np.array([1, 2, 1])
-    >>> labels = np.array([1, 2, 3])
-    >>> lpl = LeavePLabelOut(n_labels=2)
-    >>> lpl.get_n_splits(X, y, labels)
+    >>> groups = np.array([1, 2, 3])
+    >>> lpl = LeavePGroupsOut(n_groups=2)
+    >>> lpl.get_n_splits(X, y, groups)
     3
     >>> print(lpl)
-    LeavePLabelOut(n_labels=2)
-    >>> for train_index, test_index in lpl.split(X, y, labels):
+    LeavePGroupsOut(n_groups=2)
+    >>> for train_index, test_index in lpl.split(X, y, groups):
     ...    print("TRAIN:", train_index, "TEST:", test_index)
     ...    X_train, X_test = X[train_index], X[test_index]
     ...    y_train, y_test = y[train_index], y[test_index]
@@ -851,25 +851,25 @@ class LeavePLabelOut(BaseCrossValidator):
 
     See also
     --------
-    LabelKFold: K-fold iterator variant with non-overlapping labels.
+    GroupKFold: K-fold iterator variant with non-overlapping groups.
     """
 
-    def __init__(self, n_labels):
-        self.n_labels = n_labels
+    def __init__(self, n_groups):
+        self.n_groups = n_groups
 
-    def _iter_test_masks(self, X, y, labels):
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
-        labels = np.array(labels, copy=True)
-        unique_labels = np.unique(labels)
-        combi = combinations(range(len(unique_labels)), self.n_labels)
+    def _iter_test_masks(self, X, y, groups):
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
+        groups = np.array(groups, copy=True)
+        unique_groups = np.unique(groups)
+        combi = combinations(range(len(unique_groups)), self.n_groups)
         for indices in combi:
             test_index = np.zeros(_num_samples(X), dtype=np.bool)
-            for l in unique_labels[np.array(indices)]:
-                test_index[labels == l] = True
+            for l in unique_groups[np.array(indices)]:
+                test_index[groups == l] = True
             yield test_index
 
-    def get_n_splits(self, X, y, labels):
+    def get_n_splits(self, X, y, groups):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -880,7 +880,7 @@ class LeavePLabelOut(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -889,9 +889,9 @@ class LeavePLabelOut(BaseCrossValidator):
         n_splits : int
             Returns the number of splitting iterations in the cross-validator.
         """
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
-        return int(comb(len(np.unique(labels)), self.n_labels, exact=True))
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
+        return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
 
 
 class BaseShuffleSplit(with_metaclass(ABCMeta)):
@@ -905,7 +905,7 @@ class BaseShuffleSplit(with_metaclass(ABCMeta)):
         self.train_size = train_size
         self.random_state = random_state
 
-    def split(self, X, y=None, labels=None):
+    def split(self, X, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -917,7 +917,7 @@ class BaseShuffleSplit(with_metaclass(ABCMeta)):
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -929,15 +929,15 @@ class BaseShuffleSplit(with_metaclass(ABCMeta)):
         test : ndarray
             The testing set indices for that split.
         """
-        X, y, labels = indexable(X, y, labels)
-        for train, test in self._iter_indices(X, y, labels):
+        X, y, groups = indexable(X, y, groups)
+        for train, test in self._iter_indices(X, y, groups):
             yield train, test
 
     @abstractmethod
-    def _iter_indices(self, X, y=None, labels=None):
+    def _iter_indices(self, X, y=None, groups=None):
         """Generate (train, test) indices"""
 
-    def get_n_splits(self, X=None, y=None, labels=None):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -948,7 +948,7 @@ class BaseShuffleSplit(with_metaclass(ABCMeta)):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -1019,7 +1019,7 @@ class ShuffleSplit(BaseShuffleSplit):
     TRAIN: [0 2] TEST: [3]
     """
 
-    def _iter_indices(self, X, y=None, labels=None):
+    def _iter_indices(self, X, y=None, groups=None):
         n_samples = _num_samples(X)
         n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
                                                   self.train_size)
@@ -1032,26 +1032,26 @@ class ShuffleSplit(BaseShuffleSplit):
             yield ind_train, ind_test
 
 
-class LabelShuffleSplit(ShuffleSplit):
-    '''Shuffle-Labels-Out cross-validation iterator
+class GroupShuffleSplit(ShuffleSplit):
+    '''Shuffle-Group(s)-Out cross-validation iterator
 
     Provides randomized train/test indices to split data according to a
-    third-party provided label. This label information can be used to encode
+    third-party provided group. This group information can be used to encode
     arbitrary domain specific stratifications of the samples as integers.
 
-    For instance the labels could be the year of collection of the samples
+    For instance the groups could be the year of collection of the samples
     and thus allow for cross-validation against time-based splits.
 
-    The difference between LeavePLabelOut and LabelShuffleSplit is that
-    the former generates splits using all subsets of size ``p`` unique labels,
-    whereas LabelShuffleSplit generates a user-determined number of random
-    test splits, each with a user-determined fraction of unique labels.
+    The difference between LeavePGroupsOut and GroupShuffleSplit is that
+    the former generates splits using all subsets of size ``p`` unique groups,
+    whereas GroupShuffleSplit generates a user-determined number of random
+    test splits, each with a user-determined fraction of unique groups.
 
     For example, a less computationally intensive alternative to
-    ``LeavePLabelOut(p=10)`` would be
-    ``LabelShuffleSplit(test_size=10, n_splits=100)``.
+    ``LeavePGroupsOut(p=10)`` would be
+    ``GroupShuffleSplit(test_size=10, n_splits=100)``.
 
-    Note: The parameters ``test_size`` and ``train_size`` refer to labels, and
+    Note: The parameters ``test_size`` and ``train_size`` refer to groups, and
     not to samples, as in ShuffleSplit.
 
 
@@ -1062,14 +1062,14 @@ class LabelShuffleSplit(ShuffleSplit):
 
     test_size : float (default 0.2), int, or None
         If float, should be between 0.0 and 1.0 and represent the
-        proportion of the labels to include in the test split. If
-        int, represents the absolute number of test labels. If None,
+        proportion of the groups to include in the test split. If
+        int, represents the absolute number of test groups. If None,
         the value is automatically set to the complement of the train size.
 
     train_size : float, int, or None (default is None)
         If float, should be between 0.0 and 1.0 and represent the
-        proportion of the labels to include in the train split. If
-        int, represents the absolute number of train labels. If None,
+        proportion of the groups to include in the train split. If
+        int, represents the absolute number of train groups. If None,
         the value is automatically set to the complement of the test size.
 
     random_state : int or RandomState
@@ -1078,23 +1078,23 @@ class LabelShuffleSplit(ShuffleSplit):
 
     def __init__(self, n_splits=5, test_size=0.2, train_size=None,
                  random_state=None):
-        super(LabelShuffleSplit, self).__init__(
+        super(GroupShuffleSplit, self).__init__(
             n_splits=n_splits,
             test_size=test_size,
             train_size=train_size,
             random_state=random_state)
 
-    def _iter_indices(self, X, y, labels):
-        if labels is None:
-            raise ValueError("The labels parameter should not be None")
-        classes, label_indices = np.unique(labels, return_inverse=True)
-        for label_train, label_test in super(
-                LabelShuffleSplit, self)._iter_indices(X=classes):
+    def _iter_indices(self, X, y, groups):
+        if groups is None:
+            raise ValueError("The groups parameter should not be None")
+        classes, group_indices = np.unique(groups, return_inverse=True)
+        for group_train, group_test in super(
+                GroupShuffleSplit, self)._iter_indices(X=classes):
             # these are the indices of classes in the partition
             # invert them into data indices
 
-            train = np.flatnonzero(np.in1d(label_indices, label_train))
-            test = np.flatnonzero(np.in1d(label_indices, label_test))
+            train = np.flatnonzero(np.in1d(group_indices, group_train))
+            test = np.flatnonzero(np.in1d(group_indices, group_test))
 
             yield train, test
 
@@ -1225,7 +1225,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
         super(StratifiedShuffleSplit, self).__init__(
             n_splits, test_size, train_size, random_state)
 
-    def _iter_indices(self, X, y, labels=None):
+    def _iter_indices(self, X, y, groups=None):
         n_samples = _num_samples(X)
         n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
                                                   self.train_size)
@@ -1236,7 +1236,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
         if np.min(class_counts) < 2:
             raise ValueError("The least populated class in y has only 1"
                              " member, which is too few. The minimum"
-                             " number of labels for any class cannot"
+                             " number of groups for any class cannot"
                              " be less than 2.")
 
         if n_train < n_classes:
@@ -1271,7 +1271,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
 
             yield train, test
 
-    def split(self, X, y, labels=None):
+    def split(self, X, y, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -1283,7 +1283,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
         y : array-like, shape (n_samples,)
             The target variable for supervised learning problems.
 
-        labels : array-like, with shape (n_samples,), optional
+        groups : array-like, with shape (n_samples,), optional
             Group labels for the samples used while splitting the dataset into
             train/test set.
 
@@ -1295,7 +1295,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
         test : ndarray
             The testing set indices for that split.
         """
-        return super(StratifiedShuffleSplit, self).split(X, y, labels)
+        return super(StratifiedShuffleSplit, self).split(X, y, groups)
 
 
 def _validate_shuffle_split_init(test_size, train_size):
@@ -1338,13 +1338,13 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
     Validation helper to check if the test/test sizes are meaningful wrt to the
     size of the data (n_samples)
     """
-    if (test_size is not None and np.asarray(test_size).dtype.kind == 'i'
-            and test_size >= n_samples):
+    if (test_size is not None and np.asarray(test_size).dtype.kind == 'i' and
+            test_size >= n_samples):
         raise ValueError('test_size=%d should be smaller than the number of '
                          'samples %d' % (test_size, n_samples))
 
-    if (train_size is not None and np.asarray(train_size).dtype.kind == 'i'
-            and train_size >= n_samples):
+    if (train_size is not None and np.asarray(train_size).dtype.kind == 'i' and
+            train_size >= n_samples):
         raise ValueError("train_size=%d should be smaller than the number of"
                          " samples %d" % (train_size, n_samples))
 
@@ -1406,7 +1406,7 @@ class PredefinedSplit(BaseCrossValidator):
         self.unique_folds = np.unique(self.test_fold)
         self.unique_folds = self.unique_folds[self.unique_folds != -1]
 
-    def split(self, X=None, y=None, labels=None):
+    def split(self, X=None, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -1417,7 +1417,7 @@ class PredefinedSplit(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -1442,7 +1442,7 @@ class PredefinedSplit(BaseCrossValidator):
             test_mask[test_index] = True
             yield test_mask
 
-    def get_n_splits(self, X=None, y=None, labels=None):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -1453,7 +1453,7 @@ class PredefinedSplit(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -1469,7 +1469,7 @@ class _CVIterableWrapper(BaseCrossValidator):
     def __init__(self, cv):
         self.cv = cv
 
-    def get_n_splits(self, X=None, y=None, labels=None):
+    def get_n_splits(self, X=None, y=None, groups=None):
         """Returns the number of splitting iterations in the cross-validator
 
         Parameters
@@ -1480,7 +1480,7 @@ class _CVIterableWrapper(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -1490,7 +1490,7 @@ class _CVIterableWrapper(BaseCrossValidator):
         """
         return len(self.cv)  # Both iterables and old-cv objects support len
 
-    def split(self, X=None, y=None, labels=None):
+    def split(self, X=None, y=None, groups=None):
         """Generate indices to split data into training and test set.
 
         Parameters
@@ -1501,7 +1501,7 @@ class _CVIterableWrapper(BaseCrossValidator):
         y : object
             Always ignored, exists for compatibility.
 
-        labels : object
+        groups : object
             Always ignored, exists for compatibility.
 
         Returns
@@ -1603,7 +1603,7 @@ def train_test_split(*arrays, **options):
 
     stratify : array-like or None (default is None)
         If not None, data is split in a stratified fashion, using this as
-        the labels array.
+        the groups array.
 
     Returns
     -------
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index 183f16a10fddbbbbc7b916063cf309e589eb70d2..febf5634a38029bc036281dc56c37fc91be153e1 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -28,13 +28,13 @@ from ..metrics.scorer import check_scoring
 from ..exceptions import FitFailedWarning
 
 from ._split import KFold
-from ._split import LabelKFold
-from ._split import LeaveOneLabelOut
+from ._split import GroupKFold
+from ._split import LeaveOneGroupOut
 from ._split import LeaveOneOut
-from ._split import LeavePLabelOut
+from ._split import LeavePGroupsOut
 from ._split import LeavePOut
 from ._split import ShuffleSplit
-from ._split import LabelShuffleSplit
+from ._split import GroupShuffleSplit
 from ._split import StratifiedKFold
 from ._split import StratifiedShuffleSplit
 from ._split import PredefinedSplit
@@ -44,24 +44,24 @@ __all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
            'learning_curve', 'validation_curve']
 
 ALL_CVS = {'KFold': KFold,
-           'LabelKFold': LabelKFold,
-           'LeaveOneLabelOut': LeaveOneLabelOut,
+           'GroupKFold': GroupKFold,
+           'LeaveOneGroupOut': LeaveOneGroupOut,
            'LeaveOneOut': LeaveOneOut,
-           'LeavePLabelOut': LeavePLabelOut,
+           'LeavePGroupsOut': LeavePGroupsOut,
            'LeavePOut': LeavePOut,
            'ShuffleSplit': ShuffleSplit,
-           'LabelShuffleSplit': LabelShuffleSplit,
+           'GroupShuffleSplit': GroupShuffleSplit,
            'StratifiedKFold': StratifiedKFold,
            'StratifiedShuffleSplit': StratifiedShuffleSplit,
            'PredefinedSplit': PredefinedSplit}
 
-LABEL_CVS = {'LabelKFold': LabelKFold,
-             'LeaveOneLabelOut': LeaveOneLabelOut,
-             'LeavePLabelOut': LeavePLabelOut,
-             'LabelShuffleSplit': LabelShuffleSplit}
+GROUP_CVS = {'GroupKFold': GroupKFold,
+             'LeaveOneGroupOut': LeaveOneGroupOut,
+             'LeavePGroupsOut': LeavePGroupsOut,
+             'GroupShuffleSplit': GroupShuffleSplit}
 
 
-def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
+def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
                     n_jobs=1, verbose=0, fit_params=None,
                     pre_dispatch='2*n_jobs'):
     """Evaluate a score by cross-validation
@@ -80,7 +80,7 @@ def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
         The target variable to try to predict in the case of
         supervised learning.
 
-    labels : array-like, with shape (n_samples,), optional
+    groups : array-like, with shape (n_samples,), optional
         Group labels for the samples used while splitting the dataset into
         train/test set.
 
@@ -153,7 +153,7 @@ def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
         Make a scorer from a performance metric or loss function.
 
     """
-    X, y, labels = indexable(X, y, labels)
+    X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
     scorer = check_scoring(estimator, scoring=scoring)
@@ -164,7 +164,7 @@ def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
     scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
                                               train, test, verbose, None,
                                               fit_params)
-                      for train, test in cv.split(X, y, labels))
+                      for train, test in cv.split(X, y, groups))
     return np.array(scores)[:, 0]
 
 
@@ -314,7 +314,7 @@ def _score(estimator, X_test, y_test, scorer):
     return score
 
 
-def cross_val_predict(estimator, X, y=None, labels=None, cv=None, n_jobs=1,
+def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
                       verbose=0, fit_params=None, pre_dispatch='2*n_jobs',
                       method='predict'):
     """Generate cross-validated estimates for each input data point
@@ -333,7 +333,7 @@ def cross_val_predict(estimator, X, y=None, labels=None, cv=None, n_jobs=1,
         The target variable to try to predict in the case of
         supervised learning.
 
-    labels : array-like, with shape (n_samples,), optional
+    groups : array-like, with shape (n_samples,), optional
         Group labels for the samples used while splitting the dataset into
         train/test set.
 
@@ -397,7 +397,7 @@ def cross_val_predict(estimator, X, y=None, labels=None, cv=None, n_jobs=1,
     >>> lasso = linear_model.Lasso()
     >>> y_pred = cross_val_predict(lasso, X, y)
     """
-    X, y, labels = indexable(X, y, labels)
+    X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
 
@@ -412,7 +412,7 @@ def cross_val_predict(estimator, X, y=None, labels=None, cv=None, n_jobs=1,
                         pre_dispatch=pre_dispatch)
     prediction_blocks = parallel(delayed(_fit_and_predict)(
         clone(estimator), X, y, train, test, verbose, fit_params, method)
-        for train, test in cv.split(X, y, labels))
+        for train, test in cv.split(X, y, groups))
 
     # Concatenate the predictions
     predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
@@ -525,7 +525,7 @@ def _index_param_value(X, v, indices):
     return safe_indexing(v, indices)
 
 
-def permutation_test_score(estimator, X, y, labels=None, cv=None,
+def permutation_test_score(estimator, X, y, groups=None, cv=None,
                            n_permutations=100, n_jobs=1, random_state=0,
                            verbose=0, scoring=None):
     """Evaluate the significance of a cross-validated score with permutations
@@ -544,9 +544,15 @@ def permutation_test_score(estimator, X, y, labels=None, cv=None,
         The target variable to try to predict in the case of
         supervised learning.
 
-    labels : array-like, with shape (n_samples,), optional
-        Group labels for the samples used while splitting the dataset into
-        train/test set.
+    groups : array-like, with shape (n_samples,), optional
+        Labels to constrain permutation within groups, i.e. ``y`` values
+        are permuted among samples with the same group identifier.
+        When not specified, ``y`` values are permuted among all samples.
+
+        When a grouped cross-validator is used, the group labels are
+        also passed on to the ``split`` method of the cross-validator. The
+        cross-validator uses them for grouping the samples  while splitting
+        the dataset into train/test set.
 
     scoring : string, callable or None, optional, default: None
         A string (see model evaluation documentation) or
@@ -606,7 +612,7 @@ def permutation_test_score(estimator, X, y, labels=None, cv=None,
         vol. 11
 
     """
-    X, y, labels = indexable(X, y, labels)
+    X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
     scorer = check_scoring(estimator, scoring=scoring)
@@ -614,11 +620,11 @@ def permutation_test_score(estimator, X, y, labels=None, cv=None,
 
     # We clone the estimator to make sure that all the folds are
     # independent, and that it is pickle-able.
-    score = _permutation_test_score(clone(estimator), X, y, labels, cv, scorer)
+    score = _permutation_test_score(clone(estimator), X, y, groups, cv, scorer)
     permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
         delayed(_permutation_test_score)(
-            clone(estimator), X, _shuffle(y, labels, random_state),
-            labels, cv, scorer)
+            clone(estimator), X, _shuffle(y, groups, random_state),
+            groups, cv, scorer)
         for _ in range(n_permutations))
     permutation_scores = np.array(permutation_scores)
     pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1)
@@ -628,28 +634,28 @@ def permutation_test_score(estimator, X, y, labels=None, cv=None,
 permutation_test_score.__test__ = False  # to avoid a pb with nosetests
 
 
-def _permutation_test_score(estimator, X, y, labels, cv, scorer):
+def _permutation_test_score(estimator, X, y, groups, cv, scorer):
     """Auxiliary function for permutation_test_score"""
     avg_score = []
-    for train, test in cv.split(X, y, labels):
+    for train, test in cv.split(X, y, groups):
         estimator.fit(X[train], y[train])
         avg_score.append(scorer(estimator, X[test], y[test]))
     return np.mean(avg_score)
 
 
-def _shuffle(y, labels, random_state):
-    """Return a shuffled copy of y eventually shuffle among same labels."""
-    if labels is None:
+def _shuffle(y, groups, random_state):
+    """Return a shuffled copy of y eventually shuffle among same groups."""
+    if groups is None:
         indices = random_state.permutation(len(y))
     else:
-        indices = np.arange(len(labels))
-        for label in np.unique(labels):
-            this_mask = (labels == label)
+        indices = np.arange(len(groups))
+        for group in np.unique(groups):
+            this_mask = (groups == group)
             indices[this_mask] = random_state.permutation(indices[this_mask])
     return y[indices]
 
 
-def learning_curve(estimator, X, y, labels=None,
+def learning_curve(estimator, X, y, groups=None,
                    train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None,
                    exploit_incremental_learning=False, n_jobs=1,
                    pre_dispatch="all", verbose=0):
@@ -679,7 +685,7 @@ def learning_curve(estimator, X, y, labels=None,
         Target relative to X for classification or regression;
         None for unsupervised learning.
 
-    labels : array-like, with shape (n_samples,), optional
+    groups : array-like, with shape (n_samples,), optional
         Group labels for the samples used while splitting the dataset into
         train/test set.
 
@@ -749,10 +755,10 @@ def learning_curve(estimator, X, y, labels=None,
     if exploit_incremental_learning and not hasattr(estimator, "partial_fit"):
         raise ValueError("An estimator must support the partial_fit interface "
                          "to exploit incremental learning")
-    X, y, labels = indexable(X, y, labels)
+    X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
-    cv_iter = cv.split(X, y, labels)
+    cv_iter = cv.split(X, y, groups)
     # Make a list since we will be iterating multiple times over the folds
     cv_iter = list(cv_iter)
     scorer = check_scoring(estimator, scoring=scoring)
@@ -773,7 +779,7 @@ def learning_curve(estimator, X, y, labels=None,
         classes = np.unique(y) if is_classifier(estimator) else None
         out = parallel(delayed(_incremental_fit_estimator)(
             clone(estimator), X, y, classes, train, test, train_sizes_abs,
-            scorer, verbose) for train, test in cv.split(X, y, labels))
+            scorer, verbose) for train, test in cv.split(X, y, groups))
     else:
         out = parallel(delayed(_fit_and_score)(
             clone(estimator), X, y, scorer, train[:n_train_samples], test,
@@ -869,7 +875,7 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,
     return np.array((train_scores, test_scores)).T
 
 
-def validation_curve(estimator, X, y, param_name, param_range, labels=None,
+def validation_curve(estimator, X, y, param_name, param_range, groups=None,
                      cv=None, scoring=None, n_jobs=1, pre_dispatch="all",
                      verbose=0):
     """Validation curve.
@@ -902,7 +908,7 @@ def validation_curve(estimator, X, y, param_name, param_range, labels=None,
     param_range : array-like, shape (n_values,)
         The values of the parameter that will be evaluated.
 
-    labels : array-like, with shape (n_samples,), optional
+    groups : array-like, with shape (n_samples,), optional
         Group labels for the samples used while splitting the dataset into
         train/test set.
 
@@ -950,7 +956,7 @@ def validation_curve(estimator, X, y, param_name, param_range, labels=None,
     See :ref:`sphx_glr_auto_examples_model_selection_plot_validation_curve.py`
 
     """
-    X, y, labels = indexable(X, y, labels)
+    X, y, groups = indexable(X, y, groups)
 
     cv = check_cv(cv, y, classifier=is_classifier(estimator))
 
@@ -961,7 +967,7 @@ def validation_curve(estimator, X, y, param_name, param_range, labels=None,
     out = parallel(delayed(_fit_and_score)(
         estimator, X, y, scorer, train, test, verbose,
         parameters={param_name: v}, fit_params=None, return_train_score=True)
-        for train, test in cv.split(X, y, labels) for v in param_range)
+        for train, test in cv.split(X, y, groups) for v in param_range)
 
     out = np.asarray(out)[:, :2]
     n_params = len(param_range)
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 03eafdafb1d30521bf7a96f2f22bcaf890c943da..141c1a21b46e98dc7e5de73bb8fca2dab3f627cb 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -37,10 +37,10 @@ from sklearn.datasets import make_multilabel_classification
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
 from sklearn.model_selection import StratifiedShuffleSplit
-from sklearn.model_selection import LeaveOneLabelOut
-from sklearn.model_selection import LeavePLabelOut
-from sklearn.model_selection import LabelKFold
-from sklearn.model_selection import LabelShuffleSplit
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import GroupShuffleSplit
 from sklearn.model_selection import GridSearchCV
 from sklearn.model_selection import RandomizedSearchCV
 from sklearn.model_selection import ParameterGrid
@@ -224,28 +224,28 @@ def test_grid_search_score_method():
     assert_almost_equal(score_auc, score_no_score_auc)
 
 
-def test_grid_search_labels():
-    # Check if ValueError (when labels is None) propagates to GridSearchCV
-    # And also check if labels is correctly passed to the cv object
+def test_grid_search_groups():
+    # Check if ValueError (when groups is None) propagates to GridSearchCV
+    # And also check if groups is correctly passed to the cv object
     rng = np.random.RandomState(0)
 
     X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
-    labels = rng.randint(0, 3, 15)
+    groups = rng.randint(0, 3, 15)
 
     clf = LinearSVC(random_state=0)
     grid = {'C': [1]}
 
-    label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(),
-                 LabelShuffleSplit()]
-    for cv in label_cvs:
+    group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), GroupKFold(),
+                 GroupShuffleSplit()]
+    for cv in group_cvs:
         gs = GridSearchCV(clf, grid, cv=cv)
         assert_raise_message(ValueError,
-                             "The labels parameter should not be None",
+                             "The groups parameter should not be None",
                              gs.fit, X, y)
-        gs.fit(X, y, labels)
+        gs.fit(X, y, groups=groups)
 
-    non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
-    for cv in non_label_cvs:
+    non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
+    for cv in non_group_cvs:
         gs = GridSearchCV(clf, grid, cv=cv)
         # Should not raise an error
         gs.fit(X, y)
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index d4130182b0e10c468b924c685e6240203c32a1e8..ec4f07aef8e9adac1ee8a5c458c82e6980c22bc8 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -29,14 +29,14 @@ from sklearn.utils.mocking import MockDataFrame
 from sklearn.model_selection import cross_val_score
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
-from sklearn.model_selection import LabelKFold
+from sklearn.model_selection import GroupKFold
 from sklearn.model_selection import TimeSeriesSplit
 from sklearn.model_selection import LeaveOneOut
-from sklearn.model_selection import LeaveOneLabelOut
+from sklearn.model_selection import LeaveOneGroupOut
 from sklearn.model_selection import LeavePOut
-from sklearn.model_selection import LeavePLabelOut
+from sklearn.model_selection import LeavePGroupsOut
 from sklearn.model_selection import ShuffleSplit
-from sklearn.model_selection import LabelShuffleSplit
+from sklearn.model_selection import GroupShuffleSplit
 from sklearn.model_selection import StratifiedShuffleSplit
 from sklearn.model_selection import PredefinedSplit
 from sklearn.model_selection import check_cv
@@ -132,7 +132,7 @@ class MockClassifier(object):
 @ignore_warnings
 def test_cross_validator_with_default_params():
     n_samples = 4
-    n_unique_labels = 4
+    n_unique_groups = 4
     n_splits = 2
     p = 2
     n_shuffle_splits = 10  # (the default value)
@@ -140,13 +140,13 @@ def test_cross_validator_with_default_params():
     X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
     X_1d = np.array([1, 2, 3, 4])
     y = np.array([1, 1, 2, 2])
-    labels = np.array([1, 2, 3, 4])
+    groups = np.array([1, 2, 3, 4])
     loo = LeaveOneOut()
     lpo = LeavePOut(p)
     kf = KFold(n_splits)
     skf = StratifiedKFold(n_splits)
-    lolo = LeaveOneLabelOut()
-    lopo = LeavePLabelOut(p)
+    lolo = LeaveOneGroupOut()
+    lopo = LeavePGroupsOut(p)
     ss = ShuffleSplit(random_state=0)
     ps = PredefinedSplit([1, 1, 2, 2])  # n_splits = np of unique folds = 2
 
@@ -154,14 +154,14 @@ def test_cross_validator_with_default_params():
     lpo_repr = "LeavePOut(p=2)"
     kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
     skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
-    lolo_repr = "LeaveOneLabelOut()"
-    lopo_repr = "LeavePLabelOut(n_labels=2)"
+    lolo_repr = "LeaveOneGroupOut()"
+    lopo_repr = "LeavePGroupsOut(n_groups=2)"
     ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=0.1, "
                "train_size=None)")
     ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
 
     n_splits_expected = [n_samples, comb(n_samples, p), n_splits, n_splits,
-                         n_unique_labels, comb(n_unique_labels, p),
+                         n_unique_groups, comb(n_unique_groups, p),
                          n_shuffle_splits, 2]
 
     for i, (cv, cv_repr) in enumerate(zip(
@@ -169,14 +169,14 @@ def test_cross_validator_with_default_params():
             [loo_repr, lpo_repr, kf_repr, skf_repr, lolo_repr, lopo_repr,
              ss_repr, ps_repr])):
         # Test if get_n_splits works correctly
-        assert_equal(n_splits_expected[i], cv.get_n_splits(X, y, labels))
+        assert_equal(n_splits_expected[i], cv.get_n_splits(X, y, groups))
 
         # Test if the cross-validator works as expected even if
         # the data is 1d
-        np.testing.assert_equal(list(cv.split(X, y, labels)),
-                                list(cv.split(X_1d, y, labels)))
+        np.testing.assert_equal(list(cv.split(X, y, groups)),
+                                list(cv.split(X_1d, y, groups)))
         # Test that train, test indices returned are integers
-        for train, test in cv.split(X, y, labels):
+        for train, test in cv.split(X, y, groups):
             assert_equal(np.asarray(train).dtype.kind, 'i')
             assert_equal(np.asarray(train).dtype.kind, 'i')
 
@@ -196,17 +196,17 @@ def check_valid_split(train, test, n_samples=None):
         assert_equal(train.union(test), set(range(n_samples)))
 
 
-def check_cv_coverage(cv, X, y, labels, expected_n_splits=None):
+def check_cv_coverage(cv, X, y, groups, expected_n_splits=None):
     n_samples = _num_samples(X)
     # Check that a all the samples appear at least once in a test fold
     if expected_n_splits is not None:
-        assert_equal(cv.get_n_splits(X, y, labels), expected_n_splits)
+        assert_equal(cv.get_n_splits(X, y, groups), expected_n_splits)
     else:
-        expected_n_splits = cv.get_n_splits(X, y, labels)
+        expected_n_splits = cv.get_n_splits(X, y, groups)
 
     collected_test_samples = set()
     iterations = 0
-    for train, test in cv.split(X, y, labels):
+    for train, test in cv.split(X, y, groups):
         check_valid_split(train, test, n_samples=n_samples)
         iterations += 1
         collected_test_samples.update(test)
@@ -236,9 +236,9 @@ def test_kfold_valueerrors():
     # side of the split at each split
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        check_cv_coverage(skf_3, X2, y, labels=None, expected_n_splits=3)
+        check_cv_coverage(skf_3, X2, y, groups=None, expected_n_splits=3)
 
-    # Check that errors are raised if all n_labels for individual
+    # Check that errors are raised if all n_groups for individual
     # classes are less than n_splits.
     y = np.array([3, 3, -1, -1, 2])
 
@@ -268,13 +268,13 @@ def test_kfold_indices():
     # Check all indices are returned in the test folds
     X1 = np.ones(18)
     kf = KFold(3)
-    check_cv_coverage(kf, X1, y=None, labels=None, expected_n_splits=3)
+    check_cv_coverage(kf, X1, y=None, groups=None, expected_n_splits=3)
 
     # Check all indices are returned in the test folds even when equal-sized
     # folds are not possible
     X2 = np.ones(17)
     kf = KFold(3)
-    check_cv_coverage(kf, X2, y=None, labels=None, expected_n_splits=3)
+    check_cv_coverage(kf, X2, y=None, groups=None, expected_n_splits=3)
 
     # Check if get_n_splits returns the number of folds
     assert_equal(5, KFold(5).get_n_splits(X2))
@@ -443,7 +443,7 @@ def test_shuffle_stratifiedkfold():
     for (_, test0), (_, test1) in zip(kf0.split(X_40, y),
                                       kf1.split(X_40, y)):
         assert_not_equal(set(test0), set(test1))
-    check_cv_coverage(kf0, X_40, y, labels=None, expected_n_splits=5)
+    check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
 
 
 def test_kfold_can_detect_dependent_samples_on_digits():  # see #2372
@@ -596,7 +596,7 @@ def test_stratified_shuffle_split_even():
                         "to even draws")
 
     for n_samples in (6, 22):
-        labels = np.array((n_samples // 2) * [0, 1])
+        groups = np.array((n_samples // 2) * [0, 1])
         splits = StratifiedShuffleSplit(n_splits=n_splits,
                                         test_size=1. / n_folds,
                                         random_state=0)
@@ -604,7 +604,7 @@ def test_stratified_shuffle_split_even():
         train_counts = [0] * n_samples
         test_counts = [0] * n_samples
         n_splits_actual = 0
-        for train, test in splits.split(X=np.ones(n_samples), y=labels):
+        for train, test in splits.split(X=np.ones(n_samples), y=groups):
             n_splits_actual += 1
             for counter, ids in [(train_counts, train), (test_counts, test)]:
                 for id in ids:
@@ -618,10 +618,10 @@ def test_stratified_shuffle_split_even():
         assert_equal(len(test), n_test)
         assert_equal(len(set(train).intersection(test)), 0)
 
-        label_counts = np.unique(labels)
+        group_counts = np.unique(groups)
         assert_equal(splits.test_size, 1.0 / n_folds)
-        assert_equal(n_train + n_test, len(labels))
-        assert_equal(len(label_counts), 2)
+        assert_equal(n_train + n_test, len(groups))
+        assert_equal(len(group_counts), 2)
         ex_test_p = float(n_test) / n_samples
         ex_train_p = float(n_train) / n_samples
 
@@ -664,28 +664,28 @@ def test_predefinedsplit_with_kfold_split():
     assert_array_equal(ps_test, kf_test)
 
 
-def test_label_shuffle_split():
-    labels = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
+def test_group_shuffle_split():
+    groups = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
               np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
               np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
               np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4])]
 
-    for l in labels:
+    for l in groups:
         X = y = np.ones(len(l))
         n_splits = 6
-        test_size = 1. / 3
-        slo = LabelShuffleSplit(n_splits, test_size=test_size, random_state=0)
+        test_size = 1./3
+        slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0)
 
         # Make sure the repr works
         repr(slo)
 
         # Test that the length is correct
-        assert_equal(slo.get_n_splits(X, y, labels=l), n_splits)
+        assert_equal(slo.get_n_splits(X, y, groups=l), n_splits)
 
         l_unique = np.unique(l)
 
-        for train, test in slo.split(X, y, labels=l):
-            # First test: no train label is in the test set and vice versa
+        for train, test in slo.split(X, y, groups=l):
+            # First test: no train group is in the test set and vice versa
             l_train_unique = np.unique(l[train])
             l_test_unique = np.unique(l[test])
             assert_false(np.any(np.in1d(l[train], l_test_unique)))
@@ -698,33 +698,33 @@ def test_label_shuffle_split():
             assert_array_equal(np.intersect1d(train, test), [])
 
             # Fourth test:
-            # unique train and test labels are correct, +- 1 for rounding error
+            # unique train and test groups are correct, +- 1 for rounding error
             assert_true(abs(len(l_test_unique) -
                             round(test_size * len(l_unique))) <= 1)
             assert_true(abs(len(l_train_unique) -
                             round((1.0 - test_size) * len(l_unique))) <= 1)
 
 
-def test_leave_label_out_changing_labels():
-    # Check that LeaveOneLabelOut and LeavePLabelOut work normally if
-    # the labels variable is changed before calling split
-    labels = np.array([0, 1, 2, 1, 1, 2, 0, 0])
-    X = np.ones(len(labels))
-    labels_changing = np.array(labels, copy=True)
-    lolo = LeaveOneLabelOut().split(X, labels=labels)
-    lolo_changing = LeaveOneLabelOut().split(X, labels=labels)
-    lplo = LeavePLabelOut(n_labels=2).split(X, labels=labels)
-    lplo_changing = LeavePLabelOut(n_labels=2).split(X, labels=labels)
-    labels_changing[:] = 0
+def test_leave_group_out_changing_groups():
+    # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
+    # the groups variable is changed before calling split
+    groups = np.array([0, 1, 2, 1, 1, 2, 0, 0])
+    X = np.ones(len(groups))
+    groups_changing = np.array(groups, copy=True)
+    lolo = LeaveOneGroupOut().split(X, groups=groups)
+    lolo_changing = LeaveOneGroupOut().split(X, groups=groups)
+    lplo = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
+    lplo_changing = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
+    groups_changing[:] = 0
     for llo, llo_changing in [(lolo, lolo_changing), (lplo, lplo_changing)]:
         for (train, test), (train_chan, test_chan) in zip(llo, llo_changing):
             assert_array_equal(train, train_chan)
             assert_array_equal(test, test_chan)
 
-    # n_splits = no of 2 (p) label combinations of the unique labels = 3C2 = 3
-    assert_equal(3, LeavePLabelOut(n_labels=2).get_n_splits(X, y, labels))
-    # n_splits = no of unique labels (C(uniq_lbls, 1) = n_unique_labels)
-    assert_equal(3, LeaveOneLabelOut().get_n_splits(X, y, labels))
+    # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3
+    assert_equal(3, LeavePGroupsOut(n_groups=2).get_n_splits(X, y, groups))
+    # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups)
+    assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups))
 
 
 def test_train_test_split_errors():
@@ -931,11 +931,11 @@ def test_cv_iterable_wrapper():
     assert_equal(len(cv), wrapped_old_skf.get_n_splits())
 
 
-def test_label_kfold():
+def test_group_kfold():
     rng = np.random.RandomState(0)
 
     # Parameters of the test
-    n_labels = 15
+    n_groups = 15
     n_samples = 1000
     n_splits = 5
 
@@ -943,34 +943,34 @@ def test_label_kfold():
 
     # Construct the test data
     tolerance = 0.05 * n_samples  # 5 percent error allowed
-    labels = rng.randint(0, n_labels, n_samples)
+    groups = rng.randint(0, n_groups, n_samples)
 
-    ideal_n_labels_per_fold = n_samples // n_splits
+    ideal_n_groups_per_fold = n_samples // n_splits
 
-    len(np.unique(labels))
+    len(np.unique(groups))
     # Get the test fold indices from the test set indices of each fold
     folds = np.zeros(n_samples)
-    lkf = LabelKFold(n_splits=n_splits)
-    for i, (_, test) in enumerate(lkf.split(X, y, labels)):
+    lkf = GroupKFold(n_splits=n_splits)
+    for i, (_, test) in enumerate(lkf.split(X, y, groups)):
         folds[test] = i
 
     # Check that folds have approximately the same size
-    assert_equal(len(folds), len(labels))
+    assert_equal(len(folds), len(groups))
     for i in np.unique(folds):
         assert_greater_equal(tolerance,
-                             abs(sum(folds == i) - ideal_n_labels_per_fold))
+                             abs(sum(folds == i) - ideal_n_groups_per_fold))
 
-    # Check that each label appears only in 1 fold
-    for label in np.unique(labels):
-        assert_equal(len(np.unique(folds[labels == label])), 1)
+    # Check that each group appears only in 1 fold
+    for group in np.unique(groups):
+        assert_equal(len(np.unique(folds[groups == group])), 1)
 
-    # Check that no label is on both sides of the split
-    labels = np.asarray(labels, dtype=object)
-    for train, test in lkf.split(X, y, labels):
-        assert_equal(len(np.intersect1d(labels[train], labels[test])), 0)
+    # Check that no group is on both sides of the split
+    groups = np.asarray(groups, dtype=object)
+    for train, test in lkf.split(X, y, groups):
+        assert_equal(len(np.intersect1d(groups[train], groups[test])), 0)
 
     # Construct the test data
-    labels = np.array(['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean',
+    groups = np.array(['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean',
                        'Francis', 'Robert', 'Michel', 'Rachel', 'Lois',
                        'Michelle', 'Bernard', 'Marion', 'Laura', 'Jean',
                        'Rachel', 'Franck', 'John', 'Gael', 'Anna', 'Alix',
@@ -978,41 +978,41 @@ def test_label_kfold():
                        'Madmood', 'Cary', 'Mary', 'Alexandre', 'David',
                        'Francis', 'Barack', 'Abdoul', 'Rasha', 'Xi', 'Silvia'])
 
-    n_labels = len(np.unique(labels))
-    n_samples = len(labels)
+    n_groups = len(np.unique(groups))
+    n_samples = len(groups)
     n_splits = 5
     tolerance = 0.05 * n_samples  # 5 percent error allowed
-    ideal_n_labels_per_fold = n_samples // n_splits
+    ideal_n_groups_per_fold = n_samples // n_splits
 
     X = y = np.ones(n_samples)
 
     # Get the test fold indices from the test set indices of each fold
     folds = np.zeros(n_samples)
-    for i, (_, test) in enumerate(lkf.split(X, y, labels)):
+    for i, (_, test) in enumerate(lkf.split(X, y, groups)):
         folds[test] = i
 
     # Check that folds have approximately the same size
-    assert_equal(len(folds), len(labels))
+    assert_equal(len(folds), len(groups))
     for i in np.unique(folds):
         assert_greater_equal(tolerance,
-                             abs(sum(folds == i) - ideal_n_labels_per_fold))
+                             abs(sum(folds == i) - ideal_n_groups_per_fold))
 
-    # Check that each label appears only in 1 fold
+    # Check that each group appears only in 1 fold
     with warnings.catch_warnings():
         warnings.simplefilter("ignore", DeprecationWarning)
-        for label in np.unique(labels):
-            assert_equal(len(np.unique(folds[labels == label])), 1)
+        for group in np.unique(groups):
+            assert_equal(len(np.unique(folds[groups == group])), 1)
 
-    # Check that no label is on both sides of the split
-    labels = np.asarray(labels, dtype=object)
-    for train, test in lkf.split(X, y, labels):
-        assert_equal(len(np.intersect1d(labels[train], labels[test])), 0)
+    # Check that no group is on both sides of the split
+    groups = np.asarray(groups, dtype=object)
+    for train, test in lkf.split(X, y, groups):
+        assert_equal(len(np.intersect1d(groups[train], groups[test])), 0)
 
-    # Should fail if there are more folds than labels
-    labels = np.array([1, 1, 1, 2, 2])
-    X = y = np.ones(len(labels))
+    # Should fail if there are more folds than groups
+    groups = np.array([1, 1, 1, 2, 2])
+    X = y = np.ones(len(groups))
     assert_raises_regexp(ValueError, "Cannot have number of splits.*greater",
-                         next, LabelKFold(n_splits=3).split(X, y, labels))
+                         next, GroupKFold(n_splits=3).split(X, y, groups))
 
 
 def test_time_series_cv():
@@ -1058,16 +1058,16 @@ def test_nested_cv():
     rng = np.random.RandomState(0)
 
     X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
-    labels = rng.randint(0, 5, 15)
+    groups = rng.randint(0, 5, 15)
 
-    cvs = [LeaveOneLabelOut(), LeaveOneOut(), LabelKFold(), StratifiedKFold(),
+    cvs = [LeaveOneGroupOut(), LeaveOneOut(), GroupKFold(), StratifiedKFold(),
            StratifiedShuffleSplit(n_splits=3, random_state=0)]
 
     for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
         gs = GridSearchCV(Ridge(), param_grid={'alpha': [1, .1]},
                           cv=inner_cv)
-        cross_val_score(gs, X=X, y=y, labels=labels, cv=outer_cv,
-                        fit_params={'labels': labels})
+        cross_val_score(gs, X=X, y=y, groups=groups, cv=outer_cv,
+                        fit_params={'groups': groups})
 
 
 def test_build_repr():
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index c8afc07fffa917e0b31a2ffefde38c80dbc25061..67937711ec2d8d4f353ac468d76ca374ae7a2302 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -29,10 +29,10 @@ from sklearn.model_selection import permutation_test_score
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
 from sklearn.model_selection import LeaveOneOut
-from sklearn.model_selection import LeaveOneLabelOut
-from sklearn.model_selection import LeavePLabelOut
-from sklearn.model_selection import LabelKFold
-from sklearn.model_selection import LabelShuffleSplit
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import GroupShuffleSplit
 from sklearn.model_selection import learning_curve
 from sklearn.model_selection import validation_curve
 from sklearn.model_selection._validation import _check_is_permutation
@@ -181,22 +181,22 @@ def test_cross_val_score():
     assert_raises(ValueError, cross_val_score, clf, X_3d, y2)
 
 
-def test_cross_val_score_predict_labels():
-    # Check if ValueError (when labels is None) propagates to cross_val_score
+def test_cross_val_score_predict_groups():
+    # Check if ValueError (when groups is None) propagates to cross_val_score
     # and cross_val_predict
-    # And also check if labels is correctly passed to the cv object
+    # And also check if groups is correctly passed to the cv object
     X, y = make_classification(n_samples=20, n_classes=2, random_state=0)
 
     clf = SVC(kernel="linear")
 
-    label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(),
-                 LabelShuffleSplit()]
-    for cv in label_cvs:
+    group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), GroupKFold(),
+                 GroupShuffleSplit()]
+    for cv in group_cvs:
         assert_raise_message(ValueError,
-                             "The labels parameter should not be None",
+                             "The groups parameter should not be None",
                              cross_val_score, estimator=clf, X=X, y=y, cv=cv)
         assert_raise_message(ValueError,
-                             "The labels parameter should not be None",
+                             "The groups parameter should not be None",
                              cross_val_predict, estimator=clf, X=X, y=y, cv=cv)
 
 
@@ -372,21 +372,21 @@ def test_permutation_score():
     assert_greater(score, 0.9)
     assert_almost_equal(pvalue, 0.0, 1)
 
-    score_label, _, pvalue_label = permutation_test_score(
+    score_group, _, pvalue_group = permutation_test_score(
         svm, X, y, n_permutations=30, cv=cv, scoring="accuracy",
-        labels=np.ones(y.size), random_state=0)
-    assert_true(score_label == score)
-    assert_true(pvalue_label == pvalue)
+        groups=np.ones(y.size), random_state=0)
+    assert_true(score_group == score)
+    assert_true(pvalue_group == pvalue)
 
     # check that we obtain the same results with a sparse representation
     svm_sparse = SVC(kernel='linear')
     cv_sparse = StratifiedKFold(2)
-    score_label, _, pvalue_label = permutation_test_score(
+    score_group, _, pvalue_group = permutation_test_score(
         svm_sparse, X_sparse, y, n_permutations=30, cv=cv_sparse,
-        scoring="accuracy", labels=np.ones(y.size), random_state=0)
+        scoring="accuracy", groups=np.ones(y.size), random_state=0)
 
-    assert_true(score_label == score)
-    assert_true(pvalue_label == pvalue)
+    assert_true(score_group == score)
+    assert_true(pvalue_group == pvalue)
 
     # test with custom scoring object
     def custom_score(y_true, y_pred):
@@ -483,7 +483,7 @@ def test_cross_val_predict():
     assert_equal(len(preds), len(y))
 
     class BadCV():
-        def split(self, X, y=None, labels=None):
+        def split(self, X, y=None, groups=None):
             for i in range(4):
                 yield np.array([0, 1, 2, 3]), np.array([4, 5, 6, 7, 8])