diff --git a/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py b/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py
index eab418fd0d8baf415fc2bcb6275591709febb45f..9f747694064ac5584853965277a4d3b8bf44f1ae 100644
--- a/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py
+++ b/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py
@@ -55,12 +55,12 @@ if __name__ == "__main__":
 
     # TASK: print the mean and std for each candidate along with the parameter
     # settings for all the candidates explored by grid search.
-    n_candidates = len(grid_search.results_['params'])
+    n_candidates = len(grid_search.cv_results_['params'])
     for i in range(n_candidates):
         print(i, 'params - %s; mean - %0.2f; std - %0.2f'
-                 % (grid_search.results_['params'][i],
-                    grid_search.results_['test_mean_score'][i],
-                    grid_search.results_['test_std_score'][i]))
+                 % (grid_search.cv_results_['params'][i],
+                    grid_search.cv_results_['mean_test_score'][i],
+                    grid_search.cv_results_['std_test_score'][i]))
 
     # TASK: Predict the outcome on the testing set and store it in a variable
     # named y_predicted
diff --git a/doc/tutorial/text_analytics/working_with_text_data.rst b/doc/tutorial/text_analytics/working_with_text_data.rst
index e1aa2d50fca4e7c680b9128b3ff97704fc6f0d54..9248e9adbbcdff528a482108d9dcef8936d6c670 100644
--- a/doc/tutorial/text_analytics/working_with_text_data.rst
+++ b/doc/tutorial/text_analytics/working_with_text_data.rst
@@ -458,9 +458,9 @@ mean score and the parameters setting corresponding to that score::
   tfidf__use_idf: True
   vect__ngram_range: (1, 1)
 
-A more detailed summary of the search is available at ``gs_clf.results_``.
+A more detailed summary of the search is available at ``gs_clf.cv_results_``.
 
-The ``results_`` parameter can be easily imported into pandas as a
+The ``cv_results_`` parameter can be easily imported into pandas as a
 ``DataFrame`` for further inspection.
 
 .. note:
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index b5611d07b1a0e3f197a044a6130c8bafa2cfac0e..3da1dab97385259048329c2ec9b5389f17b3a9b0 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -39,27 +39,27 @@ Model Selection Enhancements and API Changes
     :class:`model_selection.GridSearchCV` and
     :class:`model_selection.RandomizedSearchCV` utilities.
 
-  - **The enhanced `results_` attribute**
+  - **The enhanced ``cv_results_`` attribute**
 
-    The new ``results_`` attribute (of :class:`model_selection.GridSearchCV`
+    The new ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`
     and :class:`model_selection.RandomizedSearchCV`) introduced in lieu of the
     ``grid_scores_`` attribute is a dict of 1D arrays with elements in each
     array corresponding to the parameter settings (i.e. search candidates).
 
-    The ``results_`` dict can be easily imported into ``pandas`` as a
+    The ``cv_results_`` dict can be easily imported into ``pandas`` as a
     ``DataFrame`` for exploring the search results.
 
-    The ``results_`` arrays include scores for each cross-validation split
-    (with keys such as ``test_split0_score``), as well as their mean
-    (``test_mean_score``) and standard deviation (``test_std_score``).
+    The ``cv_results_`` arrays include scores for each cross-validation split
+    (with keys such as ``'split0_test_score'``), as well as their mean
+    (``'mean_test_score'``) and standard deviation (``'std_test_score'``).
 
     The ranks for the search candidates (based on their mean
-    cross-validation score) is available at ``results_['test_rank_score']``.
+    cross-validation score) is available at ``cv_results_['rank_test_score']``.
 
     The parameter values for each parameter is stored separately as numpy
     masked object arrays. The value, for that search candidate, is masked if
     the corresponding parameter is not applicable. Additionally a list of all
-    the parameter dicts are stored at ``results_['params']``.
+    the parameter dicts are stored at ``cv_results_['params']``.
 
   - **Parameters ``n_folds`` and ``n_iter`` renamed to ``n_splits``**
 
@@ -235,7 +235,7 @@ Enhancements
    - The :func: `ignore_warnings` now accept a category argument to ignore only
      the warnings of a specified type. By `Thierry Guillemot`_.
 
-   - The new ``results_`` attribute of :class:`model_selection.GridSearchCV`
+   - The new ``cv_results_`` attribute of :class:`model_selection.GridSearchCV`
      (and :class:`model_selection.RandomizedSearchCV`) can be easily imported
      into pandas as a ``DataFrame``. Ref :ref:`model_selection_changes` for
      more information.
@@ -419,7 +419,7 @@ API changes summary
 
    - The ``grid_scores_`` attribute of :class:`model_selection.GridSearchCV`
      and :class:`model_selection.RandomizedSearchCV` is deprecated in favor of
-     the attribute ``results_``.
+     the attribute ``cv_results_``.
      Ref :ref:`model_selection_changes` for more information.
      (`#6697 <https://github.com/scikit-learn/scikit-learn/pull/6697>`_) by
      `Raghav R V`_.
diff --git a/examples/model_selection/grid_search_digits.py b/examples/model_selection/grid_search_digits.py
index 13755b0bc8c105ac1dd5f636dee405bb8ba523b1..50ba4dc097f8b1cfa887ac9103689f82bcf19307 100644
--- a/examples/model_selection/grid_search_digits.py
+++ b/examples/model_selection/grid_search_digits.py
@@ -60,11 +60,11 @@ for score in scores:
     print()
     print("Grid scores on development set:")
     print()
-    means = clf.results_['test_mean_score']
-    stds = clf.results_['test_std_score']
-    for i in range(len(clf.results_['params'])):
+    means = clf.cv_results_['mean_test_score']
+    stds = clf.cv_results_['std_test_score']
+    for mean, std, params in zip(means, stds, clf.cv_results_['params']):
         print("%0.3f (+/-%0.03f) for %r"
-              % (means[i], stds[i] * 2, clf.results_['params'][i]))
+              % (mean, std * 2, params))
     print()
 
     print("Detailed classification report:")
