From a08555a2384884c03d5deb509192a052c06caa85 Mon Sep 17 00:00:00 2001
From: "(Venkat) Raghav, Rajagopalan" <rvraghav93@gmail.com>
Date: Fri, 7 Jul 2017 17:12:31 +0200
Subject: [PATCH] [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al.
 to evaluate on multiple metrics (#7388)

* ENH cross_val_score now supports multiple metrics

* DOCFIX permutation_test_score

* ENH validate multiple metric scorers

* ENH Move validation of multimetric scoring param out

* ENH GridSearchCV and RandomizedSearchCV now support multiple metrics

* EXA Add an example demonstrating the multiple metric in GridSearchCV

* ENH Let check_multimetric_scoring tell if its multimetric or not

* FIX For single metric name of scorer should remain 'score'

* ENH validation_curve and learning_curve now support multiple metrics

* MNT move _aggregate_score_dicts helper into _validation.py

* TST More testing/ Fixing scores to the correct values

* EXA Add cross_val_score to multimetric example

* Rename to multiple_metric_evaluation.py

* MNT Remove scaffolding

* FIX doctest imports

* FIX wrap the scorer and unwrap the score when using _score() in rfe

* TST Cleanup the tests. Test for is_multimetric too

* TST Make sure it registers as single metric when scoring is of that type

* PEP8

* Don't use dict comprehension to make it work in python2.6

* ENH/FIX/TST grid_scores_ should not be available for multimetric evaluation

* FIX+TST delegated methods NA when multimetric is enabled...

TST Add general tests to GridSearchCV and RandomizedSearchCV

* ENH add option to disable delegation on multimetric scoring

* Remove old function from __all__

* flake8

* FIX revert disable_on_multimetric

* stash

* Fix incorrect rebase

* [ci skip]

* Make sure refit works as expected and remove irrelevant tests

* Allow passing standard scorers by name in multimetric scorers

* Fix example

* flake8

* Address reviews

* Fix indentation

* Ensure {'acc': 'accuracy'} and ['precision'] are valid inputs

* Test that for single metric, 'score' is a key

* Typos

* Fix incorrect rebase

* Compare multimetric grid search with multiple single metric searches

* Test X, y list and pandas input; Test multimetric for unsupervised grid search

* Fix tests; Unsupervised multimetric gs will not pass until #8117 is merged

* Make a plot of Precision vs ROC AUC for RandomForest varying the n_estimators

* Add example to grid_search.rst

* Use the classic tuning of C param in SVM instead of estimators in RF

* FIX Remove scoring arg in deafult scorer test

* flake8

* Search for min_samples_split in DTC; Also show f-score

* REVIEW Make check_multimetric_scoring private

* FIX Add more samples to see if 3% mismatch on 32 bit systems gets fixed

* REVIEW Plot best score; Shorten legends

* REVIEW/COSMIT multimetric --> multi-metric

* REVIEW Mark the best scores of P/R scores too

* Revert "FIX Add more samples to see if 3% mismatch on 32 bit systems gets fixed"

This reverts commit ba766d98353380a186fbc3dade211670ee72726d.

* ENH Use looping for iid testing

* FIX use param grid as scipy's stats dist in 0.12 do not accept seed

* ENH more looping less code; Use small non-noisy dataset

* FIX Use named arg after expanded args

* TST More testing of the refit parameter

* Test that in multimetric search refit to single metric, the delegated methods
  work as expected.
* Test that setting probability=False works with multimetric too
* Test refit=False gives sensible error

* COSMIT multimetric --> multi-metric

* REV Correct example doc

* COSMIT

* REVIEW Make tests stronger; Fix bugs in _check_multimetric_scorer

* REVIEW refit param: Raise for empty strings

* TST Invalid refit params

* REVIEW Use <scorer_name> alone; recall --> Recall

* REV specify when we expect scorers to not be None

* FLAKE8

* REVERT multimetrics in learning_curve and validation_curve

* REVIEW Simpler coding style

* COSMIT

* COSMIT

* REV Compress example a bit. Move comment to top

* FIX fit_grid_point's previous API must be preserved

* Flake8

* TST Use loop; Compare with single-metric

* REVIEW Use dict-comprehension instead of helper

* REVIEW Remove redundant test

* Fix tests incorrect braces

* COSMIT

* REVIEW Use regexp

* REV Simplify aggregation of score dicts

* FIX precision and accuracy test

* FIX doctest and flake8

* TST the best_* attributes multimetric with single metric

* Address @jnothman's review

* Address more comments \o/

* DOCFIXES

* Fix use the validated fit_param from fit's arguments

* Revert alpha to a lower value as before

* Using def instead of lambda

* Address @jnothman's review batch 1: Fix tests / Doc fixes

* Remove superfluous tests

* Remove more superfluous testing

* TST/FIX loop over refit and check found n_clusters

* Cosmetic touches

* Use zip instead of manually listing the keys

* Fix inverse_transform

* FIX bug in fit_grid_point; Allow only single score

TST if fit_grid_point works as intended

* ENH Use only ROC-AUC and F1-score

* Fix typos and flake8; Address Andy's reviews

MNT Add a comment on why we do such a transpose + some fixes

* ENH Better error messages for incorrect multimetric scoring values +...

ENH Avoid exception traceback while using incorrect scoring string

* Dict keys must be of string type only

* 1. Better error message for invalid scoring 2...
Internal functions return single score for single metric scoring

* Fix test failures and shuffle tests

* Avoid wrapping scorer as dict in learning_curve

* Remove doc example as asked for

* Some leftover ones

* Don't wrap scorer in validation_curve either

* Add a doc example and skip it as dict order fails doctest

* Import zip from six for python2.7 compat

* Make cross_val_score return a cv_results-like dict

* Add relevant sections to userguide

* Flake8 fixes

* Add whatsnew and fix broken links

* Use AUC and accuracy instead of f1

* Fix failing doctests cross_validation.rst

* DOC add the wrapper example for metrics that return multiple return values

* Address andy's comments

* Be less weird

* Address more of andy's comments

* Make a separate cross_validate function to return dict and a cross_val_score

* Update the docs to reflect the new cross_validate function

* Add cross_validate to toc-tree

* Add more tests on type of cross_validate return and time limits

* FIX failing doctests

* FIX ensure keys are not plural

* DOC fix

* Address some pending comments

* Remove the comment as it is irrelevant now

* Remove excess blank line

* Fix flake8 inconsistencies

* Allow fit_times to be 0 to conform with windows precision

* DOC specify how refit param is to be set in multiple metric case

* TST ensure cross_validate works for string single metrics + address @jnothman's reviews

* Doc fixes

* Remove the shape and transform parameter of _aggregate_score_dicts

* Address Joel's doc comments

* Fix broken doctest

* Fix the spurious file

* Address Andy's comments

* MNT Remove erroneous entry

* Address Andy's comments

* FIX broken links

* Update whats_new.rst

missing newline
---
 doc/modules/classes.rst                       |   1 +
 doc/modules/cross_validation.rst              |  61 ++-
 doc/modules/grid_search.rst                   |  25 ++
 doc/modules/model_evaluation.rst              |  45 +++
 doc/whats_new.rst                             |  13 +
 .../plot_multi_metric_evaluation.py           |  94 +++++
 sklearn/metrics/scorer.py                     | 120 +++++-
 sklearn/metrics/tests/test_score_objects.py   |  99 ++++-
 sklearn/model_selection/__init__.py           |   2 +
 sklearn/model_selection/_search.py            | 306 ++++++++++----
 sklearn/model_selection/_validation.py        | 379 +++++++++++++++---
 sklearn/model_selection/tests/test_search.py  | 297 ++++++++++----
 .../model_selection/tests/test_validation.py  | 201 +++++++++-
 13 files changed, 1406 insertions(+), 237 deletions(-)
 create mode 100644 examples/model_selection/plot_multi_metric_evaluation.py

diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 5399e27ef4..7275789c19 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -223,6 +223,7 @@ Model validation
    :toctree: generated/
    :template: function.rst
 
+   model_selection.cross_validate
    model_selection.cross_val_score
    model_selection.cross_val_predict
    model_selection.permutation_test_score
diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index cc5f6a3c07..ab7d222744 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -172,6 +172,65 @@ validation iterator instead, for instance::
 
     See :ref:`combining_estimators`.
 
+
+.. _multimetric_cross_validation:
+
+The cross_validate function and multiple metric evaluation
+----------------------------------------------------------
+
+The ``cross_validate`` function differs from ``cross_val_score`` in two ways -
+
+- It allows specifying multiple metrics for evaluation.
+
+- It returns a dict containing training scores, fit-times and score-times in
+  addition to the test score.
+
+For single metric evaluation, where the scoring parameter is a string,
+callable or None, the keys will be - ``['test_score', 'fit_time', 'score_time']``
+
+And for multiple metric evaluation, the return value is a dict with the
+following keys -
+``['test_<scorer1_name>', 'test_<scorer2_name>', 'test_<scorer...>', 'fit_time', 'score_time']``
+
+``return_train_score`` is set to ``True`` by default. It adds train score keys
+for all the scorers. If train scores are not needed, this should be set to
+``False`` explicitly.
+
+The multiple metrics can be specified either as a list, tuple or set of
+predefined scorer names::
+
+    >>> from sklearn.model_selection import cross_validate
+    >>> from sklearn.metrics import recall_score
+    >>> scoring = ['precision_macro', 'recall_macro']
+    >>> clf = svm.SVC(kernel='linear', C=1, random_state=0)
+    >>> scores = cross_validate(clf, iris.data, iris.target, scoring=scoring,
+    ...                         cv=5, return_train_score=False)
+    >>> sorted(scores.keys())
+    ['fit_time', 'score_time', 'test_precision_macro', 'test_recall_macro']
+    >>> scores['test_recall_macro']                       # doctest: +ELLIPSIS
+    array([ 0.96...,  1.  ...,  0.96...,  0.96...,  1.        ])
+
+Or as a dict mapping scorer name to a predefined or custom scoring function::
+
+    >>> from sklearn.metrics.scorer import make_scorer
+    >>> scoring = {'prec_macro': 'precision_macro',
+    ...            'rec_micro': make_scorer(recall_score, average='macro')}
+    >>> scores = cross_validate(clf, iris.data, iris.target, scoring=scoring,
+    ...                         cv=5, return_train_score=True)
+    >>> sorted(scores.keys())                 # doctest: +NORMALIZE_WHITESPACE
+    ['fit_time', 'score_time', 'test_prec_macro', 'test_rec_micro',
+     'train_prec_macro', 'train_rec_micro']
+    >>> scores['train_rec_micro']                         # doctest: +ELLIPSIS
+    array([ 0.97...,  0.97...,  0.99...,  0.98...,  0.98...])
+
+Here is an example of ``cross_validate`` using a single metric::
+
+    >>> scores = cross_validate(clf, iris.data, iris.target,
+    ...                         scoring='precision_macro')
+    >>> sorted(scores.keys())
+    ['fit_time', 'score_time', 'test_score', 'train_score']
+
+
 Obtaining predictions by cross-validation
 -----------------------------------------
 
@@ -186,7 +245,7 @@ These prediction can then be used to evaluate the classifier::
   >>> from sklearn.model_selection import cross_val_predict
   >>> predicted = cross_val_predict(clf, iris.data, iris.target, cv=10)
   >>> metrics.accuracy_score(iris.target, predicted) # doctest: +ELLIPSIS
-  0.966...
+  0.973...
 
 Note that the result of this computation may be slightly different from those
 obtained using :func:`cross_val_score` as the elements are grouped in different
diff --git a/doc/modules/grid_search.rst b/doc/modules/grid_search.rst
index 48870a80a6..1867a66594 100644
--- a/doc/modules/grid_search.rst
+++ b/doc/modules/grid_search.rst
@@ -84,6 +84,10 @@ evaluated and the best combination is retained.
       dataset. This is the best practice for evaluating the performance of a
       model with grid search.
 
+    - See :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation`
+      for an example of :class:`GridSearchCV` being used to evaluate multiple
+      metrics simultaneously.
+
 .. _randomized_parameter_search:
 
 Randomized Parameter Optimization
@@ -161,6 +165,27 @@ scoring function can be specified via the ``scoring`` parameter to
 specialized cross-validation tools described below.
 See :ref:`scoring_parameter` for more details.
 
+.. _multimetric_grid_search:
+
+Specifying multiple metrics for evaluation
+------------------------------------------
+
+``GridSearchCV`` and ``RandomizedSearchCV`` allow specifying multiple metrics
+for the ``scoring`` parameter.
+
+Multimetric scoring can either be specified as a list of strings of predefined
+scores names or a dict mapping the scorer name to the scorer function and/or
+the predefined scorer name(s). See :ref:`multimetric_scoring` for more details.
+
+When specifying multiple metrics, the ``refit`` parameter must be set to the
+metric (string) for which the ``best_params_`` will be found and used to build
+the ``best_estimator_`` on the whole dataset. If the search should not be
+refit, set ``refit=False``. Leaving refit to the default value ``None`` will
+result in an error when using multiple metrics.
+
+See :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation`
+for an example usage.
+
 Composite estimators and parameter spaces
 -----------------------------------------
 
diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst
index c544175861..dee5865bdd 100644
--- a/doc/modules/model_evaluation.rst
+++ b/doc/modules/model_evaluation.rst
@@ -210,6 +210,51 @@ the following two rules:
   Again, by convention higher numbers are better, so if your scorer
   returns loss, that value should be negated.
 
+.. _multimetric_scoring:
+
+Using mutiple metric evaluation
+-------------------------------
+
+Scikit-learn also permits evaluation of multiple metrics in ``GridSearchCV``,
+``RandomizedSearchCV`` and ``cross_validate``.
+
+There are two ways to specify multiple scoring metrics for the ``scoring``
+parameter:
+
+- As an iterable of string metrics::
+      >>> scoring = ['accuracy', 'precision']
+
+- As a ``dict`` mapping the scorer name to the scoring function::
+      >>> from sklearn.metrics import accuracy_score
+      >>> from sklearn.metrics import make_scorer
+      >>> scoring = {'accuracy': make_scorer(accuracy_score),
+      ...            'prec': 'precision'}
+
+Note that the dict values can either be scorer functions or one of the
+predefined metric strings.
+
+Currently only those scorer functions that return a single score can be passed
+inside the dict. Scorer functions that return multiple values are not
+permitted and will require a wrapper to return a single metric::
+
+    >>> from sklearn.model_selection import cross_validate
+    >>> from sklearn.metrics import confusion_matrix
+    >>> # A sample toy binary classification dataset
+    >>> X, y = datasets.make_classification(n_classes=2, random_state=0)
+    >>> svm = LinearSVC(random_state=0)
+    >>> tp = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 0]
+    >>> tn = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 0]
+    >>> fp = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[1, 0]
+    >>> fn = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 1]
+    >>> scoring = {'tp' : make_scorer(tp), 'tn' : make_scorer(tn),
+    ...            'fp' : make_scorer(fp), 'fn' : make_scorer(fn)}
+    >>> cv_results = cross_validate(svm.fit(X, y), X, y, scoring=scoring)
+    >>> # Getting the test set false positive scores
+    >>> print(cv_results['test_tp'])          # doctest: +NORMALIZE_WHITESPACE
+    [12 13 15]
+    >>> # Getting the test set false negative scores
+    >>> print(cv_results['test_fn'])          # doctest: +NORMALIZE_WHITESPACE
+    [5 4 1]
 
 .. _classification_metrics:
 
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 0203511348..0c5608d6b5 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -31,6 +31,19 @@ Changelog
 New features
 ............
 
