diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 64da48dfe6f2d4690340322ab6f2295b9fd75f43..7c68cff7da090ef901295785ce001c59e6b66d5f 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -116,6 +116,20 @@ Model Selection Enhancements and API Changes
     The parameter ``n_labels`` in the newly renamed
     :class:`model_selection.LeavePGroupsOut` is changed to ``n_groups``.
 
+  - Training scores and Timing information
+
+    ``cv_results_`` also includes the training scores for each
+    cross-validation split (with keys such as ``'split0_train_score'``), as
+    well as their mean (``'mean_train_score'``) and standard deviation
+    (``'std_train_score'``). To avoid the cost of evaluating training score,
+    set ``return_train_score=False``.
+
+    Additionally the mean and standard deviation of the times taken to split,
+    train and score the model across all the cross-validation splits is
+    available at the key ``'mean_time'`` and ``'std_time'`` respectively.
+
+Changelog
+---------
 
 New features
 ............
@@ -362,6 +376,12 @@ Enhancements
      now accept arbitrary kernel functions in addition to strings ``knn`` and ``rbf``.
      (`#5762 <https://github.com/scikit-learn/scikit-learn/pull/5762>`_) By `Utkarsh Upadhyay`_.
 
+   - The training scores and time taken for training followed by scoring for
+     each search candidate are now available at the ``cv_results_`` dict.
+     See :ref:`model_selection_changes` for more information.
+     (`#7324 <https://github.com/scikit-learn/scikit-learn/pull/7325>`)
+     By `Eugene Chen`_ and `Raghav RV`_.
+
 
 Bug fixes
 .........
@@ -4731,3 +4751,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
 .. _Russell Smith: https://github.com/rsmith54
 
 .. _Utkarsh Upadhyay: https://github.com/musically-ut