diff --git a/examples/model_selection/randomized_search.py b/examples/model_selection/randomized_search.py
index e1f7c215ab65359f8f2af1b5105cfda239ba5790..1024a01239d012357c8aec5339e427e40237e40a 100644
--- a/examples/model_selection/randomized_search.py
+++ b/examples/model_selection/randomized_search.py
@@ -41,12 +41,12 @@ clf = RandomForestClassifier(n_estimators=20)
 # Utility function to report best scores
 def report(results, n_top=3):
     for i in range(1, n_top + 1):
-        candidates = np.flatnonzero(results['test_rank_score'] == i)
+        candidates = np.flatnonzero(results['rank_test_score'] == i)
         for candidate in candidates:
             print("Model with rank: {0}".format(i))
             print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
-                  results['test_mean_score'][candidate],
-                  results['test_std_score'][candidate]))
+                  results['mean_test_score'][candidate],
+                  results['std_test_score'][candidate]))
             print("Parameters: {0}".format(results['params'][candidate]))
             print("")
 
@@ -68,7 +68,7 @@ start = time()
 random_search.fit(X, y)
 print("RandomizedSearchCV took %.2f seconds for %d candidates"
       " parameter settings." % ((time() - start), n_iter_search))
-report(random_search.results_)
+report(random_search.cv_results_)
 
 # use a full grid over all parameters
 param_grid = {"max_depth": [3, None],
@@ -84,5 +84,5 @@ start = time()
 grid_search.fit(X, y)
 
 print("GridSearchCV took %.2f seconds for %d candidate parameter settings."
-      % (time() - start, len(grid_search.results_['params'])))
-report(grid_search.results_)
+      % (time() - start, len(grid_search.cv_results_['params'])))
+report(grid_search.cv_results_)
diff --git a/examples/plot_compare_reduction.py b/examples/plot_compare_reduction.py
index 0f578722ee68cc42db21aa4c43a2ef6361d235bf..1c84ea9c3a4dc907534d79dc49ffe270a4a26372 100644
--- a/examples/plot_compare_reduction.py
+++ b/examples/plot_compare_reduction.py
@@ -53,7 +53,7 @@ grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid)
 digits = load_digits()
 grid.fit(digits.data, digits.target)
 
-mean_scores = np.array(grid.results_['test_mean_score'])
+mean_scores = np.array(grid.cv_results_['mean_test_score'])
 # scores are in the order of param_grid iteration, which is alphabetical
 mean_scores = mean_scores.reshape(len(C_OPTIONS), -1, len(N_FEATURES_OPTIONS))
 # select score for best C
diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py
index b71d6b22dc7c4e9a390ff25fb7c41e6259ee62e9..acec9896169b8bd014cdbb4363b29d6d6c8a4781 100644
--- a/examples/svm/plot_rbf_parameters.py
+++ b/examples/svm/plot_rbf_parameters.py
@@ -171,8 +171,8 @@ for (k, (C, gamma, clf)) in enumerate(classifiers):
     plt.yticks(())
     plt.axis('tight')
 
-scores = grid.results_['test_mean_score'].reshape(len(C_range),
-                                                  len(gamma_range))
+scores = grid.cv_results_['mean_test_score'].reshape(len(C_range),
+                                                     len(gamma_range))
 
 # Draw heatmap of the validation accuracy as a function of gamma and C
 #
diff --git a/examples/svm/plot_svm_scale_c.py b/examples/svm/plot_svm_scale_c.py
index 09934c2f5d8592f172c2d3bacea760374598be2a..5d72ca61d5157c1d028f7484e2fb55da3ae68354 100644
--- a/examples/svm/plot_svm_scale_c.py
+++ b/examples/svm/plot_svm_scale_c.py
@@ -131,7 +131,7 @@ for fignum, (clf, cs, X, y) in enumerate(clf_sets):
                             cv=ShuffleSplit(train_size=train_size,
                                             n_splits=250, random_state=1))
         grid.fit(X, y)
-        scores = grid.results_['test_mean_score']
+        scores = grid.cv_results_['mean_test_score']
 
         scales = [(1, 'No scaling'),
                   ((n_samples * train_size), '1/n_samples'),
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index be3509770141ba0369669177bc0c6de3a943e462..435cabc68d49a6e6f8c9671aa932db42646985c6 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -573,17 +573,18 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2,
                                   axis=1, weights=weights))
 
-        results = dict()
+        cv_results = dict()
         for split_i in range(n_splits):
-            results["test_split%d_score" % split_i] = test_scores[:, split_i]
-        results["test_mean_score"] = means
-        results["test_std_score"] = stds
+            cv_results["split%d_test_score" % split_i] = test_scores[:,
+                                                                     split_i]
+        cv_results["mean_test_score"] = means
+        cv_results["std_test_score"] = stds
 
         ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32)
 
         best_index = np.flatnonzero(ranks == 1)[0]
         best_parameters = candidate_params[best_index]
-        results["test_rank_score"] = ranks
+        cv_results["rank_test_score"] = ranks
 
         # Use one np.MaskedArray and mask all the places where the param is not
         # applicable for that candidate. Use defaultdict as each candidate may
@@ -597,12 +598,12 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
                 # Setting the value at an index also unmasks that index
                 param_results["param_%s" % name][cand_i] = value
 
-        results.update(param_results)
+        cv_results.update(param_results)
 
         # Store a list of param dicts at the key 'params'
-        results['params'] = candidate_params
+        cv_results['params'] = candidate_params
 
-        self.results_ = results
+        self.cv_results_ = cv_results
         self.best_index_ = best_index
         self.n_splits_ = n_splits
 
@@ -620,30 +621,31 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
 
     @property
     def best_params_(self):
-        check_is_fitted(self, 'results_')
-        return self.results_['params'][self.best_index_]
+        check_is_fitted(self, 'cv_results_')
+        return self.cv_results_['params'][self.best_index_]
 
     @property
     def best_score_(self):
-        check_is_fitted(self, 'results_')
-        return self.results_['test_mean_score'][self.best_index_]
+        check_is_fitted(self, 'cv_results_')
+        return self.cv_results_['mean_test_score'][self.best_index_]
 
     @property
     def grid_scores_(self):
         warnings.warn(
             "The grid_scores_ attribute was deprecated in version 0.18"
-            " in favor of the more elaborate results_ attribute."
+            " in favor of the more elaborate cv_results_ attribute."
             " The grid_scores_ attribute will not be available from 0.20",
             DeprecationWarning)
 
-        check_is_fitted(self, 'results_')
+        check_is_fitted(self, 'cv_results_')
         grid_scores = list()
 
         for i, (params, mean, std) in enumerate(zip(
-                self.results_['params'],
-                self.results_['test_mean_score'],
-                self.results_['test_std_score'])):
-            scores = np.array(list(self.results_['test_split%d_score' % s][i]
+                self.cv_results_['params'],
+                self.cv_results_['mean_test_score'],
+                self.cv_results_['std_test_score'])):
+            scores = np.array(list(self.cv_results_['split%d_test_score'
+                                                    % s][i]
                                    for s in range(self.n_splits_)),
                               dtype=np.float64)
             grid_scores.append(_CVScoreTuple(params, mean, scores))