+   - :class:`model_selection.GridSearchCV` and
+     :class:`model_selection.RandomizedSearchCV` now support simultaneous
+     evaluation of multiple metrics. Refer to the
+     :ref:`multimetric_grid_search` section of the user guide for more
+     information. :issue:`7388` by `Raghav RV`_
+
+   - Added the :func:`model_selection.cross_validate` which allows evaluation
+     of multiple metrics. This function returns a dict with more useful
+     information from cross-validation such as the train scores, fit times and
+     score times.
+     Refer to :ref:`multimetric_cross_validation` section of the userguide
+     for more information. :issue:`7388` by `Raghav RV`_
+     
    - Added :class:`multioutput.ClassifierChain` for multi-label
      classification. By `Adam Kleczewski <adamklec>`_.
 
diff --git a/examples/model_selection/plot_multi_metric_evaluation.py b/examples/model_selection/plot_multi_metric_evaluation.py
new file mode 100644
index 0000000000..5f4491e51f
--- /dev/null
+++ b/examples/model_selection/plot_multi_metric_evaluation.py
@@ -0,0 +1,94 @@
+"""Demonstration of multi-metric evaluation on cross_val_score and GridSearchCV
+
+Multiple metric parameter search can be done by setting the ``scoring``
+parameter to a list of metric scorer names or a dict mapping the scorer names
+to the scorer callables.
+
+The scores of all the scorers are available in the ``cv_results_`` dict at keys
+ending in ``'_<scorer_name>'`` (``'mean_test_precision'``,
+``'rank_test_precision'``, etc...)
+
+The ``best_estimator_``, ``best_index_``, ``best_score_`` and ``best_params_``
+correspond to the scorer (key) that is set to the ``refit`` attribute.
+"""
+
+# Author: Raghav RV <rvraghav93@gmail.com>
+# License: BSD
+
+import numpy as np
+from matplotlib import pyplot as plt
+
+from sklearn.datasets import make_hastie_10_2
+from sklearn.model_selection import GridSearchCV
+from sklearn.metrics import make_scorer
+from sklearn.metrics import accuracy_score
+from sklearn.tree import DecisionTreeClassifier
+
+print(__doc__)
+
+###############################################################################
+# Running ``GridSearchCV`` using multiple evaluation metrics
+# ----------------------------------------------------------
+#
+
+X, y = make_hastie_10_2(n_samples=8000, random_state=42)
+
+# The scorers can be either be one of the predefined metric strings or a scorer
+# callable, like the one returned by make_scorer
+scoring = {'AUC': 'roc_auc', 'Accuracy': make_scorer(accuracy_score)}
+
+# Setting refit='AUC', refits an estimator on the whole dataset with the
+# parameter setting that has the best cross-validated AUC score.
+# That estimator is made available at ``gs.best_estimator_`` along with
+# parameters like ``gs.best_score_``, ``gs.best_parameters_`` and
+# ``gs.best_index_``
+gs = GridSearchCV(DecisionTreeClassifier(random_state=42),
+                  param_grid={'min_samples_split': range(2, 403, 10)},
+                  scoring=scoring, cv=5, refit='AUC')
+gs.fit(X, y)
+results = gs.cv_results_
+
+###############################################################################
+# Plotting the result
+# -------------------
+
+plt.figure(figsize=(13, 13))
+plt.title("GridSearchCV evaluating using multiple scorers simultaneously",
+          fontsize=16)
+
+plt.xlabel("min_samples_split")
+plt.ylabel("Score")
+plt.grid()
+
+ax = plt.axes()
+ax.set_xlim(0, 402)
+ax.set_ylim(0.73, 1)
+
+# Get the regular numpy array from the MaskedArray
+X_axis = np.array(results['param_min_samples_split'].data, dtype=float)
+
+for scorer, color in zip(sorted(scoring), ['g', 'k']):
+    for sample, style in (('train', '--'), ('test', '-')):
+        sample_score_mean = results['mean_%s_%s' % (sample, scorer)]
+        sample_score_std = results['std_%s_%s' % (sample, scorer)]
+        ax.fill_between(X_axis, sample_score_mean - sample_score_std,
+                        sample_score_mean + sample_score_std,
+                        alpha=0.1 if sample == 'test' else 0, color=color)
+        ax.plot(X_axis, sample_score_mean, style, color=color,
+                alpha=1 if sample == 'test' else 0.7,
+                label="%s (%s)" % (scorer, sample))
+
+    best_index = np.nonzero(results['rank_test_%s' % scorer] == 1)[0][0]
+    best_score = results['mean_test_%s' % scorer][best_index]
+
+    # Plot a dotted vertical line at the best score for that scorer marked by x
+    ax.plot([X_axis[best_index], ] * 2, [0, best_score],
+            linestyle='-.', color=color, marker='x', markeredgewidth=3, ms=8)
+
+    # Annotate the best score for that scorer
+    ax.annotate("%0.2f" % best_score,
+                (X_axis[best_index], best_score + 0.005))
+
+plt.legend(loc="best")
+plt.grid('off')
+plt.show()
diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py
index 3a163d967c..1d16a9dcb0 100644
--- a/sklearn/metrics/scorer.py
+++ b/sklearn/metrics/scorer.py
@@ -209,12 +209,15 @@ class _ThresholdScorer(_BaseScorer):
 
 
 def get_scorer(scoring):
+    valid = True
     if isinstance(scoring, six.string_types):
         try:
             scorer = SCORERS[scoring]
         except KeyError:
             scorers = [scorer for scorer in SCORERS
                        if SCORERS[scorer]._deprecation_msg is None]
+            valid = False  # Don't raise here to make the error message elegant
+        if not valid:
             raise ValueError('%r is not a valid scoring value. '
                              'Valid options are %s'
                              % (scoring, sorted(scorers)))