+
+.. _Eugene Chen: https://github.com/eyc88
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 7c6344c02c853e40b0a9cdadf95715452375fda0..424263c0d1c3d7582aa6963eb0fc5f437df4d144 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -319,7 +319,9 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
     """
     score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
                                               test, verbose, parameters,
-                                              fit_params, error_score)
+                                              fit_params=fit_params,
+                                              return_n_test_samples=True,
+                                              error_score=error_score)
     return score, parameters, n_samples_test
 
 
@@ -374,7 +376,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
     def __init__(self, estimator, scoring=None,
                  fit_params=None, n_jobs=1, iid=True,
                  refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
-                 error_score='raise'):
+                 error_score='raise', return_train_score=True):
 
         self.scoring = scoring
         self.estimator = estimator
@@ -386,6 +388,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         self.verbose = verbose
         self.pre_dispatch = pre_dispatch
         self.error_score = error_score
+        self.return_train_score = return_train_score
 
     @property
     def _estimator_type(self):
@@ -551,41 +554,61 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
             pre_dispatch=pre_dispatch
         )(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
                                   train, test, self.verbose, parameters,
-                                  self.fit_params, return_parameters=True,
+                                  fit_params=self.fit_params,
+                                  return_train_score=self.return_train_score,
+                                  return_n_test_samples=True,
+                                  return_times=True, return_parameters=True,
                                   error_score=self.error_score)
           for parameters in parameter_iterable
           for train, test in cv.split(X, y, groups))
 
-        test_scores, test_sample_counts, _, parameters = zip(*out)
+        # if one choose to see train score, "out" will contain train score info
+        if self.return_train_score:
+            (train_scores, test_scores, test_sample_counts,
+             fit_time, score_time, parameters) = zip(*out)
+        else:
+            (test_scores, test_sample_counts,
+             fit_time, score_time, parameters) = zip(*out)
 
         candidate_params = parameters[::n_splits]
         n_candidates = len(candidate_params)
 
-        test_scores = np.array(test_scores,
-                               dtype=np.float64).reshape(n_candidates,
-                                                         n_splits)
+        results = dict()
+
+        def _store(key_name, array, weights=None, splits=False, rank=False):
+            """A small helper to store the scores/times to the cv_results_"""
+            array = np.array(array, dtype=np.float64).reshape(n_candidates,
+                                                              n_splits)
+            if splits:
+                for split_i in range(n_splits):
+                    results["split%d_%s"
+                            % (split_i, key_name)] = array[:, split_i]
+
+            array_means = np.average(array, axis=1, weights=weights)
+            results['mean_%s' % key_name] = array_means
+            # Weighted std is not directly available in numpy
+            array_stds = np.sqrt(np.average((array -
+                                             array_means[:, np.newaxis]) ** 2,
+                                            axis=1, weights=weights))
+            results['std_%s' % key_name] = array_stds
+
+            if rank:
+                results["rank_%s" % key_name] = np.asarray(
+                    rankdata(-array_means, method='min'), dtype=np.int32)
+
+        # Computed the (weighted) mean and std for test scores alone
         # NOTE test_sample counts (weights) remain the same for all candidates
         test_sample_counts = np.array(test_sample_counts[:n_splits],
                                       dtype=np.int)
 
-        # Computed the (weighted) mean and std for all the candidates
-        weights = test_sample_counts if self.iid else None
-        means = np.average(test_scores, axis=1, weights=weights)
-        stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2,
-                                  axis=1, weights=weights))
-
-        cv_results = dict()
-        for split_i in range(n_splits):
-            cv_results["split%d_test_score" % split_i] = test_scores[:,
-                                                                     split_i]
-        cv_results["mean_test_score"] = means
-        cv_results["std_test_score"] = stds
-
-        ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32)
+        _store('test_score', test_scores, splits=True, rank=True,
+               weights=test_sample_counts if self.iid else None)
+        _store('train_score', train_scores, splits=True)
+        _store('fit_time', fit_time)
+        _store('score_time', score_time)
 
-        best_index = np.flatnonzero(ranks == 1)[0]
+        best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
         best_parameters = candidate_params[best_index]
-        cv_results["rank_test_score"] = ranks
 
         # Use one np.MaskedArray and mask all the places where the param is not
         # applicable for that candidate. Use defaultdict as each candidate may
@@ -599,12 +622,12 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                 # Setting the value at an index also unmasks that index
                 param_results["param_%s" % name][cand_i] = value
 
-        cv_results.update(param_results)
+        results.update(param_results)
 
         # Store a list of param dicts at the key 'params'
-        cv_results['params'] = candidate_params
+        results['params'] = candidate_params
 
-        self.cv_results_ = cv_results
+        self.cv_results_ = results
         self.best_index_ = best_index
         self.n_splits_ = n_splits
 
@@ -746,6 +769,10 @@ class GridSearchCV(BaseSearchCV):
         FitFailedWarning is raised. This parameter does not affect the refit
         step, which will always raise the error.
 
+    return_train_score : boolean, default=True
+        If ``'False'``, the ``cv_results_`` attribute will not include training
+        scores.
+
 
     Examples
     --------
@@ -764,13 +791,16 @@ class GridSearchCV(BaseSearchCV):
                          random_state=None, shrinking=True, tol=...,
                          verbose=False),
            fit_params={}, iid=..., n_jobs=1,
-           param_grid=..., pre_dispatch=..., refit=...,
+           param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
            scoring=..., verbose=...)
     >>> sorted(clf.cv_results_.keys())
     ...                             # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
-    ['mean_test_score', 'param_C', 'param_kernel', 'params',...
-     'rank_test_score', 'split0_test_score', 'split1_test_score',...
-     'split2_test_score', 'std_test_score']
+    ['mean_fit_time', 'mean_score_time', 'mean_test_score',...
+     'mean_train_score', 'param_C', 'param_kernel', 'params',...
+     'rank_test_score', 'split0_test_score',...
+     'split0_train_score', 'split1_test_score', 'split1_train_score',...
+     'split2_test_score', 'split2_train_score',...
+     'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
 
     Attributes
     ----------
@@ -801,17 +831,28 @@ class GridSearchCV(BaseSearchCV):
                                         mask = [ True  True False False]...),
             'param_degree': masked_array(data = [2.0 3.0 -- --],
                                          mask = [False False  True  True]...),
-            'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
-            'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
-            'mean_test_score'   : [0.81, 0.60, 0.75, 0.82],
-            'std_test_score'    : [0.02, 0.01, 0.03, 0.03],
-            'rank_test_score'   : [2, 4, 3, 1],
-            'params'            : [{'kernel': 'poly', 'degree': 2}, ...],
+            'split0_test_score'  : [0.8, 0.7, 0.8, 0.9],
+            'split1_test_score'  : [0.82, 0.5, 0.7, 0.78],
+            'mean_test_score'    : [0.81, 0.60, 0.75, 0.82],
+            'std_test_score'     : [0.02, 0.01, 0.03, 0.03],
+            'rank_test_score'    : [2, 4, 3, 1],
+            'split0_train_score' : [0.8, 0.9, 0.7],
+            'split1_train_score' : [0.82, 0.5, 0.7],
+            'mean_train_score'   : [0.81, 0.7, 0.7],
+            'std_train_score'    : [0.03, 0.03, 0.04],
+            'mean_fit_time'      : [0.73, 0.63, 0.43, 0.49],
+            'std_fit_time'       : [0.01, 0.02, 0.01, 0.01],
+            'mean_score_time'    : [0.007, 0.06, 0.04, 0.04],
+            'std_score_time'     : [0.001, 0.002, 0.003, 0.005],
+            'params'             : [{'kernel': 'poly', 'degree': 2}, ...],
             }
 
         NOTE that the key ``'params'`` is used to store a list of parameter
         settings dict for all the parameter candidates.
 
+        The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
+        ``std_score_time`` are all in seconds.
+
     best_estimator_ : estimator
         Estimator that was chosen by the search, i.e. estimator
         which gave highest score (or smallest loss if specified)
@@ -868,11 +909,13 @@ class GridSearchCV(BaseSearchCV):
 
     def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
                  n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
-                 pre_dispatch='2*n_jobs', error_score='raise'):
+                 pre_dispatch='2*n_jobs', error_score='raise',
+                 return_train_score=True):
         super(GridSearchCV, self).__init__(
             estimator=estimator, scoring=scoring, fit_params=fit_params,
             n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
-            pre_dispatch=pre_dispatch, error_score=error_score)
+            pre_dispatch=pre_dispatch, error_score=error_score,
+            return_train_score=return_train_score)
         self.param_grid = param_grid
         _check_param_grid(param_grid)
 
@@ -1006,6 +1049,10 @@ class RandomizedSearchCV(BaseSearchCV):
         FitFailedWarning is raised. This parameter does not affect the refit
         step, which will always raise the error.
 
+    return_train_score : boolean, default=True
+        If ``'False'``, the ``cv_results_`` attribute will not include training
+        scores.
+
     Attributes
     ----------
     cv_results_ : dict of numpy (masked) ndarrays
@@ -1030,17 +1077,28 @@ class RandomizedSearchCV(BaseSearchCV):
             'param_kernel' : masked_array(data = ['rbf', rbf', 'rbf'],
                                           mask = False),
             'param_gamma'  : masked_array(data = [0.1 0.2 0.3], mask = False),
-            'split0_test_score' : [0.8, 0.9, 0.7],
-            'split1_test_score' : [0.82, 0.5, 0.7],
-            'mean_test_score'   : [0.81, 0.7, 0.7],
-            'std_test_score'    : [0.02, 0.2, 0.],
-            'rank_test_score'   : [3, 1, 1],
+            'split0_test_score'  : [0.8, 0.9, 0.7],
+            'split1_test_score'  : [0.82, 0.5, 0.7],
+            'mean_test_score'    : [0.81, 0.7, 0.7],
+            'std_test_score'     : [0.02, 0.2, 0.],
+            'rank_test_score'    : [3, 1, 1],
+            'split0_train_score' : [0.8, 0.9, 0.7],
+            'split1_train_score' : [0.82, 0.5, 0.7],
+            'mean_train_score'   : [0.81, 0.7, 0.7],
+            'std_train_score'    : [0.03, 0.03, 0.04],
+            'mean_fit_time'      : [0.73, 0.63, 0.43, 0.49],
+            'std_fit_time'       : [0.01, 0.02, 0.01, 0.01],
+            'mean_score_time'    : [0.007, 0.06, 0.04, 0.04],
+            'std_score_time'     : [0.001, 0.002, 0.003, 0.005],
             'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
             }
 
         NOTE that the key ``'params'`` is used to store a list of parameter
         settings dict for all the parameter candidates.
 
+        The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
+        ``std_score_time`` are all in seconds.
+
     best_estimator_ : estimator
         Estimator that was chosen by the search, i.e. estimator
         which gave highest score (or smallest loss if specified)
@@ -1094,15 +1152,15 @@ class RandomizedSearchCV(BaseSearchCV):
     def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
                  fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
                  verbose=0, pre_dispatch='2*n_jobs', random_state=None,
-                 error_score='raise'):
-
+                 error_score='raise', return_train_score=True):
         self.param_distributions = param_distributions
         self.n_iter = n_iter
         self.random_state = random_state
         super(RandomizedSearchCV, self).__init__(
-            estimator=estimator, scoring=scoring, fit_params=fit_params,
-            n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
-            pre_dispatch=pre_dispatch, error_score=error_score)
+             estimator=estimator, scoring=scoring, fit_params=fit_params,
+             n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
+             pre_dispatch=pre_dispatch, error_score=error_score,
+             return_train_score=return_train_score)
 
     def fit(self, X, y=None, groups=None):
         """Run fit on the estimator with randomly drawn parameters.
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index d82a62707ea9a356c3b420fa59cbc6619080364b..9745cb9decf731b21ce874f29b0b4bf095ee8bf5 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -1,3 +1,4 @@
+
 """
 The :mod:`sklearn.model_selection._validation` module includes classes and
 functions to validate the model.
@@ -142,7 +143,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
 
 def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
                    parameters, fit_params, return_train_score=False,
-                   return_parameters=False, error_score='raise'):
+                   return_parameters=False, return_n_test_samples=False,
+                   return_times=False, error_score='raise'):
     """Fit estimator and compute scores for a given dataset split.
 
     Parameters