@@ -763,22 +765,22 @@ class GridSearchCV(BaseSearchCV):
            fit_params={}, iid=..., n_jobs=1,
            param_grid=..., pre_dispatch=..., refit=...,
            scoring=..., verbose=...)
-    >>> sorted(clf.results_.keys())
+    >>> sorted(clf.cv_results_.keys())
     ...                             # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
-    ['param_C', 'param_kernel', 'params', 'test_mean_score',...
-     'test_rank_score', 'test_split0_score', 'test_split1_score',...
-     'test_split2_score', 'test_std_score']
+    ['mean_test_score', 'param_C', 'param_kernel', 'params',...
+     'rank_test_score', 'split0_test_score', 'split1_test_score',...
+     'split2_test_score', 'std_test_score']
 
     Attributes
     ----------
-    results_ : dict of numpy (masked) ndarrays
+    cv_results_ : dict of numpy (masked) ndarrays
         A dict with keys as column headers and values as columns, that can be
         imported into a pandas ``DataFrame``.
 
         For instance the below given table
 
         +------------+-----------+------------+-----------------+---+---------+
-        |param_kernel|param_gamma|param_degree|test_split0_score|...|...rank..|
+        |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_....|
         +============+===========+============+=================+===+=========+
         |  'poly'    |     --    |      2     |        0.8      |...|    2    |
         +------------+-----------+------------+-----------------+---+---------+
@@ -789,7 +791,7 @@ class GridSearchCV(BaseSearchCV):
         |  'rbf'     |     0.2   |     --     |        0.9      |...|    1    |
         +------------+-----------+------------+-----------------+---+---------+
 
-        will be represented by a ``results_`` dict of::
+        will be represented by a ``cv_results_`` dict of::
 
             {
             'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
@@ -798,11 +800,11 @@ class GridSearchCV(BaseSearchCV):
                                         mask = [ True  True False False]...),
             'param_degree': masked_array(data = [2.0 3.0 -- --],
                                          mask = [False False  True  True]...),
-            'test_split0_score' : [0.8, 0.7, 0.8, 0.9],
-            'test_split1_score' : [0.82, 0.5, 0.7, 0.78],
-            'test_mean_score'   : [0.81, 0.60, 0.75, 0.82],
-            'test_std_score'    : [0.02, 0.01, 0.03, 0.03],
-            'test_rank_score'   : [2, 4, 3, 1],
+            'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
+            'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
+            'mean_test_score'   : [0.81, 0.60, 0.75, 0.82],
+            'std_test_score'    : [0.02, 0.01, 0.03, 0.03],
+            'rank_test_score'   : [2, 4, 3, 1],
             'params'            : [{'kernel': 'poly', 'degree': 2}, ...],
             }
 
@@ -821,10 +823,10 @@ class GridSearchCV(BaseSearchCV):
         Parameter setting that gave the best results on the hold out data.
 
     best_index_ : int
-        The index (of the ``results_`` arrays) which corresponds to the best
+        The index (of the ``cv_results_`` arrays) which corresponds to the best
         candidate parameter setting.
 
-        The dict at ``search.results_['params'][search.best_index_]`` gives
+        The dict at ``search.cv_results_['params'][search.best_index_]`` gives
         the parameter setting for the best model, that gives the highest
         mean score (``search.best_score_``).
 
@@ -1005,14 +1007,14 @@ class RandomizedSearchCV(BaseSearchCV):
 
     Attributes
     ----------
-    results_ : dict of numpy (masked) ndarrays
+    cv_results_ : dict of numpy (masked) ndarrays
         A dict with keys as column headers and values as columns, that can be
         imported into a pandas ``DataFrame``.
 
         For instance the below given table
 
         +--------------+-------------+-------------------+---+---------------+
-        | param_kernel | param_gamma | test_split0_score |...|test_rank_score|
+        | param_kernel | param_gamma | split0_test_score |...|rank_test_score|
         +==============+=============+===================+===+===============+
         |    'rbf'     |     0.1     |        0.8        |...|       2       |
         +--------------+-------------+-------------------+---+---------------+