@@ -253,13 +256,12 @@ def check_scoring(estimator, scoring=None, allow_none=False):
         A scorer callable object / function with signature
         ``scorer(estimator, X, y)``.
     """
-    has_scoring = scoring is not None
     if not hasattr(estimator, 'fit'):
         raise TypeError("estimator should be an estimator implementing "
                         "'fit' method, %r was passed" % estimator)
     if isinstance(scoring, six.string_types):
         return get_scorer(scoring)
-    elif has_scoring:
+    elif callable(scoring):
         # Heuristic to ensure user has not passed a metric
         module = getattr(scoring, '__module__', None)
         if hasattr(module, 'startswith') and \
@@ -272,14 +274,114 @@ def check_scoring(estimator, scoring=None, allow_none=False):
                              'Please use `make_scorer` to convert a metric '
                              'to a scorer.' % scoring)
         return get_scorer(scoring)
-    elif hasattr(estimator, 'score'):
-        return _passthrough_scorer
-    elif allow_none:
-        return None
+    elif scoring is None:
+        if hasattr(estimator, 'score'):
+            return _passthrough_scorer
+        elif allow_none:
+            return None
+        else:
+            raise TypeError(
+                "If no scoring is specified, the estimator passed should "
+                "have a 'score' method. The estimator %r does not."
+                % estimator)
     else:
-        raise TypeError(
-            "If no scoring is specified, the estimator passed should "
-            "have a 'score' method. The estimator %r does not." % estimator)
+        raise ValueError("scoring value should either be a callable, string or"
+                         " None. %r was passed" % scoring)
+
+
+def _check_multimetric_scoring(estimator, scoring=None):
+    """Check the scoring parameter in cases when multiple metrics are allowed
+
+    Parameters
+    ----------
+    estimator : sklearn estimator instance
+        The estimator for which the scoring will be applied.
+
+    scoring : string, callable, list/tuple, dict or None, default: None
+        A single string (see :ref:`scoring_parameter`) or a callable
+        (see :ref:`scoring`) to evaluate the predictions on the test set.
+
+        For evaluating multiple metrics, either give a list of (unique) strings
+        or a dict with names as keys and callables as values.
+
+        NOTE that when using custom scorers, each scorer should return a single
+        value. Metric functions returning a list/array of values can be wrapped
+        into multiple scorers that return one value each.
+
+        See :ref:`multivalued_scorer_wrapping` for an example.
+
+        If None the estimator's default scorer (if available) is used.
+        The return value in that case will be ``{'score': <default_scorer>}``.
+        If the estimator's default scorer is not available, a ``TypeError``
+        is raised.
+
+    Returns
+    -------
+    scorers_dict : dict
+        A dict mapping each scorer name to its validated scorer.
+
+    is_multimetric : bool
+        True if scorer is a list/tuple or dict of callables
+        False if scorer is None/str/callable
+    """
+    if callable(scoring) or scoring is None or isinstance(scoring,
+                                                          six.string_types):
+        scorers = {"score": check_scoring(estimator, scoring=scoring)}
+        return scorers, False
+    else:
+        err_msg_generic = ("scoring should either be a single string or "
+                           "callable for single metric evaluation or a "
+                           "list/tuple of strings or a dict of scorer name "
+                           "mapped to the callable for multiple metric "
+                           "evaluation. Got %s of type %s"
+                           % (repr(scoring), type(scoring)))
+
+        if isinstance(scoring, (list, tuple, set)):
+            err_msg = ("The list/tuple elements must be unique "
+                       "strings of predefined scorers. ")
+            invalid = False
+            try:
+                keys = set(scoring)
+            except TypeError:
+                invalid = True
+            if invalid:
+                raise ValueError(err_msg)
+
+            if len(keys) != len(scoring):
+                raise ValueError(err_msg + "Duplicate elements were found in"
+                                 " the given list. %r" % repr(scoring))
+            elif len(keys) > 0:
+                if not all(isinstance(k, six.string_types) for k in keys):
+                    if any(callable(k) for k in keys):
+                        raise ValueError(err_msg +
+                                         "One or more of the elements were "
+                                         "callables. Use a dict of score name "
+                                         "mapped to the scorer callable. "
+                                         "Got %r" % repr(scoring))
+                    else:
+                        raise ValueError(err_msg +
+                                         "Non-string types were found in "
+                                         "the given list. Got %r"
+                                         % repr(scoring))
+                scorers = {scorer: check_scoring(estimator, scoring=scorer)
+                           for scorer in scoring}
+            else:
+                raise ValueError(err_msg +
+                                 "Empty list was given. %r" % repr(scoring))
+
+        elif isinstance(scoring, dict):
+            keys = set(scoring)
+            if not all(isinstance(k, six.string_types) for k in keys):
+                raise ValueError("Non-string types were found in the keys of "
+                                 "the given dict. scoring=%r" % repr(scoring))
+            if len(keys) == 0:
+                raise ValueError("An empty dict was passed. %r"
+                                 % repr(scoring))
+            scorers = {key: check_scoring(estimator, scoring=scorer)
+                       for key, scorer in scoring.items()}
+        else:
+            raise ValueError(err_msg_generic)
+        return scorers, True
 
 
 def make_scorer(score_func, greater_is_better=True, needs_proba=False,
diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py
index 461bdadf3d..47c4d334f8 100644
--- a/sklearn/metrics/tests/test_score_objects.py
+++ b/sklearn/metrics/tests/test_score_objects.py
@@ -8,9 +8,11 @@ import numpy as np
 
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_array_equal
+from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_raises_regexp
 from sklearn.utils.testing import assert_true
+from sklearn.utils.testing import assert_false
 from sklearn.utils.testing import ignore_warnings
 from sklearn.utils.testing import assert_not_equal
 from sklearn.utils.testing import assert_warns_message
@@ -21,6 +23,8 @@ from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
 from sklearn.metrics import cluster as cluster_module
 from sklearn.metrics.scorer import (check_scoring, _PredictScorer,
                                     _passthrough_scorer)
+from sklearn.metrics import accuracy_score
+from sklearn.metrics.scorer import _check_multimetric_scoring
 from sklearn.metrics import make_scorer, get_scorer, SCORERS
 from sklearn.svm import LinearSVC
 from sklearn.pipeline import make_pipeline
@@ -104,18 +108,18 @@ def teardown_module():
 
 
 class EstimatorWithoutFit(object):
-    """Dummy estimator to test check_scoring"""
+    """Dummy estimator to test scoring validators"""
     pass
 
 
 class EstimatorWithFit(BaseEstimator):
-    """Dummy estimator to test check_scoring"""
+    """Dummy estimator to test scoring validators"""
     def fit(self, X, y):
         return self
 
 
 class EstimatorWithFitAndScore(object):
-    """Dummy estimator to test check_scoring"""
+    """Dummy estimator to test scoring validators"""
     def fit(self, X, y):
         return self
 
@@ -124,7 +128,7 @@ class EstimatorWithFitAndScore(object):
 
 
 class EstimatorWithFitAndPredict(object):
-    """Dummy estimator to test check_scoring"""
+    """Dummy estimator to test scoring validators"""
     def fit(self, X, y):
         self.y = y
         return self
@@ -145,16 +149,16 @@ def test_all_scorers_repr():
         repr(scorer)
 
 
-def test_check_scoring():
-    # Test all branches of check_scoring
+def check_scoring_validator_for_single_metric_usecases(scoring_validator):
+    # Test all branches of single metric usecases
     estimator = EstimatorWithoutFit()
     pattern = (r"estimator should be an estimator implementing 'fit' method,"
                r" .* was passed")
-    assert_raises_regexp(TypeError, pattern, check_scoring, estimator)
+    assert_raises_regexp(TypeError, pattern, scoring_validator, estimator)
 
     estimator = EstimatorWithFitAndScore()
     estimator.fit([[1]], [1])
-    scorer = check_scoring(estimator)
+    scorer = scoring_validator(estimator)
     assert_true(scorer is _passthrough_scorer)
     assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0)
 
@@ -162,18 +166,85 @@ def test_check_scoring():
     estimator.fit([[1]], [1])
     pattern = (r"If no scoring is specified, the estimator passed should have"
                r" a 'score' method\. The estimator .* does not\.")
-    assert_raises_regexp(TypeError, pattern, check_scoring, estimator)
+    assert_raises_regexp(TypeError, pattern, scoring_validator, estimator)
 
-    scorer = check_scoring(estimator, "accuracy")
+    scorer = scoring_validator(estimator, "accuracy")
     assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0)
 
     estimator = EstimatorWithFit()
-    scorer = check_scoring(estimator, "accuracy")
+    scorer = scoring_validator(estimator, "accuracy")
     assert_true(isinstance(scorer, _PredictScorer))
 
-    estimator = EstimatorWithFit()
-    scorer = check_scoring(estimator, allow_none=True)
-    assert_true(scorer is None)
+    # Test the allow_none parameter for check_scoring alone
+    if scoring_validator is check_scoring:
+        estimator = EstimatorWithFit()
+        scorer = scoring_validator(estimator, allow_none=True)
+        assert_true(scorer is None)
+
+
+def check_multimetric_scoring_single_metric_wrapper(*args, **kwargs):
+    # This wraps the _check_multimetric_scoring to take in single metric
+    # scoring parameter so we can run the tests that we will run for
+    # check_scoring, for check_multimetric_scoring too for single-metric
+    # usecases
+    scorers, is_multi = _check_multimetric_scoring(*args, **kwargs)
+    # For all single metric use cases, it should register as not multimetric
+    assert_false(is_multi)
+    if args[0] is not None:
+        assert_true(scorers is not None)
+        names, scorers = zip(*scorers.items())
+        assert_equal(len(scorers), 1)
+        assert_equal(names[0], 'score')
+        scorers = scorers[0]
+    return scorers
+
+
+def test_check_scoring_and_check_multimetric_scoring():
+    check_scoring_validator_for_single_metric_usecases(check_scoring)
+    # To make sure the check_scoring is correctly applied to the constituent
+    # scorers
+    check_scoring_validator_for_single_metric_usecases(
+        check_multimetric_scoring_single_metric_wrapper)
+
+    # For multiple metric use cases
+    # Make sure it works for the valid cases
+    for scoring in (('accuracy',), ['precision'],
+                    {'acc': 'accuracy', 'precision': 'precision'},
+                    ('accuracy', 'precision'), ['precision', 'accuracy'],
+                    {'accuracy': make_scorer(accuracy_score),
+                     'precision': make_scorer(precision_score)}):
+        estimator = LinearSVC(random_state=0)
+        estimator.fit([[1], [2], [3]], [1, 1, 0])
+
+        scorers, is_multi = _check_multimetric_scoring(estimator, scoring)
+        assert_true(is_multi)
+        assert_true(isinstance(scorers, dict))
+        assert_equal(sorted(scorers.keys()), sorted(list(scoring)))
+        assert_true(all([isinstance(scorer, _PredictScorer)
+                         for scorer in list(scorers.values())]))
+
+        if 'acc' in scoring:
+            assert_almost_equal(scorers['acc'](
+                estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
+        if 'accuracy' in scoring:
+            assert_almost_equal(scorers['accuracy'](
+                estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
+        if 'precision' in scoring:
+            assert_almost_equal(scorers['precision'](
+                estimator, [[1], [2], [3]], [1, 0, 0]), 0.5)
+
+    estimator = EstimatorWithFitAndPredict()
+    estimator.fit([[1]], [1])
+
+    # Make sure it raises errors when scoring parameter is not valid.
+    # More weird corner cases are tested at test_validation.py
+    error_message_regexp = ".*must be unique strings.*"
+    for scoring in ((make_scorer(precision_score),  # Tuple of callables
+                     make_scorer(accuracy_score)), [5],
+                    (make_scorer(precision_score),), (), ('f1', 'f1')):
+        assert_raises_regexp(ValueError, error_message_regexp,
+                             _check_multimetric_scoring, estimator,
+                             scoring=scoring)
 
 
 def test_check_scoring_gridsearchcv():
diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py
index 73c842e706..82a9b93717 100644
--- a/sklearn/model_selection/__init__.py
+++ b/sklearn/model_selection/__init__.py
@@ -18,6 +18,7 @@ from ._split import check_cv
 
 from ._validation import cross_val_score
 from ._validation import cross_val_predict
+from ._validation import cross_validate
 from ._validation import learning_curve
 from ._validation import permutation_test_score
 from ._validation import validation_curve
@@ -50,6 +51,7 @@ __all__ = ('BaseCrossValidator',
            'check_cv',
            'cross_val_predict',
            'cross_val_score',
+           'cross_validate',
            'fit_grid_point',
            'learning_curve',
            'permutation_test_score',
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 67bd8597de..17c588c293 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -9,6 +9,7 @@ from __future__ import division
 #         Gael Varoquaux <gael.varoquaux@normalesup.org>
 #         Andreas Mueller <amueller@ais.uni-bonn.de>
 #         Olivier Grisel <olivier.grisel@ensta.org>
+#         Raghav RV <rvraghav93@gmail.com>
 # License: BSD 3 clause
 
 from abc import ABCMeta, abstractmethod
@@ -25,6 +26,7 @@ from ..base import BaseEstimator, is_classifier, clone
 from ..base import MetaEstimatorMixin
 from ._split import check_cv
 from ._validation import _fit_and_score
+from ._validation import _aggregate_score_dicts
 from ..exceptions import NotFittedError
 from ..externals.joblib import Parallel, delayed
 from ..externals import six
@@ -34,6 +36,7 @@ from ..utils.fixes import MaskedArray
 from ..utils.random import sample_without_replacement
 from ..utils.validation import indexable, check_is_fitted
 from ..utils.metaestimators import if_delegate_has_method
+from ..metrics.scorer import _check_multimetric_scoring
 from ..metrics.scorer import check_scoring
 
 
@@ -295,10 +298,12 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
     test : ndarray, dtype int or bool
         Boolean mask or indices for test set.
 
-    scorer : callable or None.
-        If provided must be a scorer callable object / function with signature
+    scorer : callable or None
+        The scorer callable object / function must have its signature as
         ``scorer(estimator, X, y)``.
 
+        If ``None`` the estimator's default scorer is used.
+
     verbose : int
         Verbosity level.
 
@@ -314,7 +319,7 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
     Returns
     -------
     score : float
-        Score of this parameter setting on given training / test split.
+         Score of this parameter setting on given training / test split.
 
     parameters : dict
         The parameters that have been evaluated.
@@ -322,12 +327,16 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
     n_samples_test : int
         Number of test samples in this split.
     """
-    score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
-                                              test, verbose, parameters,
-                                              fit_params=fit_params,
-                                              return_n_test_samples=True,
-                                              error_score=error_score)
-    return score, parameters, n_samples_test
+    # NOTE we are not using the return value as the scorer by itself should be
+    # validated before. We use check_scoring only to reject multimetric scorer
+    check_scoring(estimator, scorer)
+    scores, n_samples_test = _fit_and_score(estimator, X, y,
+                                            scorer, train,
+                                            test, verbose, parameters,
+                                            fit_params=fit_params,
+                                            return_n_test_samples=True,
+                                            error_score=error_score)
+    return scores, parameters, n_samples_test
 
 
 def _check_param_grid(param_grid):