@@ -199,8 +201,11 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
     n_test_samples : int
         Number of test samples.
 
-    scoring_time : float
-        Time spent for fitting and scoring in seconds.
+    fit_time : float
+        Time spent for fitting in seconds.
+
+    score_time : float
+        Time spent for scoring in seconds.
 
     parameters : dict or None, optional
         The parameters that have been evaluated.
@@ -233,6 +238,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
             estimator.fit(X_train, y_train, **fit_params)
 
     except Exception as e:
+        # Note fit time as time until error
+        fit_time = time.time() - start_time
+        score_time = 0.0
         if error_score == 'raise':
             raise
         elif isinstance(error_score, numbers.Number):
@@ -248,20 +256,24 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
                              " make sure that it has been spelled correctly.)")
 
     else:
+        fit_time = time.time() - start_time
         test_score = _score(estimator, X_test, y_test, scorer)
+        score_time = time.time() - start_time - fit_time
         if return_train_score:
             train_score = _score(estimator, X_train, y_train, scorer)
 
-    scoring_time = time.time() - start_time
-
     if verbose > 2:
         msg += ", score=%f" % test_score
     if verbose > 1:
-        end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time))
+        end_msg = "%s -%s" % (msg, logger.short_format_time(score_time))
         print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
 
-    ret = [train_score] if return_train_score else []
-    ret.extend([test_score, _num_samples(X_test), scoring_time])
+    ret = [train_score, test_score] if return_train_score else [test_score]
+
+    if return_n_test_samples:
+        ret.append(_num_samples(X_test))
+    if return_times:
+        ret.extend([fit_time, score_time])
     if return_parameters:
         ret.append(parameters)
     return ret