@@ -1021,17 +1023,17 @@ class RandomizedSearchCV(BaseSearchCV):
         |    'rbf'     |     0.3     |        0.7        |...|       1       |
         +--------------+-------------+-------------------+---+---------------+
 
-        will be represented by a ``results_`` dict of::
+        will be represented by a ``cv_results_`` dict of::
 
             {
             'param_kernel' : masked_array(data = ['rbf', rbf', 'rbf'],
                                           mask = False),
             'param_gamma'  : masked_array(data = [0.1 0.2 0.3], mask = False),
-            'test_split0_score' : [0.8, 0.9, 0.7],
-            'test_split1_score' : [0.82, 0.5, 0.7],
-            'test_mean_score'   : [0.81, 0.7, 0.7],
-            'test_std_score'    : [0.02, 0.2, 0.],
-            'test_rank_score'   : [3, 1, 1],
+            'split0_test_score' : [0.8, 0.9, 0.7],
+            'split1_test_score' : [0.82, 0.5, 0.7],
+            'mean_test_score'   : [0.81, 0.7, 0.7],
+            'std_test_score'    : [0.02, 0.2, 0.],
+            'rank_test_score'   : [3, 1, 1],
             'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
             }
 
@@ -1050,10 +1052,10 @@ class RandomizedSearchCV(BaseSearchCV):
         Parameter setting that gave the best results on the hold out data.
 
     best_index_ : int
-        The index (of the ``results_`` arrays) which corresponds to the best
+        The index (of the ``cv_results_`` arrays) which corresponds to the best
         candidate parameter setting.
 
-        The dict at ``search.results_['params'][search.best_index_]`` gives
+        The dict at ``search.cv_results_['params'][search.best_index_]`` gives
         the parameter setting for the best model, that gives the highest
         mean score (``search.best_score_``).
 
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index b5abd8d873a51bbcb58e86065537090203a520a8..65fc1964ddf315ae19b4865ae928285c04648714 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -155,7 +155,8 @@ def test_grid_search():
     sys.stdout = old_stdout
     assert_equal(grid_search.best_estimator_.foo_param, 2)
 
-    assert_array_equal(grid_search.results_["param_foo_param"].data, [1, 2, 3])
+    assert_array_equal(grid_search.cv_results_["param_foo_param"].data,
+                       [1, 2, 3])
 
     # Smoke test the score etc:
     grid_search.score(X, y)
@@ -265,17 +266,17 @@ def test_grid_search_labels():
         gs.fit(X, y)
 
 
-def test_trivial_results_attr():
+def test_trivial_cv_results_attr():
     # Test search over a "grid" with only one point.
     # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV.
     clf = MockClassifier()
     grid_search = GridSearchCV(clf, {'foo_param': [1]})
     grid_search.fit(X, y)
-    assert_true(hasattr(grid_search, "results_"))
+    assert_true(hasattr(grid_search, "cv_results_"))
 
     random_search = RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1)
     random_search.fit(X, y)
-    assert_true(hasattr(grid_search, "results_"))
+    assert_true(hasattr(grid_search, "cv_results_"))
 
 
 def test_no_refit():
@@ -472,7 +473,7 @@ def test_gridsearch_nd():
     clf = CheckingClassifier(check_X=check_X, check_y=check_y)
     grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
     grid_search.fit(X_4d, y_3d).score(X, y)
-    assert_true(hasattr(grid_search, "results_"))
+    assert_true(hasattr(grid_search, "cv_results_"))
 
 
 def test_X_as_list():
@@ -484,7 +485,7 @@ def test_X_as_list():
     cv = KFold(n_splits=3)
     grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv)
     grid_search.fit(X.tolist(), y).score(X, y)
-    assert_true(hasattr(grid_search, "results_"))
+    assert_true(hasattr(grid_search, "cv_results_"))
 
 
 def test_y_as_list():
@@ -496,7 +497,7 @@ def test_y_as_list():
     cv = KFold(n_splits=3)
     grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv)
     grid_search.fit(X, y.tolist()).score(X, y)
-    assert_true(hasattr(grid_search, "results_"))
+    assert_true(hasattr(grid_search, "cv_results_"))
 
 
 @ignore_warnings
@@ -522,7 +523,7 @@ def test_pandas_input():
         grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
         grid_search.fit(X_df, y_ser).score(X_df, y_ser)
         grid_search.predict(X_df)