@@ -419,18 +428,23 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         -------
         score : float
         """
+        self._check_is_fitted('score')
         if self.scorer_ is None:
             raise ValueError("No score function explicitly defined, "
                              "and the estimator doesn't provide one %s"
                              % self.best_estimator_)
-        return self.scorer_(self.best_estimator_, X, y)
+        score = self.scorer_[self.refit] if self.multimetric_ else self.scorer_
+        return score(self.best_estimator_, X, y)
 
     def _check_is_fitted(self, method_name):
         if not self.refit:
-            raise NotFittedError(('This GridSearchCV instance was initialized '
-                                  'with refit=False. %s is '
-                                  'available only after refitting on the best '
-                                  'parameters. ') % method_name)
+            raise NotFittedError('This %s instance was initialized '
+                                 'with refit=False. %s is '
+                                 'available only after refitting on the best '
+                                 'parameters. You can refit an estimator '
+                                 'manually using the ``best_parameters_`` '
+                                 'attribute'
+                                 % (type(self).__name__, method_name))
         else:
             check_is_fitted(self, 'best_estimator_')
 
@@ -575,7 +589,27 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                 fit_params = self.fit_params
         estimator = self.estimator
         cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
-        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
+
+        scorers, self.multimetric_ = _check_multimetric_scoring(
+            self.estimator, scoring=self.scoring)
+
+        if self.multimetric_:
+            if self.refit is not False and (
+                    not isinstance(self.refit, six.string_types) or
+                    # This will work for both dict / list (tuple)
+                    self.refit not in scorers):
+                raise ValueError("For multi-metric scoring, the parameter "
+                                 "refit must be set to a scorer key "
+                                 "to refit an estimator with the best "
+                                 "parameter setting on the whole data and "
+                                 "make the best_* attributes "
+                                 "available for that metric. If this is not "
+                                 "needed, refit should be set to False "
+                                 "explicitly. %r was passed." % self.refit)
+            else:
+                refit_metric = self.refit
+        else:
+            refit_metric = 'score'
 
         X, y, groups = indexable(X, y, groups)
         n_splits = cv.get_n_splits(X, y, groups)
@@ -593,8 +627,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         out = Parallel(
             n_jobs=self.n_jobs, verbose=self.verbose,
             pre_dispatch=pre_dispatch
-        )(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
-                                  train, test, self.verbose, parameters,
+        )(delayed(_fit_and_score)(clone(base_estimator), X, y, scorers, train,
+                                  test, self.verbose, parameters,
                                   fit_params=fit_params,
                                   return_train_score=self.return_train_score,
                                   return_n_test_samples=True,
@@ -605,20 +639,29 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
 
         # if one choose to see train score, "out" will contain train score info
         if self.return_train_score:
-            (train_scores, test_scores, test_sample_counts, fit_time,
+            (train_score_dicts, test_score_dicts, test_sample_counts, fit_time,
              score_time) = zip(*out)
         else:
-            (test_scores, test_sample_counts, fit_time, score_time) = zip(*out)
+            (test_score_dicts, test_sample_counts, fit_time,
+             score_time) = zip(*out)
+
+        # test_score_dicts and train_score dicts are lists of dictionaries and
+        # we make them into dict of lists
+        test_scores = _aggregate_score_dicts(test_score_dicts)
+        if self.return_train_score:
+            train_scores = _aggregate_score_dicts(train_score_dicts)
 
         results = dict()
 
         def _store(key_name, array, weights=None, splits=False, rank=False):
             """A small helper to store the scores/times to the cv_results_"""
             # When iterated first by splits, then by parameters
+            # We want `array` to have `n_candidates` rows and `n_splits` cols.
             array = np.array(array, dtype=np.float64).reshape(n_candidates,
                                                               n_splits)
             if splits:
                 for split_i in range(n_splits):
+                    # Uses closure to alter the results
                     results["split%d_%s"
                             % (split_i, key_name)] = array[:, split_i]
 
@@ -634,21 +677,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                 results["rank_%s" % key_name] = np.asarray(
                     rankdata(-array_means, method='min'), dtype=np.int32)
 
-        # Computed the (weighted) mean and std for test scores alone
-        # NOTE test_sample counts (weights) remain the same for all candidates
-        test_sample_counts = np.array(test_sample_counts[:n_splits],
-                                      dtype=np.int)
-
-        _store('test_score', test_scores, splits=True, rank=True,
-               weights=test_sample_counts if self.iid else None)
-        if self.return_train_score:
-            _store('train_score', train_scores, splits=True)
         _store('fit_time', fit_time)
         _store('score_time', score_time)
-
-        best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
-        best_parameters = candidate_params[best_index]
-
         # Use one MaskedArray and mask all the places where the param is not
         # applicable for that candidate. Use defaultdict as each candidate may
         # not contain all the params
@@ -664,45 +694,58 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                 param_results["param_%s" % name][cand_i] = value
 
         results.update(param_results)
-
         # Store a list of param dicts at the key 'params'
         results['params'] = candidate_params
 
-        self.cv_results_ = results
-        self.best_index_ = best_index
-        self.n_splits_ = n_splits
+        # NOTE test_sample counts (weights) remain the same for all candidates
+        test_sample_counts = np.array(test_sample_counts[:n_splits],
+                                      dtype=np.int)
+        for scorer_name in scorers.keys():
+            # Computed the (weighted) mean and std for test scores alone
+            _store('test_%s' % scorer_name, test_scores[scorer_name],
+                   splits=True, rank=True,
+                   weights=test_sample_counts if self.iid else None)
+            if self.return_train_score:
+                _store('train_%s' % scorer_name, train_scores[scorer_name],
+                       splits=True)
+
+        # For multi-metric evaluation, store the best_index_, best_params_ and
+        # best_score_ iff refit is one of the scorer names
+        # In single metric evaluation, refit_metric is "score"
+        if self.refit or not self.multimetric_:
+            self.best_index_ = results["rank_test_%s" % refit_metric].argmin()
+            self.best_params_ = candidate_params[self.best_index_]
+            self.best_score_ = results["mean_test_%s" % refit_metric][
+                self.best_index_]
 
         if self.refit:
-            # fit the best estimator using the entire dataset
-            # clone first to work around broken estimators
-            best_estimator = clone(base_estimator).set_params(
-                **best_parameters)
+            self.best_estimator_ = clone(base_estimator).set_params(
+                **self.best_params_)
             if y is not None:
-                best_estimator.fit(X, y, **fit_params)
+                self.best_estimator_.fit(X, y, **fit_params)
             else:
-                best_estimator.fit(X, **fit_params)
-            self.best_estimator_ = best_estimator
-        return self
+                self.best_estimator_.fit(X, **fit_params)
 
-    @property
-    def best_params_(self):
-        check_is_fitted(self, 'cv_results_')
-        return self.cv_results_['params'][self.best_index_]
+        # Store the only scorer not as a dict for single metric evaluation
+        self.scorer_ = scorers if self.multimetric_ else scorers['score']
 
-    @property
-    def best_score_(self):
-        check_is_fitted(self, 'cv_results_')
-        return self.cv_results_['mean_test_score'][self.best_index_]
+        self.cv_results_ = results
+        self.n_splits_ = n_splits
+
+        return self
 
     @property
     def grid_scores_(self):
+        check_is_fitted(self, 'cv_results_')
+        if self.multimetric_:
+            raise AttributeError("grid_scores_ attribute is not available for"
+                                 " multi-metric evaluation.")
         warnings.warn(
             "The grid_scores_ attribute was deprecated in version 0.18"
             " in favor of the more elaborate cv_results_ attribute."
             " The grid_scores_ attribute will not be available from 0.20",
             DeprecationWarning)
 
-        check_is_fitted(self, 'cv_results_')
         grid_scores = list()
 
         for i, (params, mean, std) in enumerate(zip(
@@ -747,11 +790,20 @@ class GridSearchCV(BaseSearchCV):
         in the list are explored. This enables searching over any sequence
         of parameter settings.
 
-    scoring : string, callable or None, default=None
-        A string (see model evaluation documentation) or
-        a scorer callable object / function with signature
-        ``scorer(estimator, X, y)``.
-        If ``None``, the ``score`` method of the estimator is used.
+    scoring : string, callable, list/tuple, dict or None, default: None
+        A single string (see :ref:`scoring_parameter`) or a callable
+        (see :ref:`scoring`) to evaluate the predictions on the test set.
+
+        For evaluating multiple metrics, either give a list of (unique) strings
+        or a dict with names as keys and callables as values.
+
+        NOTE that when using custom scorers, each scorer should return a single
+        value. Metric functions returning a list/array of values can be wrapped
+        into multiple scorers that return one value each.
+
+        See :ref:`multivalued_scorer_wrapping` for an example.
+
+        If None, the estimator's default scorer (if available) is used.
 
     fit_params : dict, optional
         Parameters to pass to the fit method.
@@ -801,10 +853,25 @@ class GridSearchCV(BaseSearchCV):
         Refer :ref:`User Guide <cross_validation>` for the various
         cross-validation strategies that can be used here.
 
-    refit : boolean, default=True
-        Refit the best estimator with the entire dataset.
-        If "False", it is impossible to make predictions using
-        this GridSearchCV instance after fitting.
+    refit : boolean, or string, default=True
+        Refit an estimator using the best found parameters on the whole
+        dataset.
+
+        For multiple metric evaluation, this needs to be a string denoting the
+        scorer is used to find the best parameters for refitting the estimator
+        at the end.
+
+        The refitted estimator is made available at the ``best_estimator_``
+        attribute and permits using ``predict`` directly on this
+        ``GridSearchCV`` instance.
+
+        Also for multiple metric evaluation, the attributes ``best_index_``,
+        ``best_score_`` and ``best_parameters_`` will only be available if
+        ``refit`` is set and all of them will be determined w.r.t this specific
+        scorer.
+
+        See ``scoring`` parameter to know more about multiple metric
+        evaluation.
 
     verbose : integer
         Controls the verbosity: the higher, the more messages.
@@ -857,7 +924,7 @@ class GridSearchCV(BaseSearchCV):
         For instance the below given table
 
         +------------+-----------+------------+-----------------+---+---------+
-        |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_....|
+        |param_kernel|param_gamma|param_degree|split0_test_score|...|..rank...|
         +============+===========+============+=================+===+=========+
         |  'poly'    |     --    |      2     |        0.8      |...|    2    |
         +------------+-----------+------------+-----------------+---+---------+
@@ -893,23 +960,38 @@ class GridSearchCV(BaseSearchCV):
             'params'             : [{'kernel': 'poly', 'degree': 2}, ...],
             }
 
-        NOTE that the key ``'params'`` is used to store a list of parameter
-        settings dict for all the parameter candidates.
+        NOTE
+
+        The key ``'params'`` is used to store a list of parameter
+        settings dicts for all the parameter candidates.
 
         The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
         ``std_score_time`` are all in seconds.
 
-    best_estimator_ : estimator
+        For multi-metric evaluation, the scores for all the scorers are
+        available in the ``cv_results_`` dict at the keys ending with that
+        scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
+        above. ('split0_test_precision', 'mean_train_precision' etc.)
+
+    best_estimator_ : estimator or dict
         Estimator that was chosen by the search, i.e. estimator
         which gave highest score (or smallest loss if specified)
-        on the left out data. Not available if refit=False.
+        on the left out data. Not available if ``refit=False``.
+
+        See ``refit`` parameter for more information on allowed values.
 
     best_score_ : float
-        Score of best_estimator on the left out data.
+        Mean cross-validated score of the best_estimator
+
+        For multi-metric evaluation, this is present only if ``refit`` is
+        specified.
 
     best_params_ : dict
         Parameter setting that gave the best results on the hold out data.
 
+        For multi-metric evaluation, this is present only if ``refit`` is
+        specified.
+
     best_index_ : int
         The index (of the ``cv_results_`` arrays) which corresponds to the best
         candidate parameter setting.
@@ -918,10 +1000,16 @@ class GridSearchCV(BaseSearchCV):
         the parameter setting for the best model, that gives the highest
         mean score (``search.best_score_``).
 
-    scorer_ : function
+        For multi-metric evaluation, this is present only if ``refit`` is
+        specified.
+
+    scorer_ : function or a dict
         Scorer function used on the held out data to choose the best
         parameters for the model.
 
+        For multi-metric evaluation, this attribute holds the validated
+        ``scoring`` dict which maps the scorer key to the scorer callable.
+
     n_splits_ : int
         The number of cross-validation splits (folds/iterations).
 
@@ -1012,11 +1100,20 @@ class RandomizedSearchCV(BaseSearchCV):
         Number of parameter settings that are sampled. n_iter trades
         off runtime vs quality of the solution.
 
-    scoring : string, callable or None, default=None
-        A string (see model evaluation documentation) or
-        a scorer callable object / function with signature
-        ``scorer(estimator, X, y)``.
-        If ``None``, the ``score`` method of the estimator is used.
+    scoring : string, callable, list/tuple, dict or None, default: None
+        A single string (see :ref:`scoring_parameter`) or a callable
+        (see :ref:`scoring`) to evaluate the predictions on the test set.
+
+        For evaluating multiple metrics, either give a list of (unique) strings
+        or a dict with names as keys and callables as values.
+
+        NOTE that when using custom scorers, each scorer should return a single
+        value. Metric functions returning a list/array of values can be wrapped
+        into multiple scorers that return one value each.
+
+        See :ref:`multivalued_scorer_wrapping` for an example.
+
+        If None, the estimator's default scorer (if available) is used.
 
     fit_params : dict, optional
         Parameters to pass to the fit method.
@@ -1066,10 +1163,25 @@ class RandomizedSearchCV(BaseSearchCV):
         Refer :ref:`User Guide <cross_validation>` for the various
         cross-validation strategies that can be used here.
 
-    refit : boolean, default=True
-        Refit the best estimator with the entire dataset.
-        If "False", it is impossible to make predictions using
-        this RandomizedSearchCV instance after fitting.
+    refit : boolean, or string default=True
+        Refit an estimator using the best found parameters on the whole
+        dataset.
+
+        For multiple metric evaluation, this needs to be a string denoting the
+        scorer that would be used to find the best parameters for refitting
+        the estimator at the end.
+
+        The refitted estimator is made available at the ``best_estimator_``
+        attribute and permits using ``predict`` directly on this
+        ``RandomizedSearchCV`` instance.
+
+        Also for multiple metric evaluation, the attributes ``best_index_``,
+        ``best_score_`` and ``best_parameters_`` will only be available if
+        ``refit`` is set and all of them will be determined w.r.t this specific
+        scorer.
+
+        See ``scoring`` parameter to know more about multiple metric
+        evaluation.
 
     verbose : integer
         Controls the verbosity: the higher, the more messages.
@@ -1129,26 +1241,44 @@ class RandomizedSearchCV(BaseSearchCV):
             'std_fit_time'       : [0.01, 0.02, 0.01, 0.01],
             'mean_score_time'    : [0.007, 0.06, 0.04, 0.04],
             'std_score_time'     : [0.001, 0.002, 0.003, 0.005],
-            'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
+            'params'             : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
             }
 
-        NOTE that the key ``'params'`` is used to store a list of parameter
-        settings dict for all the parameter candidates.
+        NOTE
+
+        The key ``'params'`` is used to store a list of parameter
+        settings dicts for all the parameter candidates.
 
         The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
         ``std_score_time`` are all in seconds.
 
-    best_estimator_ : estimator
+        For multi-metric evaluation, the scores for all the scorers are
+        available in the ``cv_results_`` dict at the keys ending with that
+        scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
+        above. ('split0_test_precision', 'mean_train_precision' etc.)
+
+    best_estimator_ : estimator or dict
         Estimator that was chosen by the search, i.e. estimator
         which gave highest score (or smallest loss if specified)
-        on the left out data. Not available if refit=False.
+        on the left out data. Not available if ``refit=False``.
+
+        For multi-metric evaluation, this attribute is present only if
+        ``refit`` is specified.
+
+        See ``refit`` parameter for more information on allowed values.
 
     best_score_ : float
-        Score of best_estimator on the left out data.
+        Mean cross-validated score of the best_estimator.
+
+        For multi-metric evaluation, this is not available if ``refit`` is
+        ``False``. See ``refit`` parameter for more information.
 
     best_params_ : dict
         Parameter setting that gave the best results on the hold out data.
 
+        For multi-metric evaluation, this is not available if ``refit`` is
+        ``False``. See ``refit`` parameter for more information.
+
     best_index_ : int
         The index (of the ``cv_results_`` arrays) which corresponds to the best
         candidate parameter setting.
@@ -1157,10 +1287,16 @@ class RandomizedSearchCV(BaseSearchCV):
         the parameter setting for the best model, that gives the highest
         mean score (``search.best_score_``).
 
-    scorer_ : function
+        For multi-metric evaluation, this is not available if ``refit`` is
+        ``False``. See ``refit`` parameter for more information.
+
+    scorer_ : function or a dict
         Scorer function used on the held out data to choose the best
         parameters for the model.
 
+        For multi-metric evaluation, this attribute holds the validated
+        ``scoring`` dict which maps the scorer key to the scorer callable.
+
     n_splits_ : int
         The number of cross-validation splits (folds/iterations).
 
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index fe9c0e8c46..1e5ea29740 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -3,12 +3,12 @@ The :mod:`sklearn.model_selection._validation` module includes classes and
 functions to validate the model.
 """
 
-# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>,
-#         Gael Varoquaux <gael.varoquaux@normalesup.org>,
+# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
+#         Gael Varoquaux <gael.varoquaux@normalesup.org>
 #         Olivier Grisel <olivier.grisel@ensta.org>
+#         Raghav RV <rvraghav93@gmail.com>
 # License: BSD 3 clause
 
-
 from __future__ import print_function
 from __future__ import division
 
@@ -24,13 +24,193 @@ from ..utils import indexable, check_random_state, safe_indexing
 from ..utils.validation import _is_arraylike, _num_samples
 from ..utils.metaestimators import _safe_split
 from ..externals.joblib import Parallel, delayed, logger
-from ..metrics.scorer import check_scoring
+from ..externals.six.moves import zip
+from ..metrics.scorer import check_scoring, _check_multimetric_scoring
 from ..exceptions import FitFailedWarning
 from ._split import check_cv
 from ..preprocessing import LabelEncoder
 
-__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
-           'learning_curve', 'validation_curve']
+
+__all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict',
+           'permutation_test_score', 'learning_curve', 'validation_curve']
+
+
+def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
+                   n_jobs=1, verbose=0, fit_params=None,
+                   pre_dispatch='2*n_jobs', return_train_score=True):
+    """Evaluate metric(s) by cross-validation and also record fit/score times.
+
+    Read more in the :ref:`User Guide <multimetric_cross_validation>`.
+
+    Parameters
+    ----------
+    estimator : estimator object implementing 'fit'
+        The object to use to fit the data.
+
+    X : array-like
+        The data to fit. Can be for example a list, or an array.
+
+    y : array-like, optional, default: None
+        The target variable to try to predict in the case of
+        supervised learning.
+
+    groups : array-like, with shape (n_samples,), optional
+        Group labels for the samples used while splitting the dataset into
+        train/test set.
+
+    scoring : string, callable, list/tuple, dict or None, default: None
+        A single string (see :ref:`scoring_parameter`) or a callable
+        (see :ref:`scoring`) to evaluate the predictions on the test set.
+
+        For evaluating multiple metrics, either give a list of (unique) strings
+        or a dict with names as keys and callables as values.
+
+        NOTE that when using custom scorers, each scorer should return a single
+        value. Metric functions returning a list/array of values can be wrapped
+        into multiple scorers that return one value each.
+
+        See :ref:`multivalued_scorer_wrapping` for an example.
+
+        If None, the estimator's default scorer (if available) is used.
+
+    cv : int, cross-validation generator or an iterable, optional
+        Determines the cross-validation splitting strategy.
+        Possible inputs for cv are:
+          - None, to use the default 3-fold cross validation,
+          - integer, to specify the number of folds in a `(Stratified)KFold`,
+          - An object to be used as a cross-validation generator.
+          - An iterable yielding train, test splits.
+
+        For integer/None inputs, if the estimator is a classifier and ``y`` is
+        either binary or multiclass, :class:`StratifiedKFold` is used. In all
+        other cases, :class:`KFold` is used.
+
+        Refer :ref:`User Guide <cross_validation>` for the various
+        cross-validation strategies that can be used here.
+
+    n_jobs : integer, optional
+        The number of CPUs to use to do the computation. -1 means
+        'all CPUs'.
+
+    verbose : integer, optional
+        The verbosity level.
+
+    fit_params : dict, optional
+        Parameters to pass to the fit method of the estimator.
+
+    pre_dispatch : int, or string, optional
+        Controls the number of jobs that get dispatched during parallel
+        execution. Reducing this number can be useful to avoid an
+        explosion of memory consumption when more jobs get dispatched
+        than CPUs can process. This parameter can be:
+
+            - None, in which case all the jobs are immediately
+              created and spawned. Use this for lightweight and
+              fast-running jobs, to avoid delays due to on-demand
+              spawning of the jobs
+
+            - An int, giving the exact number of total jobs that are
+              spawned
+
+            - A string, giving an expression as a function of n_jobs,
+              as in '2*n_jobs'
+
+    return_train_score : boolean, default True
+        Whether to include train scores in the return dict if ``scoring`` is
+        of multimetric type.
+
+    Returns
+    -------
+    scores : dict of float arrays of shape=(n_splits,)
+        Array of scores of the estimator for each run of the cross validation.
+
+        A dict of arrays containing the score/time arrays for each scorer is
+        returned. The possible keys for this ``dict`` are:
+
+            ``test_score``
+                The score array for test scores on each cv split.
+            ``train_score``
+                The score array for train scores on each cv split.
+                This is available only if ``return_train_score`` parameter
+                is ``True``.
+            ``fit_time``
+                The time for fitting the estimator on the train
+                set for each cv split.
+            ``score_time``
+                The time for scoring the estimator on the test set for each
+                cv split. (Note time for scoring on the train set is not
+                included even if ``return_train_score`` is set to ``True``
+
+    Examples
+    --------
+    >>> from sklearn import datasets, linear_model
+    >>> from sklearn.model_selection import cross_val_score
+    >>> from sklearn.metrics.scorer import make_scorer
+    >>> from sklearn.metrics import confusion_matrix
+    >>> from sklearn.svm import LinearSVC
+    >>> diabetes = datasets.load_diabetes()
+    >>> X = diabetes.data[:150]
+    >>> y = diabetes.target[:150]
+    >>> lasso = linear_model.Lasso()
+
+    # single metric evaluation using cross_validate
+    >>> cv_results = cross_validate(lasso, X, y, return_train_score=False)
+    >>> sorted(cv_results.keys())                         # doctest: +ELLIPSIS
+    ['fit_time', 'score_time', 'test_score']
+    >>> cv_results['test_score']    # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+    array([ 0.33...,  0.08...,  0.03...])
+
+    # Multiple metric evaluation using cross_validate
+    # (Please refer the ``scoring`` parameter doc for more information)
+    >>> scores = cross_validate(lasso, X, y,
+    ...                         scoring=('r2', 'neg_mean_squared_error'))
+    >>> print(scores['test_neg_mean_squared_error'])      # doctest: +ELLIPSIS
+    [-3635.5... -3573.3... -6114.7...]
+    >>> print(scores['train_r2'])                         # doctest: +ELLIPSIS
+    [ 0.28...  0.39...  0.22...]
+
+    See Also
+    ---------
+    :func:`sklearn.metrics.cross_val_score`:
+        Run cross-validation for single metric evaluation.
+
+    :func:`sklearn.metrics.make_scorer`:
+        Make a scorer from a performance metric or loss function.
+
+    """
+    X, y, groups = indexable(X, y, groups)
+
+    cv = check_cv(cv, y, classifier=is_classifier(estimator))
+    scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring)
+
+    # We clone the estimator to make sure that all the folds are
+    # independent, and that it is pickle-able.
+    parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
+                        pre_dispatch=pre_dispatch)
+    scores = parallel(
+        delayed(_fit_and_score)(
+            clone(estimator), X, y, scorers, train, test, verbose, None,
+            fit_params, return_train_score=return_train_score,
+            return_times=True)
+        for train, test in cv.split(X, y, groups))
+
+    if return_train_score:
+        train_scores, test_scores, fit_times, score_times = zip(*scores)
+        train_scores = _aggregate_score_dicts(train_scores)
+    else:
+        test_scores, fit_times, score_times = zip(*scores)
+    test_scores = _aggregate_score_dicts(test_scores)
+
+    ret = dict()
+    ret['fit_time'] = np.array(fit_times)
+    ret['score_time'] = np.array(score_times)
+
+    for name in scorers:
+        ret['test_%s' % name] = np.array(test_scores[name])
+        if return_train_score:
+            ret['train_%s' % name] = np.array(train_scores[name])
+
+    return ret
 
 
 def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
@@ -46,7 +226,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
         The object to use to fit the data.
 
     X : array-like
-        The data to fit. Can be, for example a list, or an array at least 2d.
+        The data to fit. Can be for example a list, or an array.
 
     y : array-like, optional, default: None
         The target variable to try to predict in the case of
@@ -122,23 +302,24 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
 
     See Also
     ---------