@@ -758,7 +770,7 @@ def learning_curve(estimator, X, y, groups=None,
             verbose, parameters=None, fit_params=None, return_train_score=True)
             for train, test in cv_iter
             for n_train_samples in train_sizes_abs)
-        out = np.array(out)[:, :2]
+        out = np.array(out)
         n_cv_folds = out.shape[0] // n_unique_ticks
         out = out.reshape(n_cv_folds, n_unique_ticks, 2)
 
@@ -941,7 +953,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
         parameters={param_name: v}, fit_params=None, return_train_score=True)
         for train, test in cv.split(X, y, groups) for v in param_range)
 
-    out = np.asarray(out)[:, :2]
+    out = np.asarray(out)
     n_params = len(param_range)
     n_cv_folds = out.shape[0] // n_params
     out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index bb21a386d35b73b001003d07cc170cbda1d3345e..fa4949d317052bb013545d2096ad19f1fe833908 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -595,33 +595,33 @@ def test_param_sampler():
         assert_equal([x for x in sampler], [x for x in sampler])
 
 
-def check_cv_results_array_types(results, param_keys, score_keys):
-    # Check if the search results' array are of correct types
-    assert_true(all(isinstance(results[param], np.ma.MaskedArray)
+def check_cv_results_array_types(cv_results, param_keys, score_keys):
+    # Check if the search `cv_results`'s array are of correct types
+    assert_true(all(isinstance(cv_results[param], np.ma.MaskedArray)
                     for param in param_keys))
-    assert_true(all(results[key].dtype == object for key in param_keys))
-    assert_false(any(isinstance(results[key], np.ma.MaskedArray)
+    assert_true(all(cv_results[key].dtype == object for key in param_keys))
+    assert_false(any(isinstance(cv_results[key], np.ma.MaskedArray)
                      for key in score_keys))
-    assert_true(all(results[key].dtype == np.float64
-                    for key in score_keys if key != 'rank_test_score'))
-    assert_true(results['rank_test_score'].dtype == np.int32)
+    assert_true(all(cv_results[key].dtype == np.float64
+                    for key in score_keys if not key.startswith('rank')))
+    assert_true(cv_results['rank_test_score'].dtype == np.int32)
 
 
-def check_cv_results_keys(results, param_keys, score_keys, n_cand):
+def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand):
     # Test the search.cv_results_ contains all the required results
-    assert_array_equal(sorted(results.keys()),
+    assert_array_equal(sorted(cv_results.keys()),
                        sorted(param_keys + score_keys + ('params',)))
-    assert_true(all(results[key].shape == (n_cand,)
+    assert_true(all(cv_results[key].shape == (n_cand,)
                     for key in param_keys + score_keys))
 
 
 def check_cv_results_grid_scores_consistency(search):
     # TODO Remove in 0.20
-    results = search.cv_results_
-    res_scores = np.vstack(list([results["split%d_test_score" % i]
+    cv_results = search.cv_results_
+    res_scores = np.vstack(list([cv_results["split%d_test_score" % i]
                                  for i in range(search.n_splits_)])).T
-    res_means = results["mean_test_score"]
-    res_params = results["params"]
+    res_means = cv_results["mean_test_score"]
+    res_params = cv_results["params"]
     n_cand = len(res_params)
     grid_scores = assert_warns(DeprecationWarning, getattr,
                                search, 'grid_scores_')
@@ -634,7 +634,7 @@ def check_cv_results_grid_scores_consistency(search):
         assert_array_equal(grid_scores[i].mean_validation_score, res_means[i])
 
 
-def test_grid_search_results():
+def test_grid_search_cv_results():
     X, y = make_classification(n_samples=50, n_features=4,
                                random_state=42)
 
@@ -650,34 +650,47 @@ def test_grid_search_results():
     grid_search_iid.fit(X, y)
 
     param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel')
-    score_keys = ('mean_test_score', 'rank_test_score',
+    score_keys = ('mean_test_score', 'mean_train_score',
+                  'rank_test_score',
                   'split0_test_score', 'split1_test_score',
-                  'split2_test_score', 'std_test_score')
+                  'split2_test_score',
+                  'split0_train_score', 'split1_train_score',
+                  'split2_train_score',
+                  'std_test_score', 'std_train_score',
+                  'mean_fit_time', 'std_fit_time',
+                  'mean_score_time', 'std_score_time')
     n_candidates = n_grid_points
 
     for search, iid in zip((grid_search, grid_search_iid), (False, True)):
         assert_equal(iid, search.iid)
-        results = search.cv_results_
-        # Check results structure
-        check_cv_results_array_types(results, param_keys, score_keys)
-        check_cv_results_keys(results, param_keys, score_keys, n_candidates)
+        cv_results = search.cv_results_
+        # Check if score and timing are reasonable
+        assert_true(all(cv_results['rank_test_score'] >= 1))
+        assert_true(all(cv_results[k] >= 0) for k in score_keys
+                    if k is not 'rank_test_score')
+        assert_true(all(cv_results[k] <= 1) for k in score_keys
+                    if 'time' not in k and
+                    k is not 'rank_test_score')
+        # Check cv_results structure
+        check_cv_results_array_types(cv_results, param_keys, score_keys)
+        check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
         # Check masking
-        results = grid_search.cv_results_
+        cv_results = grid_search.cv_results_
         n_candidates = len(grid_search.cv_results_['params'])
-        assert_true(all((results['param_C'].mask[i] and
-                         results['param_gamma'].mask[i] and
-                         not results['param_degree'].mask[i])
+        assert_true(all((cv_results['param_C'].mask[i] and
+                         cv_results['param_gamma'].mask[i] and
+                         not cv_results['param_degree'].mask[i])
                         for i in range(n_candidates)
-                        if results['param_kernel'][i] == 'linear'))
-        assert_true(all((not results['param_C'].mask[i] and
-                         not results['param_gamma'].mask[i] and
-                         results['param_degree'].mask[i])
+                        if cv_results['param_kernel'][i] == 'linear'))
+        assert_true(all((not cv_results['param_C'].mask[i] and
+                         not cv_results['param_gamma'].mask[i] and
+                         cv_results['param_degree'].mask[i])
                         for i in range(n_candidates)
-                        if results['param_kernel'][i] == 'rbf'))
+                        if cv_results['param_kernel'][i] == 'rbf'))
         check_cv_results_grid_scores_consistency(search)
 
 
-def test_random_search_results():
+def test_random_search_cv_results():
     # Make a dataset with a lot of noise to get various kind of prediction
     # errors across CV folds and parameter settings
     X, y = make_classification(n_samples=200, n_features=100, n_informative=3,
@@ -690,8 +703,8 @@ def test_random_search_results():
     n_search_iter = 30
     params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
     random_search = RandomizedSearchCV(SVC(), n_iter=n_search_iter,
-                                       cv=n_splits,
-                                       iid=False, param_distributions=params)
+                                       cv=n_splits, iid=False,
+                                       param_distributions=params)
     random_search.fit(X, y)
     random_search_iid = RandomizedSearchCV(SVC(), n_iter=n_search_iter,
                                            cv=n_splits, iid=True,
@@ -699,20 +712,26 @@ def test_random_search_results():
     random_search_iid.fit(X, y)
 
     param_keys = ('param_C', 'param_gamma')
-    score_keys = ('mean_test_score', 'rank_test_score',
+    score_keys = ('mean_test_score', 'mean_train_score',
+                  'rank_test_score',
                   'split0_test_score', 'split1_test_score',
-                  'split2_test_score', 'std_test_score')
+                  'split2_test_score',
+                  'split0_train_score', 'split1_train_score',
+                  'split2_train_score',
+                  'std_test_score', 'std_train_score',
+                  'mean_fit_time', 'std_fit_time',
+                  'mean_score_time', 'std_score_time')
     n_cand = n_search_iter
 
     for search, iid in zip((random_search, random_search_iid), (False, True)):
         assert_equal(iid, search.iid)
-        results = search.cv_results_
+        cv_results = search.cv_results_
         # Check results structure
-        check_cv_results_array_types(results, param_keys, score_keys)
-        check_cv_results_keys(results, param_keys, score_keys, n_cand)
+        check_cv_results_array_types(cv_results, param_keys, score_keys)
+        check_cv_results_keys(cv_results, param_keys, score_keys, n_cand)
         # For random_search, all the param array vals should be unmasked
-        assert_false(any(results['param_C'].mask) or
-                     any(results['param_gamma'].mask))
+        assert_false(any(cv_results['param_C'].mask) or
+                     any(cv_results['param_gamma'].mask))
         check_cv_results_grid_scores_consistency(search)
 
 
@@ -739,22 +758,39 @@ def test_search_iid_param():
         search.fit(X, y)
         assert_true(search.iid)
 
-        # Test the first candidate
-        cv_scores = np.array(list(search.cv_results_['split%d_test_score'
-                                                     % s][0]
-                                  for s in range(search.n_splits_)))
-        mean = search.cv_results_['mean_test_score'][0]
-        std = search.cv_results_['std_test_score'][0]
+        test_cv_scores = np.array(list(search.cv_results_['split%d_test_score'
+                                                          % s_i][0]
+                                       for s_i in range(search.n_splits_)))
+        train_cv_scores = np.array(list(search.cv_results_['split%d_train_'
+                                                           'score' % s_i][0]
+                                        for s_i in range(search.n_splits_)))
+        test_mean = search.cv_results_['mean_test_score'][0]
+        test_std = search.cv_results_['std_test_score'][0]
+
+        train_cv_scores = np.array(list(search.cv_results_['split%d_train_'
+                                                           'score' % s_i][0]
+                                        for s_i in range(search.n_splits_)))
+        train_mean = search.cv_results_['mean_train_score'][0]
+        train_std = search.cv_results_['std_train_score'][0]
 
+        # Test the first candidate
         assert_equal(search.cv_results_['param_C'][0], 1)
-        assert_array_almost_equal(cv_scores, [1, 1. / 3.])
+        assert_array_almost_equal(test_cv_scores, [1, 1. / 3.])
+        assert_array_almost_equal(train_cv_scores, [1, 1])
+
         # for first split, 1/4 of dataset is in test, for second 3/4.
         # take weighted average and weighted std
-        expected_mean = 1 * 1. / 4. + 1. / 3. * 3. / 4.
-        expected_std = np.sqrt(1. / 4 * (expected_mean - 1) ** 2 +
-                               3. / 4 * (expected_mean - 1. / 3.) ** 2)
-        assert_almost_equal(mean, expected_mean)
-        assert_almost_equal(std, expected_std)
+        expected_test_mean = 1 * 1. / 4. + 1. / 3. * 3. / 4.
+        expected_test_std = np.sqrt(1. / 4 * (expected_test_mean - 1) ** 2 +
+                                    3. / 4 * (expected_test_mean - 1. / 3.) **
+                                    2)
+        assert_almost_equal(test_mean, expected_test_mean)
+        assert_almost_equal(test_std, expected_test_std)
+
+        # For the train scores, we do not take a weighted mean irrespective of
+        # i.i.d. or not
+        assert_almost_equal(train_mean, 1)
+        assert_almost_equal(train_std, 0)
 
     # once with iid=False
     grid_search = GridSearchCV(SVC(),
@@ -768,17 +804,29 @@ def test_search_iid_param():
         search.fit(X, y)
         assert_false(search.iid)
 
-        cv_scores = np.array(list(search.cv_results_['split%d_test_score'
-                                                     % s][0]
-                                  for s in range(search.n_splits_)))
-        mean = search.cv_results_['mean_test_score'][0]
-        std = search.cv_results_['std_test_score'][0]
+        test_cv_scores = np.array(list(search.cv_results_['split%d_test_score'
+                                                          % s][0]
+                                       for s in range(search.n_splits_)))
+        test_mean = search.cv_results_['mean_test_score'][0]
+        test_std = search.cv_results_['std_test_score'][0]
+
+        train_cv_scores = np.array(list(search.cv_results_['split%d_train_'
+                                                           'score' % s][0]
+                                        for s in range(search.n_splits_)))
+        train_mean = search.cv_results_['mean_train_score'][0]
+        train_std = search.cv_results_['std_train_score'][0]
+
         assert_equal(search.cv_results_['param_C'][0], 1)
         # scores are the same as above
-        assert_array_almost_equal(cv_scores, [1, 1. / 3.])
+        assert_array_almost_equal(test_cv_scores, [1, 1. / 3.])
         # Unweighted mean/std is used
-        assert_almost_equal(mean, np.mean(cv_scores))
-        assert_almost_equal(std, np.std(cv_scores))
+        assert_almost_equal(test_mean, np.mean(test_cv_scores))
+        assert_almost_equal(test_std, np.std(test_cv_scores))
+
+        # For the train scores, we do not take a weighted mean irrespective of
+        # i.i.d. or not
+        assert_almost_equal(train_mean, 1)
+        assert_almost_equal(train_std, 0)
 
 
 def test_search_cv_results_rank_tie_breaking():
@@ -794,15 +842,22 @@ def test_search_cv_results_rank_tie_breaking():
 
     for search in (grid_search, random_search):
         search.fit(X, y)
-        results = search.cv_results_
+        cv_results = search.cv_results_
         # Check tie breaking strategy -
         # Check that there is a tie in the mean scores between
         # candidates 1 and 2 alone
-        assert_almost_equal(results['mean_test_score'][0],
-                            results['mean_test_score'][1])
+        assert_almost_equal(cv_results['mean_test_score'][0],
+                            cv_results['mean_test_score'][1])
+        assert_almost_equal(cv_results['mean_train_score'][0],
+                            cv_results['mean_train_score'][1])
+        try:
+            assert_almost_equal(cv_results['mean_test_score'][1],
+                                cv_results['mean_test_score'][2])
+        except AssertionError:
+            pass
         try:
-            assert_almost_equal(results['mean_test_score'][1],
-                                results['mean_test_score'][2])
+            assert_almost_equal(cv_results['mean_train_score'][1],
+                                cv_results['mean_train_score'][2])
         except AssertionError:
             pass
         # 'min' rank should be assigned to the tied candidates
@@ -821,6 +876,30 @@ def test_search_cv_results_none_param():
                            [0, None])
 
 
+@ignore_warnings()
+def test_search_cv_timing():
+    svc = LinearSVC(random_state=0)
+
+    X = [[1, ], [2, ], [3, ], [4, ]]
+    y = [0, 1, 1, 0]
+
+    gs = GridSearchCV(svc, {'C': [0, 1]}, cv=2, error_score=0)
+    rs = RandomizedSearchCV(svc, {'C': [0, 1]}, cv=2, error_score=0, n_iter=2)
+
+    for search in (gs, rs):
+        search.fit(X, y)
+        for key in ['mean_fit_time', 'std_fit_time']:
+            # NOTE The precision of time.time in windows is not high
+            # enough for the fit/score times to be non-zero for trivial X and y
+            assert_true(np.all(search.cv_results_[key] >= 0))
+            assert_true(np.all(search.cv_results_[key] < 1))
+
+        for key in ['mean_score_time', 'std_score_time']:
+            assert_true(search.cv_results_[key][1] >= 0)
+            assert_true(search.cv_results_[key][0] == 0.0)
+            assert_true(np.all(search.cv_results_[key] < 1))
+
+
 def test_grid_search_correct_score_results():
     # test that correct scores are used
     n_splits = 3
@@ -829,10 +908,10 @@ def test_grid_search_correct_score_results():
     Cs = [.1, 1, 10]
     for score in ['f1', 'roc_auc']:
         grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score, cv=n_splits)
-        results = grid_search.fit(X, y).cv_results_
+        cv_results = grid_search.fit(X, y).cv_results_
 
         # Test scorer names
-        result_keys = list(results.keys())
+        result_keys = list(cv_results.keys())
         expected_keys = (("mean_test_score", "rank_test_score") +
                          tuple("split%d_test_score" % cv_i
                                for cv_i in range(n_splits)))
@@ -1039,8 +1118,8 @@ def test_stochastic_gradient_loss_param():
     param_grid = {
         'loss': ['log'],
     }
-    X = np.arange(20).reshape(5, -1)
-    y = [0, 0, 1, 1, 1]
+    X = np.arange(24).reshape(6, -1)
+    y = [0, 0, 0, 1, 1, 1]
     clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
                        param_grid=param_grid)