-        assert_true(hasattr(grid_search, "results_"))
+        assert_true(hasattr(grid_search, "cv_results_"))
 
 
 def test_unsupervised_grid_search():
@@ -581,7 +582,7 @@ def test_param_sampler():
         assert_equal([x for x in sampler], [x for x in sampler])
 
 
-def check_results_array_types(results, param_keys, score_keys):
+def check_cv_results_array_types(results, param_keys, score_keys):
     # Check if the search results' array are of correct types
     assert_true(all(isinstance(results[param], np.ma.MaskedArray)
                     for param in param_keys))
@@ -589,24 +590,24 @@ def check_results_array_types(results, param_keys, score_keys):
     assert_false(any(isinstance(results[key], np.ma.MaskedArray)
                      for key in score_keys))
     assert_true(all(results[key].dtype == np.float64
-                    for key in score_keys if key != 'test_rank_score'))
-    assert_true(results['test_rank_score'].dtype == np.int32)
+                    for key in score_keys if key != 'rank_test_score'))
+    assert_true(results['rank_test_score'].dtype == np.int32)
 
 
-def check_results_keys(results, param_keys, score_keys, n_cand):
-    # Test the search.results_ contains all the required results
+def check_cv_results_keys(results, param_keys, score_keys, n_cand):
+    # Test the search.cv_results_ contains all the required results
     assert_array_equal(sorted(results.keys()),
                        sorted(param_keys + score_keys + ('params',)))
     assert_true(all(results[key].shape == (n_cand,)
                     for key in param_keys + score_keys))
 
 
-def check_results_grid_scores_consistency(search):
+def check_cv_results_grid_scores_consistency(search):
     # TODO Remove in 0.20