+    :func:`sklearn.model_selection.cross_validate`:
+        To run cross-validation on multiple metrics and also to return
+        train scores, fit times and score times.
+
     :func:`sklearn.metrics.make_scorer`:
         Make a scorer from a performance metric or loss function.
 
     """
-    X, y, groups = indexable(X, y, groups)
-
-    cv = check_cv(cv, y, classifier=is_classifier(estimator))
+    # To ensure multimetric format is not supported
     scorer = check_scoring(estimator, scoring=scoring)
-    # We clone the estimator to make sure that all the folds are
-    # independent, and that it is pickle-able.
-    parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
-                        pre_dispatch=pre_dispatch)
-    scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
-                                              train, test, verbose, None,
-                                              fit_params)
-                      for train, test in cv.split(X, y, groups))
-    return np.array(scores)[:, 0]
+
+    cv_results = cross_validate(estimator=estimator, X=X, y=y, groups=groups,
+                                scoring={'score': scorer}, cv=cv,
+                                return_train_score=False,
+                                n_jobs=n_jobs, verbose=verbose,
+                                fit_params=fit_params,
+                                pre_dispatch=pre_dispatch)
+    return cv_results['test_score']
 
 
 def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
@@ -159,8 +340,14 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
         The target variable to try to predict in the case of
         supervised learning.
 
-    scorer : callable
-        A scorer callable object / function with signature
+    scorer : A single callable or dict mapping scorer name to the callable
+        If it is a single callable, the return value for ``train_scores`` and
+        ``test_scores`` is a single float.
+
+        For a dict, it should be one mapping the scorer name to the scorer
+        callable object / function.
+
+        The callable object / fn should have signature
         ``scorer(estimator, X, y)``.
 
     train : array-like, shape (n_train_samples,)
@@ -190,13 +377,20 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
     return_parameters : boolean, optional, default: False
         Return parameters that has been used for the estimator.
 
+    return_n_test_samples : boolean, optional, default: False
+        Whether to return the ``n_test_samples``
+
+    return_times : boolean, optional, default: False
+        Whether to return the fit/score times.
+
     Returns
     -------
-    train_score : float, optional
-        Score on training set, returned only if `return_train_score` is `True`.
+    train_scores : dict of scorer name -> float, optional
+        Score on training set (for all the scorers),
+        returned only if `return_train_score` is `True`.
 
-    test_score : float
-        Score on test set.
+    test_scores : dict of scorer name -> float, optional
+        Score on testing set (for all the scorers).
 
     n_test_samples : int
         Number of test samples.
@@ -223,6 +417,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
     fit_params = dict([(k, _index_param_value(X, v, train))
                       for k, v in fit_params.items()])
 
+    test_scores = {}
+    train_scores = {}
     if parameters is not None:
         estimator.set_params(**parameters)
 
@@ -231,6 +427,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
     X_train, y_train = _safe_split(estimator, X, y, train)
     X_test, y_test = _safe_split(estimator, X, y, test, train)
 
+    is_multimetric = not callable(scorer)
+    n_scorers = len(scorer.keys()) if is_multimetric else 1
+
     try:
         if y_train is None:
             estimator.fit(X_train, **fit_params)
@@ -244,9 +443,16 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
         if error_score == 'raise':
             raise
         elif isinstance(error_score, numbers.Number):
-            test_score = error_score
-            if return_train_score:
-                train_score = error_score
+            if is_multimetric:
+                test_scores = dict(zip(scorer.keys(),
+                                   [error_score, ] * n_scorers))
+                if return_train_score:
+                    train_scores = dict(zip(scorer.keys(),
+                                        [error_score, ] * n_scorers))
+            else:
+                test_scores = error_score
+                if return_train_score:
+                    train_scores = error_score
             warnings.warn("Classifier fit failed. The score on this train-test"
                           " partition for these parameters will be set to %f. "
                           "Details: \n%r" % (error_score, e), FitFailedWarning)
@@ -257,19 +463,25 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
 
     else:
         fit_time = time.time() - start_time
-        test_score = _score(estimator, X_test, y_test, scorer)
+        # _score will return dict if is_multimetric is True
+        test_scores = _score(estimator, X_test, y_test, scorer, is_multimetric)
         score_time = time.time() - start_time - fit_time
         if return_train_score:
-            train_score = _score(estimator, X_train, y_train, scorer)
+            train_scores = _score(estimator, X_train, y_train, scorer,
+                                  is_multimetric)
 
     if verbose > 2:
-        msg += ", score=%f" % test_score
+        if is_multimetric:
+            for scorer_name, score in test_scores.items():
+                msg += ", %s=%s" % (scorer_name, score)
+        else:
+            msg += ", score=%s" % test_scores
     if verbose > 1:
         total_time = score_time + fit_time
         end_msg = "%s, total=%s" % (msg, logger.short_format_time(total_time))
         print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
 
-    ret = [train_score, test_score] if return_train_score else [test_score]
+    ret = [train_scores, test_scores] if return_train_score else [test_scores]
 
     if return_n_test_samples:
         ret.append(_num_samples(X_test))
@@ -280,25 +492,61 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
     return ret
 
 
-def _score(estimator, X_test, y_test, scorer):
-    """Compute the score of an estimator on a given test set."""
-    if y_test is None:
-        score = scorer(estimator, X_test)
+def _score(estimator, X_test, y_test, scorer, is_multimetric=False):
+    """Compute the score(s) of an estimator on a given test set.
+
+    Will return a single float if is_multimetric is False and a dict of floats,
+    if is_multimetric is True
+    """
+    if is_multimetric:
+        return _multimetric_score(estimator, X_test, y_test, scorer)
     else:
-        score = scorer(estimator, X_test, y_test)
-    if hasattr(score, 'item'):
-        try:
-            # e.g. unwrap memmapped scalars
-            score = score.item()
-        except ValueError:
-            # non-scalar?
-            pass
-    if not isinstance(score, numbers.Number):
-        raise ValueError("scoring must return a number, got %s (%s) instead."
-                         % (str(score), type(score)))
+        if y_test is None:
+            score = scorer(estimator, X_test)
+        else:
+            score = scorer(estimator, X_test, y_test)
+
+        if hasattr(score, 'item'):
+            try:
+                # e.g. unwrap memmapped scalars
+                score = score.item()
+            except ValueError:
+                # non-scalar?
+                pass
+
+        if not isinstance(score, numbers.Number):
+            raise ValueError("scoring must return a number, got %s (%s) "
+                             "instead. (scorer=%r)"
+                             % (str(score), type(score), scorer))
     return score
 
 
+def _multimetric_score(estimator, X_test, y_test, scorers):
+    """Return a dict of score for multimetric scoring"""
+    scores = {}
+
+    for name, scorer in scorers.items():
+        if y_test is None:
+            score = scorer(estimator, X_test)
+        else:
+            score = scorer(estimator, X_test, y_test)
+
+        if hasattr(score, 'item'):
+            try:
+                # e.g. unwrap memmapped scalars
+                score = score.item()
+            except ValueError:
+                # non-scalar?
+                pass
+        scores[name] = score
+
+        if not isinstance(score, numbers.Number):
+            raise ValueError("scoring must return a number, got %s (%s) "
+                             "instead. (scorer=%s)"
+                             % (str(score), type(score), name))
+    return scores
+
+
 def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
                       verbose=0, fit_params=None, pre_dispatch='2*n_jobs',
                       method='predict'):
@@ -555,9 +803,10 @@ def permutation_test_score(estimator, X, y, groups=None, cv=None,
         the dataset into train/test set.
 
     scoring : string, callable or None, optional, default: None
-        A string (see model evaluation documentation) or
-        a scorer callable object / function with signature
-        ``scorer(estimator, X, y)``.
+        A single string (see :ref:`_scoring_parameter`) or a callable
+        (see :ref:`_scoring`) to evaluate the predictions on the test set.
+
+        If None the estimator's default scorer, if available, is used.
 
     cv : int, cross-validation generator or an iterable, optional
         Determines the cross-validation splitting strategy.
@@ -997,10 +1246,38 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
         parameters={param_name: v}, fit_params=None, return_train_score=True)
         # NOTE do not change order of iteration to allow one time cv splitters
         for train, test in cv.split(X, y, groups) for v in param_range)
-
     out = np.asarray(out)
     n_params = len(param_range)
     n_cv_folds = out.shape[0] // n_params
     out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))
 
     return out[0], out[1]
+
+
+def _aggregate_score_dicts(scores):
+    """Aggregate the list of dict to dict of np ndarray
+
+    The aggregated output of _fit_and_score will be a list of dict
+    of form [{'prec': 0.1, 'acc':1.0}, {'prec': 0.1, 'acc':1.0}, ...]
+    Convert it to a dict of array {'prec': np.array([0.1 ...]), ...}
+
+    Parameters
+    ----------
+
+    scores : list of dict
+        List of dicts of the scores for all scorers. This is a flat list,
+        assumed originally to be of row major order.
+
+    Example
+    -------
+
+    >>> scores = [{'a': 1, 'b':10}, {'a': 2, 'b':2}, {'a': 3, 'b':3},
+    ...           {'a': 10, 'b': 10}]                         # doctest: +SKIP
+    >>> _aggregate_score_dicts(scores)                        # doctest: +SKIP
+    {'a': array([1, 2, 3, 10]),
+     'b': array([10, 2, 3, 10])}
+    """
+    out = {}
+    for key in scores[0]:
+        out[key] = np.asarray([score[key] for score in scores])
+    return out
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 9e6fd57ccd..9dfd49714e 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -7,6 +7,7 @@ from sklearn.externals.joblib._compat import PY3_OR_LATER
 from itertools import chain, product
 import pickle
 import sys
+import re
 
 import numpy as np
 import scipy.sparse as sp
@@ -27,13 +28,14 @@ from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
 
 from scipy.stats import bernoulli, expon, uniform
 
-from sklearn.externals.six.moves import zip
 from sklearn.base import BaseEstimator
+from sklearn.base import clone
 from sklearn.exceptions import NotFittedError
 from sklearn.datasets import make_classification
 from sklearn.datasets import make_blobs
 from sklearn.datasets import make_multilabel_classification
 
+from sklearn.model_selection import fit_grid_point
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
 from sklearn.model_selection import StratifiedShuffleSplit
@@ -54,6 +56,8 @@ from sklearn.tree import DecisionTreeClassifier
 from sklearn.cluster import KMeans
 from sklearn.neighbors import KernelDensity
 from sklearn.metrics import f1_score
+from sklearn.metrics import recall_score
+from sklearn.metrics import accuracy_score
 from sklearn.metrics import make_scorer
 from sklearn.metrics import roc_auc_score
 from sklearn.preprocessing import Imputer
@@ -370,19 +374,30 @@ def test_trivial_cv_results_attr():
 def test_no_refit():
     # Test that GSCV can be used for model selection alone without refitting
     clf = MockClassifier()
-    grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False)
-    grid_search.fit(X, y)
-    assert_true(not hasattr(grid_search, "best_estimator_") and
-                hasattr(grid_search, "best_index_") and
-                hasattr(grid_search, "best_params_"))
-
-    # Make sure the predict/transform etc fns raise meaningfull error msg
-    for fn_name in ('predict', 'predict_proba', 'predict_log_proba',
-                    'transform', 'inverse_transform'):
-        assert_raise_message(NotFittedError,
-                             ('refit=False. %s is available only after '
-                              'refitting on the best parameters' % fn_name),
-                             getattr(grid_search, fn_name), X)
+    for scoring in [None, ['accuracy', 'precision']]:
+        grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False)
+        grid_search.fit(X, y)
+        assert_true(not hasattr(grid_search, "best_estimator_") and
+                    hasattr(grid_search, "best_index_") and
+                    hasattr(grid_search, "best_params_"))
+
+        # Make sure the functions predict/transform etc raise meaningful
+        # error messages
+        for fn_name in ('predict', 'predict_proba', 'predict_log_proba',
+                        'transform', 'inverse_transform'):
+            assert_raise_message(NotFittedError,
+                                 ('refit=False. %s is available only after '
+                                  'refitting on the best parameters'
+                                  % fn_name), getattr(grid_search, fn_name), X)
+
+    # Test that an invalid refit param raises appropriate error messages
+    for refit in ["", 5, True, 'recall', 'accuracy']:
+        assert_raise_message(ValueError, "For multi-metric scoring, the "
+                             "parameter refit must be set to a scorer key",
+                             GridSearchCV(clf, {}, refit=refit,
+                                          scoring={'acc': 'accuracy',
+                                                   'prec': 'precision'}).fit,
+                             X, y)
 
 
 def test_grid_search_error():
@@ -622,8 +637,13 @@ def test_pandas_input():
     for InputFeatureType, TargetType in types:
         # X dataframe, y series
         X_df, y_ser = InputFeatureType(X), TargetType(y)
-        check_df = lambda x: isinstance(x, InputFeatureType)
-        check_series = lambda x: isinstance(x, TargetType)
+
+        def check_df(x):
+            return isinstance(x, InputFeatureType)
+
+        def check_series(x):
+            return isinstance(x, TargetType)
+
         clf = CheckingClassifier(check_X=check_df, check_y=check_series)
 
         grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
@@ -636,16 +656,20 @@ def test_unsupervised_grid_search():
     # test grid-search with unsupervised estimator
     X, y = make_blobs(random_state=0)
     km = KMeans(random_state=0)
-    grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]),
-                               scoring='adjusted_rand_score')
-    grid_search.fit(X, y)
-    # ARI can find the right number :)
-    assert_equal(grid_search.best_params_["n_clusters"], 3)
 
+    # Multi-metric evaluation unsupervised
+    scoring = ['adjusted_rand_score', 'fowlkes_mallows_score']
+    for refit in ['adjusted_rand_score', 'fowlkes_mallows_score']:
+        grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]),
+                                   scoring=scoring, refit=refit)
+        grid_search.fit(X, y)
+        # Both ARI and FMS can find the right number :)
+        assert_equal(grid_search.best_params_["n_clusters"], 3)
+
+    # Single metric evaluation unsupervised
     grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]),
                                scoring='fowlkes_mallows_score')
     grid_search.fit(X, y)
-    # So can FMS ;)
     assert_equal(grid_search.best_params_["n_clusters"], 3)
 
     # Now without a score, and without y
@@ -694,8 +718,9 @@ def test_param_sampler():
         assert_equal([x for x in sampler], [x for x in sampler])
 
 
-def check_cv_results_array_types(cv_results, param_keys, score_keys):
+def check_cv_results_array_types(search, param_keys, score_keys):
     # Check if the search `cv_results`'s array are of correct types
+    cv_results = search.cv_results_
     assert_true(all(isinstance(cv_results[param], np.ma.MaskedArray)
                     for param in param_keys))
     assert_true(all(cv_results[key].dtype == object for key in param_keys))
@@ -703,7 +728,11 @@ def check_cv_results_array_types(cv_results, param_keys, score_keys):
                      for key in score_keys))
     assert_true(all(cv_results[key].dtype == np.float64
                     for key in score_keys if not key.startswith('rank')))
-    assert_true(cv_results['rank_test_score'].dtype == np.int32)
+
+    scorer_keys = search.scorer_.keys() if search.multimetric_ else ['score']
+
+    for key in scorer_keys:
+        assert_true(cv_results['rank_test_%s' % key].dtype == np.int32)
 
 
 def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand):
@@ -715,22 +744,27 @@ def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand):
 
 
 def check_cv_results_grid_scores_consistency(search):
-    # TODO Remove in 0.20
-    cv_results = search.cv_results_
-    res_scores = np.vstack(list([cv_results["split%d_test_score" % i]
-                                 for i in range(search.n_splits_)])).T
-    res_means = cv_results["mean_test_score"]
-    res_params = cv_results["params"]
-    n_cand = len(res_params)
-    grid_scores = assert_warns(DeprecationWarning, getattr,
-                               search, 'grid_scores_')
-    assert_equal(len(grid_scores), n_cand)
-    # Check consistency of the structure of grid_scores
-    for i in range(n_cand):
-        assert_equal(grid_scores[i].parameters, res_params[i])
-        assert_array_equal(grid_scores[i].cv_validation_scores,
-                           res_scores[i, :])
-        assert_array_equal(grid_scores[i].mean_validation_score, res_means[i])
+    # TODO Remove test in 0.20
+    if search.multimetric_:
+        assert_raise_message(AttributeError, "not available for multi-metric",
+                             getattr, search, 'grid_scores_')
+    else:
+        cv_results = search.cv_results_
+        res_scores = np.vstack(list([cv_results["split%d_test_score" % i]
+                                     for i in range(search.n_splits_)])).T
+        res_means = cv_results["mean_test_score"]
+        res_params = cv_results["params"]
+        n_cand = len(res_params)
+        grid_scores = assert_warns(DeprecationWarning, getattr,
+                                   search, 'grid_scores_')
+        assert_equal(len(grid_scores), n_cand)
+        # Check consistency of the structure of grid_scores
+        for i in range(n_cand):
+            assert_equal(grid_scores[i].parameters, res_params[i])
+            assert_array_equal(grid_scores[i].cv_validation_scores,
+                               res_scores[i, :])
+            assert_array_equal(grid_scores[i].mean_validation_score,
+                               res_means[i])
 
 
 def test_grid_search_cv_results():
@@ -741,12 +775,6 @@ def test_grid_search_cv_results():
     n_grid_points = 6
     params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]),
               dict(kernel=['poly', ], degree=[1, 2])]
-    grid_search = GridSearchCV(SVC(), cv=n_splits, iid=False,
-                               param_grid=params)
-    grid_search.fit(X, y)
-    grid_search_iid = GridSearchCV(SVC(), cv=n_splits, iid=True,
-                                   param_grid=params)
-    grid_search_iid.fit(X, y)
 
     param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel')
     score_keys = ('mean_test_score', 'mean_train_score',
@@ -760,7 +788,9 @@ def test_grid_search_cv_results():
                   'mean_score_time', 'std_score_time')
     n_candidates = n_grid_points
 
-    for search, iid in zip((grid_search, grid_search_iid), (False, True)):
+    for iid in (False, True):
+        search = GridSearchCV(SVC(), cv=n_splits, iid=iid, param_grid=params)
+        search.fit(X, y)
         assert_equal(iid, search.iid)
         cv_results = search.cv_results_
         # Check if score and timing are reasonable
@@ -771,11 +801,11 @@ def test_grid_search_cv_results():
                     if 'time' not in k and
                     k is not 'rank_test_score')
         # Check cv_results structure
-        check_cv_results_array_types(cv_results, param_keys, score_keys)
+        check_cv_results_array_types(search, param_keys, score_keys)
         check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
         # Check masking
-        cv_results = grid_search.cv_results_
-        n_candidates = len(grid_search.cv_results_['params'])
+        cv_results = search.cv_results_
+        n_candidates = len(search.cv_results_['params'])
         assert_true(all((cv_results['param_C'].mask[i] and
                          cv_results['param_gamma'].mask[i] and
                          not cv_results['param_degree'].mask[i])
@@ -790,26 +820,12 @@ def test_grid_search_cv_results():
 
 
 def test_random_search_cv_results():
-    # Make a dataset with a lot of noise to get various kind of prediction
-    # errors across CV folds and parameter settings
-    X, y = make_classification(n_samples=200, n_features=100, n_informative=3,
-                               random_state=0)
+    X, y = make_classification(n_samples=50, n_features=4, random_state=42)
 
-    # scipy.stats dists now supports `seed` but we still support scipy 0.12
-    # which doesn't support the seed. Hence the assertions in the test for
-    # random_search alone should not depend on randomization.
     n_splits = 3
     n_search_iter = 30
-    params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
-    random_search = RandomizedSearchCV(SVC(), n_iter=n_search_iter,
-                                       cv=n_splits, iid=False,
-                                       param_distributions=params)
-    random_search.fit(X, y)
-    random_search_iid = RandomizedSearchCV(SVC(), n_iter=n_search_iter,
-                                           cv=n_splits, iid=True,
-                                           param_distributions=params)
-    random_search_iid.fit(X, y)
 
+    params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
     param_keys = ('param_C', 'param_gamma')
     score_keys = ('mean_test_score', 'mean_train_score',
                   'rank_test_score',
@@ -822,11 +838,14 @@ def test_random_search_cv_results():
                   'mean_score_time', 'std_score_time')
     n_cand = n_search_iter
 
-    for search, iid in zip((random_search, random_search_iid), (False, True)):
+    for iid in (False, True):
+        search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_splits,
+                                    iid=iid, param_distributions=params)
+        search.fit(X, y)
         assert_equal(iid, search.iid)
         cv_results = search.cv_results_
         # Check results structure
-        check_cv_results_array_types(cv_results, param_keys, score_keys)
+        check_cv_results_array_types(search, param_keys, score_keys)
         check_cv_results_keys(cv_results, param_keys, score_keys, n_cand)
         # For random_search, all the param array vals should be unmasked
         assert_false(any(cv_results['param_C'].mask) or
@@ -928,6 +947,108 @@ def test_search_iid_param():
         assert_almost_equal(train_std, 0)
 
 
+def test_grid_search_cv_results_multimetric():
+    X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+    n_splits = 3
+    params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]),
+              dict(kernel=['poly', ], degree=[1, 2])]
+
+    for iid in (False, True):
+        grid_searches = []
+        for scoring in ({'accuracy': make_scorer(accuracy_score),
+                         'recall': make_scorer(recall_score)},
+                        'accuracy', 'recall'):
+            grid_search = GridSearchCV(SVC(), cv=n_splits, iid=iid,
+                                       param_grid=params, scoring=scoring,
+                                       refit=False)
+            grid_search.fit(X, y)
+            assert_equal(grid_search.iid, iid)
+            grid_searches.append(grid_search)
+
+        compare_cv_results_multimetric_with_single(*grid_searches, iid=iid)
+
+
+def test_random_search_cv_results_multimetric():
+    X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+    n_splits = 3
+    n_search_iter = 30
+    scoring = ('accuracy', 'recall')
+
+    # Scipy 0.12's stats dists do not accept seed, hence we use param grid
+    params = dict(C=np.logspace(-10, 1), gamma=np.logspace(-5, 0, base=0.1))
+    for iid in (True, False):
+        for refit in (True, False):
+            random_searches = []
+            for scoring in (('accuracy', 'recall'), 'accuracy', 'recall'):
+                # If True, for multi-metric pass refit='accuracy'
+                if refit:
+                    refit = 'accuracy' if isinstance(scoring, tuple) else refit
+                clf = SVC(probability=True, random_state=42)
+                random_search = RandomizedSearchCV(clf, n_iter=n_search_iter,
+                                                   cv=n_splits, iid=iid,
+                                                   param_distributions=params,
+                                                   scoring=scoring,
+                                                   refit=refit, random_state=0)
+                random_search.fit(X, y)
+                random_searches.append(random_search)
+
+            compare_cv_results_multimetric_with_single(*random_searches,
+                                                       iid=iid)
+            if refit:
+                compare_refit_methods_when_refit_with_acc(
+                    random_searches[0], random_searches[1], refit)
+
+
+def compare_cv_results_multimetric_with_single(
+        search_multi, search_acc, search_rec, iid):
+    """Compare multi-metric cv_results with the ensemble of multiple
+    single metric cv_results from single metric grid/random search"""
+
+    assert_equal(search_multi.iid, iid)
+    assert_true(search_multi.multimetric_)
+    assert_array_equal(sorted(search_multi.scorer_),
+                       ('accuracy', 'recall'))
+
+    cv_results_multi = search_multi.cv_results_
+    cv_results_acc_rec = {re.sub('_score$', '_accuracy', k): v
+                          for k, v in search_acc.cv_results_.items()}
+    cv_results_acc_rec.update({re.sub('_score$', '_recall', k): v
+                               for k, v in search_rec.cv_results_.items()})
+
+    # Check if score and timing are reasonable, also checks if the keys
+    # are present
+    assert_true(all((np.all(cv_results_multi[k] <= 1) for k in (
+                    'mean_score_time', 'std_score_time', 'mean_fit_time',
+                    'std_fit_time'))))
+
+    # Compare the keys, other than time keys, among multi-metric and
+    # single metric grid search results. np.testing.assert_equal performs a
+    # deep nested comparison of the two cv_results dicts
+    np.testing.assert_equal({k: v for k, v in cv_results_multi.items()
+                             if not k.endswith('_time')},
+                            {k: v for k, v in cv_results_acc_rec.items()
+                             if not k.endswith('_time')})
+
+
+def compare_refit_methods_when_refit_with_acc(search_multi, search_acc, refit):
+    """Compare refit multi-metric search methods with single metric methods"""
+    if refit:
+        assert_equal(search_multi.refit, 'accuracy')
+    else:
+        assert_false(search_multi.refit)
+    assert_equal(search_acc.refit, refit)
+
+    X, y = make_blobs(n_samples=100, n_features=4, random_state=42)
+    for method in ('predict', 'predict_proba', 'predict_log_proba'):
+        assert_almost_equal(getattr(search_multi, method)(X),
+                            getattr(search_acc, method)(X))
+    assert_almost_equal(search_multi.score(X, y), search_acc.score(X, y))
+    for key in ('best_index_', 'best_score_', 'best_params_'):
+        assert_equal(getattr(search_multi, key), getattr(search_acc, key))
+
+
 def test_search_cv_results_rank_tie_breaking():
     X, y = make_blobs(n_samples=50, random_state=42)
 
@@ -1034,6 +1155,34 @@ def test_grid_search_correct_score_results():
                 assert_almost_equal(correct_score, cv_scores[i])
 
 
+def test_fit_grid_point():
+    X, y = make_classification(random_state=0)
+    cv = StratifiedKFold(random_state=0)
+    svc = LinearSVC(random_state=0)
+    scorer = make_scorer(accuracy_score)
+
+    for params in ({'C': 0.1}, {'C': 0.01}, {'C': 0.001}):
+        for train, test in cv.split(X, y):
+            this_scores, this_params, n_test_samples = fit_grid_point(
+                X, y, clone(svc), params, train, test,
+                scorer, verbose=False)
+
+            est = clone(svc).set_params(**params)
+            est.fit(X[train], y[train])
+            expected_score = scorer(est, X[test], y[test])
+
+            # Test the return values of fit_grid_point
+            assert_almost_equal(this_scores, expected_score)
+            assert_equal(params, this_params)
+            assert_equal(n_test_samples, test.size)
+
+    # Should raise an error upon multimetric scorer
+    assert_raise_message(ValueError, "scoring value should either be a "
+                         "callable, string or None.", fit_grid_point, X, y,
+                         svc, params, train, test, {'score': scorer},
+                         verbose=True)
+
+
 def test_pickle():
     # Test that a fit search can be pickled
     clf = MockClassifier()
@@ -1272,20 +1421,16 @@ def test_grid_search_cv_splits_consistency():
                        cv=KFold(n_splits=n_splits))
     gs2.fit(X, y)
 
-    def _pop_time_keys(cv_results):
-        for key in ('mean_fit_time', 'std_fit_time',
-                    'mean_score_time', 'std_score_time'):
-            cv_results.pop(key)
-        return cv_results
-
     # OneTimeSplitter is a non-re-entrant cv where split can be called only
     # once if ``cv.split`` is called once per param setting in GridSearchCV.fit
     # the 2nd and 3rd parameter will not be evaluated as no train/test indices
     # will be generated for the 2nd and subsequent cv.split calls.
     # This is a check to make sure cv.split is not called once per param
     # setting.
-    np.testing.assert_equal(_pop_time_keys(gs.cv_results_),
-                            _pop_time_keys(gs2.cv_results_))
+    np.testing.assert_equal({k: v for k, v in gs.cv_results_.items()
+                             if not k.endswith('_time')},
+                            {k: v for k, v in gs2.cv_results_.items()
+                             if not k.endswith('_time')})
 
     # Check consistency of folds across the parameters
     gs = GridSearchCV(LinearSVC(random_state=0),
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index 3087c1f3bd..c73f42fb27 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -16,6 +16,7 @@ from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_raise_message
+from sklearn.utils.testing import assert_raises_regex
 from sklearn.utils.testing import assert_greater
 from sklearn.utils.testing import assert_less
 from sklearn.utils.testing import assert_array_almost_equal
@@ -25,6 +26,7 @@ from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
 
 from sklearn.model_selection import cross_val_score
 from sklearn.model_selection import cross_val_predict
+from sklearn.model_selection import cross_validate
 from sklearn.model_selection import permutation_test_score
 from sklearn.model_selection import KFold
 from sklearn.model_selection import StratifiedKFold
@@ -42,7 +44,12 @@ from sklearn.datasets import load_boston
 from sklearn.datasets import load_iris
 from sklearn.metrics import explained_variance_score
 from sklearn.metrics import make_scorer
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import precision_recall_fscore_support
 from sklearn.metrics import precision_score
+from sklearn.metrics import r2_score
+from sklearn.metrics.scorer import check_scoring
 
 from sklearn.linear_model import Ridge, LogisticRegression
 from sklearn.linear_model import PassiveAggressiveClassifier
@@ -56,6 +63,7 @@ from sklearn.pipeline import Pipeline
 
 from sklearn.externals.six.moves import cStringIO as StringIO
 from sklearn.base import BaseEstimator
+from sklearn.base import clone
 from sklearn.multiclass import OneVsRestClassifier
 from sklearn.utils import shuffle
 from sklearn.datasets import make_classification
@@ -262,6 +270,196 @@ def test_cross_val_score():
     assert_raises(ValueError, cross_val_score, clf, X_3d, y2)
 
 
+def test_cross_validate_invalid_scoring_param():
+    X, y = make_classification(random_state=0)
+    estimator = MockClassifier()
+
+    # Test the errors
+    error_message_regexp = ".*must be unique strings.*"
+
+    # List/tuple of callables should raise a message advising users to use
+    # dict of names to callables mapping
+    assert_raises_regex(ValueError, error_message_regexp,
+                        cross_validate, estimator, X, y,
+                        scoring=(make_scorer(precision_score),
+                                 make_scorer(accuracy_score)))
+    assert_raises_regex(ValueError, error_message_regexp,
+                        cross_validate, estimator, X, y,
+                        scoring=(make_scorer(precision_score),))
+
+    # So should empty lists/tuples
+    assert_raises_regex(ValueError, error_message_regexp + "Empty list.*",
+                        cross_validate, estimator, X, y, scoring=())
+
+    # So should duplicated entries
+    assert_raises_regex(ValueError, error_message_regexp + "Duplicate.*",
+                        cross_validate, estimator, X, y,
+                        scoring=('f1_micro', 'f1_micro'))
+
+    # Nested Lists should raise a generic error message
+    assert_raises_regex(ValueError, error_message_regexp,
+                        cross_validate, estimator, X, y,
+                        scoring=[[make_scorer(precision_score)]])
+
+    error_message_regexp = (".*should either be.*string or callable.*for "
+                            "single.*.*dict.*for multi.*")
+
+    # Empty dict should raise invalid scoring error
+    assert_raises_regex(ValueError, "An empty dict",
+                        cross_validate, estimator, X, y, scoring=(dict()))
+
+    # And so should any other invalid entry
+    assert_raises_regex(ValueError, error_message_regexp,
+                        cross_validate, estimator, X, y, scoring=5)
+
+    multiclass_scorer = make_scorer(precision_recall_fscore_support)
+
+    # Multiclass Scorers that return multiple values are not supported yet
+    assert_raises_regex(ValueError,
+                        "Can't handle mix of binary and continuous",
+                        cross_validate, estimator, X, y,
+                        scoring=multiclass_scorer)
+    assert_raises_regex(ValueError,
+                        "Can't handle mix of binary and continuous",
+                        cross_validate, estimator, X, y,
+                        scoring={"foo": multiclass_scorer})
+
+    multivalued_scorer = make_scorer(confusion_matrix)
+
+    # Multiclass Scorers that return multiple values are not supported yet
+    assert_raises_regex(ValueError, "scoring must return a number, got",
+                        cross_validate, SVC(), X, y,
+                        scoring=multivalued_scorer)
+    assert_raises_regex(ValueError, "scoring must return a number, got",
+                        cross_validate, SVC(), X, y,
+                        scoring={"foo": multivalued_scorer})
+
+    assert_raises_regex(ValueError, "'mse' is not a valid scoring value.",
+                        cross_validate, SVC(), X, y, scoring="mse")
+
+
+def test_cross_validate():
+    # Compute train and test mse/r2 scores
+    cv = KFold(n_splits=5)
+
+    # Regression
+    X_reg, y_reg = make_regression(n_samples=30, random_state=0)
+    reg = Ridge(random_state=0)
+
+    # Classification
+    X_clf, y_clf = make_classification(n_samples=30, random_state=0)
+    clf = SVC(kernel="linear", random_state=0)
+
+    for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)):
+        # It's okay to evaluate regression metrics on classification too
+        mse_scorer = check_scoring(est, 'neg_mean_squared_error')
+        r2_scorer = check_scoring(est, 'r2')
+        train_mse_scores = []
+        test_mse_scores = []
+        train_r2_scores = []
+        test_r2_scores = []
+        for train, test in cv.split(X, y):
+            est = clone(reg).fit(X[train], y[train])
+            train_mse_scores.append(mse_scorer(est, X[train], y[train]))
+            train_r2_scores.append(r2_scorer(est, X[train], y[train]))
+            test_mse_scores.append(mse_scorer(est, X[test], y[test]))
+            test_r2_scores.append(r2_scorer(est, X[test], y[test]))
+
+        train_mse_scores = np.array(train_mse_scores)
+        test_mse_scores = np.array(test_mse_scores)
+        train_r2_scores = np.array(train_r2_scores)
+        test_r2_scores = np.array(test_r2_scores)
+
+        scores = (train_mse_scores, test_mse_scores, train_r2_scores,
+                  test_r2_scores)
+
+        yield check_cross_validate_single_metric, est, X, y, scores
+        yield check_cross_validate_multi_metric, est, X, y, scores
+
+
+def check_cross_validate_single_metric(clf, X, y, scores):
+    (train_mse_scores, test_mse_scores, train_r2_scores,
+     test_r2_scores) = scores
+    # Test single metric evaluation when scoring is string or singleton list
+    for (return_train_score, dict_len) in ((True, 4), (False, 3)):
+        # Single metric passed as a string
+        if return_train_score:
+            # It must be True by default
+            mse_scores_dict = cross_validate(clf, X, y, cv=5,
+                                             scoring='neg_mean_squared_error')
+            assert_array_almost_equal(mse_scores_dict['train_score'],
+                                      train_mse_scores)
+        else:
+            mse_scores_dict = cross_validate(clf, X, y, cv=5,
+                                             scoring='neg_mean_squared_error',
+                                             return_train_score=False)
+        assert_true(isinstance(mse_scores_dict, dict))
+        assert_equal(len(mse_scores_dict), dict_len)
+        assert_array_almost_equal(mse_scores_dict['test_score'],
+                                  test_mse_scores)
+
+        # Single metric passed as a list
+        if return_train_score:
+            # It must be True by default
+            r2_scores_dict = cross_validate(clf, X, y, cv=5, scoring=['r2'])
+            assert_array_almost_equal(r2_scores_dict['train_r2'],
+                                      train_r2_scores)
+        else:
+            r2_scores_dict = cross_validate(clf, X, y, cv=5, scoring=['r2'],
+                                            return_train_score=False)
+        assert_true(isinstance(r2_scores_dict, dict))
+        assert_equal(len(r2_scores_dict), dict_len)
+        assert_array_almost_equal(r2_scores_dict['test_r2'], test_r2_scores)
+
+
+def check_cross_validate_multi_metric(clf, X, y, scores):
+    # Test multimetric evaluation when scoring is a list / dict
+    (train_mse_scores, test_mse_scores, train_r2_scores,
+     test_r2_scores) = scores
+    all_scoring = (('r2', 'neg_mean_squared_error'),
+                   {'r2': make_scorer(r2_score),
+                    'neg_mean_squared_error': 'neg_mean_squared_error'})
+
+    keys_sans_train = set(('test_r2', 'test_neg_mean_squared_error',
+                           'fit_time', 'score_time'))
+    keys_with_train = keys_sans_train.union(
+        set(('train_r2', 'train_neg_mean_squared_error')))
+
+    for return_train_score in (True, False):
+        for scoring in all_scoring:
+            if return_train_score:
+                # return_train_score must be True by default
+                cv_results = cross_validate(clf, X, y, cv=5, scoring=scoring)
+                assert_array_almost_equal(cv_results['train_r2'],
+                                          train_r2_scores)
+                assert_array_almost_equal(
+                    cv_results['train_neg_mean_squared_error'],
+                    train_mse_scores)
+            else:
+                cv_results = cross_validate(clf, X, y, cv=5, scoring=scoring,
+                                            return_train_score=False)
+            assert_true(isinstance(cv_results, dict))
+            assert_equal(set(cv_results.keys()),
+                         keys_with_train if return_train_score
+                         else keys_sans_train)
+            assert_array_almost_equal(cv_results['test_r2'], test_r2_scores)
+            assert_array_almost_equal(
+                cv_results['test_neg_mean_squared_error'], test_mse_scores)
+
+            # Make sure all the arrays are of np.ndarray type
+            assert type(cv_results['test_r2']) == np.ndarray
+            assert (type(cv_results['test_neg_mean_squared_error']) ==
+                    np.ndarray)
+            assert type(cv_results['fit_time'] == np.ndarray)
+            assert type(cv_results['score_time'] == np.ndarray)
+
+            # Ensure all the times are within sane limits
+            assert np.all(cv_results['fit_time'] >= 0)
+            assert np.all(cv_results['fit_time'] < 10)
+            assert np.all(cv_results['score_time'] >= 0)
+            assert np.all(cv_results['score_time'] < 10)
+
+
 def test_cross_val_score_predict_groups():
     # Check if ValueError (when groups is None) propagates to cross_val_score
     # and cross_val_predict
@@ -386,8 +584,9 @@ def test_cross_val_score_score_func():
 
     with warnings.catch_warnings(record=True):
         scoring = make_scorer(score_func)
-        score = cross_val_score(clf, X, y, scoring=scoring)
+        score = cross_val_score(clf, X, y, scoring=scoring, cv=3)
     assert_array_equal(score, [1.0, 1.0, 1.0])
+    # Test that score function is called only 3 times (for cv=3)
     assert len(_score_func_args) == 3
 
 
-- 
GitLab