-    results = search.results_
-    res_scores = np.vstack(list([results["test_split%d_score" % i]
+    results = search.cv_results_
+    res_scores = np.vstack(list([results["split%d_test_score" % i]
                                  for i in range(search.n_splits_)])).T
-    res_means = results["test_mean_score"]
+    res_means = results["mean_test_score"]
     res_params = results["params"]
     n_cand = len(res_params)
     grid_scores = assert_warns(DeprecationWarning, getattr,
@@ -636,20 +637,20 @@ def test_grid_search_results():
     grid_search_iid.fit(X, y)
 
     param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel')
-    score_keys = ('test_mean_score', 'test_rank_score',
-                  'test_split0_score', 'test_split1_score',
-                  'test_split2_score', 'test_std_score')
+    score_keys = ('mean_test_score', 'rank_test_score',
+                  'split0_test_score', 'split1_test_score',
+                  'split2_test_score', 'std_test_score')
     n_candidates = n_grid_points
 
     for search, iid in zip((grid_search, grid_search_iid), (False, True)):
         assert_equal(iid, search.iid)
-        results = search.results_
+        results = search.cv_results_
         # Check results structure
-        check_results_array_types(results, param_keys, score_keys)
-        check_results_keys(results, param_keys, score_keys, n_candidates)
+        check_cv_results_array_types(results, param_keys, score_keys)
+        check_cv_results_keys(results, param_keys, score_keys, n_candidates)
         # Check masking
-        results = grid_search.results_
-        n_candidates = len(grid_search.results_['params'])
+        results = grid_search.cv_results_
+        n_candidates = len(grid_search.cv_results_['params'])
         assert_true(all((results['param_C'].mask[i] and
                          results['param_gamma'].mask[i] and
                          not results['param_degree'].mask[i])
@@ -660,7 +661,7 @@ def test_grid_search_results():
                          results['param_degree'].mask[i])
                         for i in range(n_candidates)
                         if results['param_kernel'][i] == 'rbf'))
-        check_results_grid_scores_consistency(search)
+        check_cv_results_grid_scores_consistency(search)
 
 
 def test_random_search_results():
@@ -685,21 +686,21 @@ def test_random_search_results():
     random_search_iid.fit(X, y)
 
     param_keys = ('param_C', 'param_gamma')
-    score_keys = ('test_mean_score', 'test_rank_score',
-                  'test_split0_score', 'test_split1_score',
-                  'test_split2_score', 'test_std_score')
+    score_keys = ('mean_test_score', 'rank_test_score',
+                  'split0_test_score', 'split1_test_score',
+                  'split2_test_score', 'std_test_score')
     n_cand = n_search_iter
 
     for search, iid in zip((random_search, random_search_iid), (False, True)):
         assert_equal(iid, search.iid)
-        results = search.results_
+        results = search.cv_results_
         # Check results structure
-        check_results_array_types(results, param_keys, score_keys)
-        check_results_keys(results, param_keys, score_keys, n_cand)
+        check_cv_results_array_types(results, param_keys, score_keys)
+        check_cv_results_keys(results, param_keys, score_keys, n_cand)
         # For random_search, all the param array vals should be unmasked
         assert_false(any(results['param_C'].mask) or
                      any(results['param_gamma'].mask))
-        check_results_grid_scores_consistency(search)
+        check_cv_results_grid_scores_consistency(search)
 
 
 def test_search_iid_param():
@@ -726,12 +727,13 @@ def test_search_iid_param():
         assert_true(search.iid)
 
         # Test the first candidate
-        cv_scores = np.array(list(search.results_['test_split%d_score' % s][0]
+        cv_scores = np.array(list(search.cv_results_['split%d_test_score'
+                                                     % s][0]
                                   for s in range(search.n_splits_)))
-        mean = search.results_['test_mean_score'][0]
-        std = search.results_['test_std_score'][0]
+        mean = search.cv_results_['mean_test_score'][0]
+        std = search.cv_results_['std_test_score'][0]
 
-        assert_equal(search.results_['param_C'][0], 1)
+        assert_equal(search.cv_results_['param_C'][0], 1)
         assert_array_almost_equal(cv_scores, [1, 1. / 3.])
         # for first split, 1/4 of dataset is in test, for second 3/4.
         # take weighted average and weighted std
@@ -753,11 +755,12 @@ def test_search_iid_param():
         search.fit(X, y)
         assert_false(search.iid)
 
-        cv_scores = np.array(list(search.results_['test_split%d_score' % s][0]
+        cv_scores = np.array(list(search.cv_results_['split%d_test_score'
+                                                     % s][0]
                                   for s in range(search.n_splits_)))
-        mean = search.results_['test_mean_score'][0]
-        std = search.results_['test_std_score'][0]
-        assert_equal(search.results_['param_C'][0], 1)
+        mean = search.cv_results_['mean_test_score'][0]
+        std = search.cv_results_['std_test_score'][0]
+        assert_equal(search.cv_results_['param_C'][0], 1)
         # scores are the same as above
         assert_array_almost_equal(cv_scores, [1, 1. / 3.])
         # Unweighted mean/std is used
@@ -765,7 +768,7 @@ def test_search_iid_param():
         assert_almost_equal(std, np.std(cv_scores))
 
 
-def test_search_results_rank_tie_breaking():
+def test_search_cv_results_rank_tie_breaking():
     X, y = make_blobs(n_samples=50, random_state=42)
 
     # The two C values are close enough to give similar models
@@ -778,22 +781,22 @@ def test_search_results_rank_tie_breaking():
 
     for search in (grid_search, random_search):
         search.fit(X, y)
-        results = search.results_
+        results = search.cv_results_
         # Check tie breaking strategy -
         # Check that there is a tie in the mean scores between
         # candidates 1 and 2 alone
-        assert_almost_equal(results['test_mean_score'][0],
-                            results['test_mean_score'][1])
+        assert_almost_equal(results['mean_test_score'][0],
+                            results['mean_test_score'][1])
         try:
-            assert_almost_equal(results['test_mean_score'][1],
-                                results['test_mean_score'][2])
+            assert_almost_equal(results['mean_test_score'][1],
+                                results['mean_test_score'][2])
         except AssertionError:
             pass
         # 'min' rank should be assigned to the tied candidates
-        assert_almost_equal(search.results_['test_rank_score'], [1, 1, 3])
+        assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3])
 
 
-def test_search_results_none_param():
+def test_search_cv_results_none_param():
     X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1]
     estimators = (DecisionTreeRegressor(), DecisionTreeClassifier())
     est_parameters = {"random_state": [0, None]}
@@ -801,7 +804,7 @@ def test_search_results_none_param():
 
     for est in estimators:
         grid_search = GridSearchCV(est, est_parameters, cv=cv).fit(X, y)
-        assert_array_equal(grid_search.results_['param_random_state'],
+        assert_array_equal(grid_search.cv_results_['param_random_state'],
                            [0, None])
 
 
@@ -813,12 +816,12 @@ def test_grid_search_correct_score_results():
     Cs = [.1, 1, 10]
     for score in ['f1', 'roc_auc']:
         grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score, cv=n_splits)
-        results = grid_search.fit(X, y).results_
+        results = grid_search.fit(X, y).cv_results_
 
         # Test scorer names
         result_keys = list(results.keys())
-        expected_keys = (("test_mean_score", "test_rank_score") +
-                         tuple("test_split%d_score" % cv_i
+        expected_keys = (("mean_test_score", "rank_test_score") +
+                         tuple("split%d_test_score" % cv_i
                                for cv_i in range(n_splits)))
         assert_true(all(in1d(expected_keys, result_keys)))
 
@@ -826,9 +829,10 @@ def test_grid_search_correct_score_results():
         n_splits = grid_search.n_splits_
         for candidate_i, C in enumerate(Cs):
             clf.set_params(C=C)
-            cv_scores = np.array(list(grid_search.results_['test_split%d_score'
-                                                           % s][candidate_i]
-                                      for s in range(n_splits)))
+            cv_scores = np.array(
+                list(grid_search.cv_results_['split%d_test_score'
+                                             % s][candidate_i]
+                     for s in range(n_splits)))
             for i, (train, test) in enumerate(cv.split(X, y)):
                 clf.fit(X[train], y[train])
                 if score == "f1":
@@ -868,7 +872,7 @@ def test_grid_search_with_multioutput_data():
     for est in estimators:
         grid_search = GridSearchCV(est, est_parameters, cv=cv)
         grid_search.fit(X, y)
-        res_params = grid_search.results_['params']
+        res_params = grid_search.cv_results_['params']
         for cand_i in range(len(res_params)):
             est.set_params(**res_params[cand_i])
 
@@ -877,14 +881,14 @@ def test_grid_search_with_multioutput_data():
                 correct_score = est.score(X[test], y[test])
                 assert_almost_equal(
                     correct_score,
-                    grid_search.results_['test_split%d_score' % i][cand_i])
+                    grid_search.cv_results_['split%d_test_score' % i][cand_i])
 
     # Test with a randomized search
     for est in estimators:
         random_search = RandomizedSearchCV(est, est_parameters,
                                            cv=cv, n_iter=3)
         random_search.fit(X, y)
-        res_params = random_search.results_['params']
+        res_params = random_search.cv_results_['params']
         for cand_i in range(len(res_params)):
             est.set_params(**res_params[cand_i])
 
@@ -893,7 +897,8 @@ def test_grid_search_with_multioutput_data():
                 correct_score = est.score(X[test], y[test])
                 assert_almost_equal(
                     correct_score,
-                    random_search.results_['test_split%d_score' % i][cand_i])
+                    random_search.cv_results_['split%d_test_score'
+                                              % i][cand_i])
 
 
 def test_predict_proba_disabled():
@@ -949,23 +954,24 @@ def test_grid_search_failing_classifier():
     gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
                       refit=False, error_score=0.0)
     assert_warns(FitFailedWarning, gs.fit, X, y)
-    n_candidates = len(gs.results_['params'])
+    n_candidates = len(gs.cv_results_['params'])
     # Ensure that grid scores were set to zero as required for those fits
     # that are expected to fail.
     get_cand_scores = lambda i: np.array(list(
-        gs.results_['test_split%d_score' % s][i] for s in range(gs.n_splits_)))
+        gs.cv_results_['split%d_test_score' % s][i]
+        for s in range(gs.n_splits_)))
     assert all((np.all(get_cand_scores(cand_i) == 0.0)
                 for cand_i in range(n_candidates)
-                if gs.results_['param_parameter'][cand_i] ==
+                if gs.cv_results_['param_parameter'][cand_i] ==
                 FailingClassifier.FAILING_PARAMETER))
 
     gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
                       refit=False, error_score=float('nan'))
     assert_warns(FitFailedWarning, gs.fit, X, y)
-    n_candidates = len(gs.results_['params'])
+    n_candidates = len(gs.cv_results_['params'])
     assert all(np.all(np.isnan(get_cand_scores(cand_i)))
                for cand_i in range(n_candidates)
-               if gs.results_['param_parameter'][cand_i] ==
+               if gs.cv_results_['param_parameter'][cand_i] ==
                FailingClassifier.FAILING_PARAMETER)