diff --git a/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py b/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py index eab418fd0d8baf415fc2bcb6275591709febb45f..9f747694064ac5584853965277a4d3b8bf44f1ae 100644 --- a/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py +++ b/doc/tutorial/text_analytics/solutions/exercise_02_sentiment.py @@ -55,12 +55,12 @@ if __name__ == "__main__": # TASK: print the mean and std for each candidate along with the parameter # settings for all the candidates explored by grid search. - n_candidates = len(grid_search.results_['params']) + n_candidates = len(grid_search.cv_results_['params']) for i in range(n_candidates): print(i, 'params - %s; mean - %0.2f; std - %0.2f' - % (grid_search.results_['params'][i], - grid_search.results_['test_mean_score'][i], - grid_search.results_['test_std_score'][i])) + % (grid_search.cv_results_['params'][i], + grid_search.cv_results_['mean_test_score'][i], + grid_search.cv_results_['std_test_score'][i])) # TASK: Predict the outcome on the testing set and store it in a variable # named y_predicted diff --git a/doc/tutorial/text_analytics/working_with_text_data.rst b/doc/tutorial/text_analytics/working_with_text_data.rst index e1aa2d50fca4e7c680b9128b3ff97704fc6f0d54..9248e9adbbcdff528a482108d9dcef8936d6c670 100644 --- a/doc/tutorial/text_analytics/working_with_text_data.rst +++ b/doc/tutorial/text_analytics/working_with_text_data.rst @@ -458,9 +458,9 @@ mean score and the parameters setting corresponding to that score:: tfidf__use_idf: True vect__ngram_range: (1, 1) -A more detailed summary of the search is available at ``gs_clf.results_``. +A more detailed summary of the search is available at ``gs_clf.cv_results_``. -The ``results_`` parameter can be easily imported into pandas as a +The ``cv_results_`` parameter can be easily imported into pandas as a ``DataFrame`` for further inspection. .. note: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index b5611d07b1a0e3f197a044a6130c8bafa2cfac0e..3da1dab97385259048329c2ec9b5389f17b3a9b0 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -39,27 +39,27 @@ Model Selection Enhancements and API Changes :class:`model_selection.GridSearchCV` and :class:`model_selection.RandomizedSearchCV` utilities. - - **The enhanced `results_` attribute** + - **The enhanced ``cv_results_`` attribute** - The new ``results_`` attribute (of :class:`model_selection.GridSearchCV` + The new ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV` and :class:`model_selection.RandomizedSearchCV`) introduced in lieu of the ``grid_scores_`` attribute is a dict of 1D arrays with elements in each array corresponding to the parameter settings (i.e. search candidates). - The ``results_`` dict can be easily imported into ``pandas`` as a + The ``cv_results_`` dict can be easily imported into ``pandas`` as a ``DataFrame`` for exploring the search results. - The ``results_`` arrays include scores for each cross-validation split - (with keys such as ``test_split0_score``), as well as their mean - (``test_mean_score``) and standard deviation (``test_std_score``). + The ``cv_results_`` arrays include scores for each cross-validation split + (with keys such as ``'split0_test_score'``), as well as their mean + (``'mean_test_score'``) and standard deviation (``'std_test_score'``). The ranks for the search candidates (based on their mean - cross-validation score) is available at ``results_['test_rank_score']``. + cross-validation score) is available at ``cv_results_['rank_test_score']``. The parameter values for each parameter is stored separately as numpy masked object arrays. The value, for that search candidate, is masked if the corresponding parameter is not applicable. Additionally a list of all - the parameter dicts are stored at ``results_['params']``. + the parameter dicts are stored at ``cv_results_['params']``. - **Parameters ``n_folds`` and ``n_iter`` renamed to ``n_splits``** @@ -235,7 +235,7 @@ Enhancements - The :func: `ignore_warnings` now accept a category argument to ignore only the warnings of a specified type. By `Thierry Guillemot`_. - - The new ``results_`` attribute of :class:`model_selection.GridSearchCV` + - The new ``cv_results_`` attribute of :class:`model_selection.GridSearchCV` (and :class:`model_selection.RandomizedSearchCV`) can be easily imported into pandas as a ``DataFrame``. Ref :ref:`model_selection_changes` for more information. @@ -419,7 +419,7 @@ API changes summary - The ``grid_scores_`` attribute of :class:`model_selection.GridSearchCV` and :class:`model_selection.RandomizedSearchCV` is deprecated in favor of - the attribute ``results_``. + the attribute ``cv_results_``. Ref :ref:`model_selection_changes` for more information. (`#6697 <https://github.com/scikit-learn/scikit-learn/pull/6697>`_) by `Raghav R V`_. diff --git a/examples/model_selection/grid_search_digits.py b/examples/model_selection/grid_search_digits.py index 13755b0bc8c105ac1dd5f636dee405bb8ba523b1..50ba4dc097f8b1cfa887ac9103689f82bcf19307 100644 --- a/examples/model_selection/grid_search_digits.py +++ b/examples/model_selection/grid_search_digits.py @@ -60,11 +60,11 @@ for score in scores: print() print("Grid scores on development set:") print() - means = clf.results_['test_mean_score'] - stds = clf.results_['test_std_score'] - for i in range(len(clf.results_['params'])): + means = clf.cv_results_['mean_test_score'] + stds = clf.cv_results_['std_test_score'] + for mean, std, params in zip(means, stds, clf.cv_results_['params']): print("%0.3f (+/-%0.03f) for %r" - % (means[i], stds[i] * 2, clf.results_['params'][i])) + % (mean, std * 2, params)) print() print("Detailed classification report:") diff --git a/examples/model_selection/randomized_search.py b/examples/model_selection/randomized_search.py index e1f7c215ab65359f8f2af1b5105cfda239ba5790..1024a01239d012357c8aec5339e427e40237e40a 100644 --- a/examples/model_selection/randomized_search.py +++ b/examples/model_selection/randomized_search.py @@ -41,12 +41,12 @@ clf = RandomForestClassifier(n_estimators=20) # Utility function to report best scores def report(results, n_top=3): for i in range(1, n_top + 1): - candidates = np.flatnonzero(results['test_rank_score'] == i) + candidates = np.flatnonzero(results['rank_test_score'] == i) for candidate in candidates: print("Model with rank: {0}".format(i)) print("Mean validation score: {0:.3f} (std: {1:.3f})".format( - results['test_mean_score'][candidate], - results['test_std_score'][candidate])) + results['mean_test_score'][candidate], + results['std_test_score'][candidate])) print("Parameters: {0}".format(results['params'][candidate])) print("") @@ -68,7 +68,7 @@ start = time() random_search.fit(X, y) print("RandomizedSearchCV took %.2f seconds for %d candidates" " parameter settings." % ((time() - start), n_iter_search)) -report(random_search.results_) +report(random_search.cv_results_) # use a full grid over all parameters param_grid = {"max_depth": [3, None], @@ -84,5 +84,5 @@ start = time() grid_search.fit(X, y) print("GridSearchCV took %.2f seconds for %d candidate parameter settings." - % (time() - start, len(grid_search.results_['params']))) -report(grid_search.results_) + % (time() - start, len(grid_search.cv_results_['params']))) +report(grid_search.cv_results_) diff --git a/examples/plot_compare_reduction.py b/examples/plot_compare_reduction.py index 0f578722ee68cc42db21aa4c43a2ef6361d235bf..1c84ea9c3a4dc907534d79dc49ffe270a4a26372 100644 --- a/examples/plot_compare_reduction.py +++ b/examples/plot_compare_reduction.py @@ -53,7 +53,7 @@ grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid) digits = load_digits() grid.fit(digits.data, digits.target) -mean_scores = np.array(grid.results_['test_mean_score']) +mean_scores = np.array(grid.cv_results_['mean_test_score']) # scores are in the order of param_grid iteration, which is alphabetical mean_scores = mean_scores.reshape(len(C_OPTIONS), -1, len(N_FEATURES_OPTIONS)) # select score for best C diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index b71d6b22dc7c4e9a390ff25fb7c41e6259ee62e9..acec9896169b8bd014cdbb4363b29d6d6c8a4781 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -171,8 +171,8 @@ for (k, (C, gamma, clf)) in enumerate(classifiers): plt.yticks(()) plt.axis('tight') -scores = grid.results_['test_mean_score'].reshape(len(C_range), - len(gamma_range)) +scores = grid.cv_results_['mean_test_score'].reshape(len(C_range), + len(gamma_range)) # Draw heatmap of the validation accuracy as a function of gamma and C # diff --git a/examples/svm/plot_svm_scale_c.py b/examples/svm/plot_svm_scale_c.py index 09934c2f5d8592f172c2d3bacea760374598be2a..5d72ca61d5157c1d028f7484e2fb55da3ae68354 100644 --- a/examples/svm/plot_svm_scale_c.py +++ b/examples/svm/plot_svm_scale_c.py @@ -131,7 +131,7 @@ for fignum, (clf, cs, X, y) in enumerate(clf_sets): cv=ShuffleSplit(train_size=train_size, n_splits=250, random_state=1)) grid.fit(X, y) - scores = grid.results_['test_mean_score'] + scores = grid.cv_results_['mean_test_score'] scales = [(1, 'No scaling'), ((n_samples * train_size), '1/n_samples'), diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index be3509770141ba0369669177bc0c6de3a943e462..435cabc68d49a6e6f8c9671aa932db42646985c6 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -573,17 +573,18 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2, axis=1, weights=weights)) - results = dict() + cv_results = dict() for split_i in range(n_splits): - results["test_split%d_score" % split_i] = test_scores[:, split_i] - results["test_mean_score"] = means - results["test_std_score"] = stds + cv_results["split%d_test_score" % split_i] = test_scores[:, + split_i] + cv_results["mean_test_score"] = means + cv_results["std_test_score"] = stds ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32) best_index = np.flatnonzero(ranks == 1)[0] best_parameters = candidate_params[best_index] - results["test_rank_score"] = ranks + cv_results["rank_test_score"] = ranks # Use one np.MaskedArray and mask all the places where the param is not # applicable for that candidate. Use defaultdict as each candidate may @@ -597,12 +598,12 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, # Setting the value at an index also unmasks that index param_results["param_%s" % name][cand_i] = value - results.update(param_results) + cv_results.update(param_results) # Store a list of param dicts at the key 'params' - results['params'] = candidate_params + cv_results['params'] = candidate_params - self.results_ = results + self.cv_results_ = cv_results self.best_index_ = best_index self.n_splits_ = n_splits @@ -620,30 +621,31 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, @property def best_params_(self): - check_is_fitted(self, 'results_') - return self.results_['params'][self.best_index_] + check_is_fitted(self, 'cv_results_') + return self.cv_results_['params'][self.best_index_] @property def best_score_(self): - check_is_fitted(self, 'results_') - return self.results_['test_mean_score'][self.best_index_] + check_is_fitted(self, 'cv_results_') + return self.cv_results_['mean_test_score'][self.best_index_] @property def grid_scores_(self): warnings.warn( "The grid_scores_ attribute was deprecated in version 0.18" - " in favor of the more elaborate results_ attribute." + " in favor of the more elaborate cv_results_ attribute." " The grid_scores_ attribute will not be available from 0.20", DeprecationWarning) - check_is_fitted(self, 'results_') + check_is_fitted(self, 'cv_results_') grid_scores = list() for i, (params, mean, std) in enumerate(zip( - self.results_['params'], - self.results_['test_mean_score'], - self.results_['test_std_score'])): - scores = np.array(list(self.results_['test_split%d_score' % s][i] + self.cv_results_['params'], + self.cv_results_['mean_test_score'], + self.cv_results_['std_test_score'])): + scores = np.array(list(self.cv_results_['split%d_test_score' + % s][i] for s in range(self.n_splits_)), dtype=np.float64) grid_scores.append(_CVScoreTuple(params, mean, scores)) @@ -763,22 +765,22 @@ class GridSearchCV(BaseSearchCV): fit_params={}, iid=..., n_jobs=1, param_grid=..., pre_dispatch=..., refit=..., scoring=..., verbose=...) - >>> sorted(clf.results_.keys()) + >>> sorted(clf.cv_results_.keys()) ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS - ['param_C', 'param_kernel', 'params', 'test_mean_score',... - 'test_rank_score', 'test_split0_score', 'test_split1_score',... - 'test_split2_score', 'test_std_score'] + ['mean_test_score', 'param_C', 'param_kernel', 'params',... + 'rank_test_score', 'split0_test_score', 'split1_test_score',... + 'split2_test_score', 'std_test_score'] Attributes ---------- - results_ : dict of numpy (masked) ndarrays + cv_results_ : dict of numpy (masked) ndarrays A dict with keys as column headers and values as columns, that can be imported into a pandas ``DataFrame``. For instance the below given table +------------+-----------+------------+-----------------+---+---------+ - |param_kernel|param_gamma|param_degree|test_split0_score|...|...rank..| + |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_....| +============+===========+============+=================+===+=========+ | 'poly' | -- | 2 | 0.8 |...| 2 | +------------+-----------+------------+-----------------+---+---------+ @@ -789,7 +791,7 @@ class GridSearchCV(BaseSearchCV): | 'rbf' | 0.2 | -- | 0.9 |...| 1 | +------------+-----------+------------+-----------------+---+---------+ - will be represented by a ``results_`` dict of:: + will be represented by a ``cv_results_`` dict of:: { 'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'], @@ -798,11 +800,11 @@ class GridSearchCV(BaseSearchCV): mask = [ True True False False]...), 'param_degree': masked_array(data = [2.0 3.0 -- --], mask = [False False True True]...), - 'test_split0_score' : [0.8, 0.7, 0.8, 0.9], - 'test_split1_score' : [0.82, 0.5, 0.7, 0.78], - 'test_mean_score' : [0.81, 0.60, 0.75, 0.82], - 'test_std_score' : [0.02, 0.01, 0.03, 0.03], - 'test_rank_score' : [2, 4, 3, 1], + 'split0_test_score' : [0.8, 0.7, 0.8, 0.9], + 'split1_test_score' : [0.82, 0.5, 0.7, 0.78], + 'mean_test_score' : [0.81, 0.60, 0.75, 0.82], + 'std_test_score' : [0.02, 0.01, 0.03, 0.03], + 'rank_test_score' : [2, 4, 3, 1], 'params' : [{'kernel': 'poly', 'degree': 2}, ...], } @@ -821,10 +823,10 @@ class GridSearchCV(BaseSearchCV): Parameter setting that gave the best results on the hold out data. best_index_ : int - The index (of the ``results_`` arrays) which corresponds to the best + The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting. - The dict at ``search.results_['params'][search.best_index_]`` gives + The dict at ``search.cv_results_['params'][search.best_index_]`` gives the parameter setting for the best model, that gives the highest mean score (``search.best_score_``). @@ -1005,14 +1007,14 @@ class RandomizedSearchCV(BaseSearchCV): Attributes ---------- - results_ : dict of numpy (masked) ndarrays + cv_results_ : dict of numpy (masked) ndarrays A dict with keys as column headers and values as columns, that can be imported into a pandas ``DataFrame``. For instance the below given table +--------------+-------------+-------------------+---+---------------+ - | param_kernel | param_gamma | test_split0_score |...|test_rank_score| + | param_kernel | param_gamma | split0_test_score |...|rank_test_score| +==============+=============+===================+===+===============+ | 'rbf' | 0.1 | 0.8 |...| 2 | +--------------+-------------+-------------------+---+---------------+ @@ -1021,17 +1023,17 @@ class RandomizedSearchCV(BaseSearchCV): | 'rbf' | 0.3 | 0.7 |...| 1 | +--------------+-------------+-------------------+---+---------------+ - will be represented by a ``results_`` dict of:: + will be represented by a ``cv_results_`` dict of:: { 'param_kernel' : masked_array(data = ['rbf', rbf', 'rbf'], mask = False), 'param_gamma' : masked_array(data = [0.1 0.2 0.3], mask = False), - 'test_split0_score' : [0.8, 0.9, 0.7], - 'test_split1_score' : [0.82, 0.5, 0.7], - 'test_mean_score' : [0.81, 0.7, 0.7], - 'test_std_score' : [0.02, 0.2, 0.], - 'test_rank_score' : [3, 1, 1], + 'split0_test_score' : [0.8, 0.9, 0.7], + 'split1_test_score' : [0.82, 0.5, 0.7], + 'mean_test_score' : [0.81, 0.7, 0.7], + 'std_test_score' : [0.02, 0.2, 0.], + 'rank_test_score' : [3, 1, 1], 'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...], } @@ -1050,10 +1052,10 @@ class RandomizedSearchCV(BaseSearchCV): Parameter setting that gave the best results on the hold out data. best_index_ : int - The index (of the ``results_`` arrays) which corresponds to the best + The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting. - The dict at ``search.results_['params'][search.best_index_]`` gives + The dict at ``search.cv_results_['params'][search.best_index_]`` gives the parameter setting for the best model, that gives the highest mean score (``search.best_score_``). diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index b5abd8d873a51bbcb58e86065537090203a520a8..65fc1964ddf315ae19b4865ae928285c04648714 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -155,7 +155,8 @@ def test_grid_search(): sys.stdout = old_stdout assert_equal(grid_search.best_estimator_.foo_param, 2) - assert_array_equal(grid_search.results_["param_foo_param"].data, [1, 2, 3]) + assert_array_equal(grid_search.cv_results_["param_foo_param"].data, + [1, 2, 3]) # Smoke test the score etc: grid_search.score(X, y) @@ -265,17 +266,17 @@ def test_grid_search_labels(): gs.fit(X, y) -def test_trivial_results_attr(): +def test_trivial_cv_results_attr(): # Test search over a "grid" with only one point. # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV. clf = MockClassifier() grid_search = GridSearchCV(clf, {'foo_param': [1]}) grid_search.fit(X, y) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) random_search = RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1) random_search.fit(X, y) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) def test_no_refit(): @@ -472,7 +473,7 @@ def test_gridsearch_nd(): clf = CheckingClassifier(check_X=check_X, check_y=check_y) grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) grid_search.fit(X_4d, y_3d).score(X, y) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) def test_X_as_list(): @@ -484,7 +485,7 @@ def test_X_as_list(): cv = KFold(n_splits=3) grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) grid_search.fit(X.tolist(), y).score(X, y) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) def test_y_as_list(): @@ -496,7 +497,7 @@ def test_y_as_list(): cv = KFold(n_splits=3) grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) grid_search.fit(X, y.tolist()).score(X, y) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) @ignore_warnings @@ -522,7 +523,7 @@ def test_pandas_input(): grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) grid_search.fit(X_df, y_ser).score(X_df, y_ser) grid_search.predict(X_df) - assert_true(hasattr(grid_search, "results_")) + assert_true(hasattr(grid_search, "cv_results_")) def test_unsupervised_grid_search(): @@ -581,7 +582,7 @@ def test_param_sampler(): assert_equal([x for x in sampler], [x for x in sampler]) -def check_results_array_types(results, param_keys, score_keys): +def check_cv_results_array_types(results, param_keys, score_keys): # Check if the search results' array are of correct types assert_true(all(isinstance(results[param], np.ma.MaskedArray) for param in param_keys)) @@ -589,24 +590,24 @@ def check_results_array_types(results, param_keys, score_keys): assert_false(any(isinstance(results[key], np.ma.MaskedArray) for key in score_keys)) assert_true(all(results[key].dtype == np.float64 - for key in score_keys if key != 'test_rank_score')) - assert_true(results['test_rank_score'].dtype == np.int32) + for key in score_keys if key != 'rank_test_score')) + assert_true(results['rank_test_score'].dtype == np.int32) -def check_results_keys(results, param_keys, score_keys, n_cand): - # Test the search.results_ contains all the required results +def check_cv_results_keys(results, param_keys, score_keys, n_cand): + # Test the search.cv_results_ contains all the required results assert_array_equal(sorted(results.keys()), sorted(param_keys + score_keys + ('params',))) assert_true(all(results[key].shape == (n_cand,) for key in param_keys + score_keys)) -def check_results_grid_scores_consistency(search): +def check_cv_results_grid_scores_consistency(search): # TODO Remove in 0.20 - results = search.results_ - res_scores = np.vstack(list([results["test_split%d_score" % i] + results = search.cv_results_ + res_scores = np.vstack(list([results["split%d_test_score" % i] for i in range(search.n_splits_)])).T - res_means = results["test_mean_score"] + res_means = results["mean_test_score"] res_params = results["params"] n_cand = len(res_params) grid_scores = assert_warns(DeprecationWarning, getattr, @@ -636,20 +637,20 @@ def test_grid_search_results(): grid_search_iid.fit(X, y) param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel') - score_keys = ('test_mean_score', 'test_rank_score', - 'test_split0_score', 'test_split1_score', - 'test_split2_score', 'test_std_score') + score_keys = ('mean_test_score', 'rank_test_score', + 'split0_test_score', 'split1_test_score', + 'split2_test_score', 'std_test_score') n_candidates = n_grid_points for search, iid in zip((grid_search, grid_search_iid), (False, True)): assert_equal(iid, search.iid) - results = search.results_ + results = search.cv_results_ # Check results structure - check_results_array_types(results, param_keys, score_keys) - check_results_keys(results, param_keys, score_keys, n_candidates) + check_cv_results_array_types(results, param_keys, score_keys) + check_cv_results_keys(results, param_keys, score_keys, n_candidates) # Check masking - results = grid_search.results_ - n_candidates = len(grid_search.results_['params']) + results = grid_search.cv_results_ + n_candidates = len(grid_search.cv_results_['params']) assert_true(all((results['param_C'].mask[i] and results['param_gamma'].mask[i] and not results['param_degree'].mask[i]) @@ -660,7 +661,7 @@ def test_grid_search_results(): results['param_degree'].mask[i]) for i in range(n_candidates) if results['param_kernel'][i] == 'rbf')) - check_results_grid_scores_consistency(search) + check_cv_results_grid_scores_consistency(search) def test_random_search_results(): @@ -685,21 +686,21 @@ def test_random_search_results(): random_search_iid.fit(X, y) param_keys = ('param_C', 'param_gamma') - score_keys = ('test_mean_score', 'test_rank_score', - 'test_split0_score', 'test_split1_score', - 'test_split2_score', 'test_std_score') + score_keys = ('mean_test_score', 'rank_test_score', + 'split0_test_score', 'split1_test_score', + 'split2_test_score', 'std_test_score') n_cand = n_search_iter for search, iid in zip((random_search, random_search_iid), (False, True)): assert_equal(iid, search.iid) - results = search.results_ + results = search.cv_results_ # Check results structure - check_results_array_types(results, param_keys, score_keys) - check_results_keys(results, param_keys, score_keys, n_cand) + check_cv_results_array_types(results, param_keys, score_keys) + check_cv_results_keys(results, param_keys, score_keys, n_cand) # For random_search, all the param array vals should be unmasked assert_false(any(results['param_C'].mask) or any(results['param_gamma'].mask)) - check_results_grid_scores_consistency(search) + check_cv_results_grid_scores_consistency(search) def test_search_iid_param(): @@ -726,12 +727,13 @@ def test_search_iid_param(): assert_true(search.iid) # Test the first candidate - cv_scores = np.array(list(search.results_['test_split%d_score' % s][0] + cv_scores = np.array(list(search.cv_results_['split%d_test_score' + % s][0] for s in range(search.n_splits_))) - mean = search.results_['test_mean_score'][0] - std = search.results_['test_std_score'][0] + mean = search.cv_results_['mean_test_score'][0] + std = search.cv_results_['std_test_score'][0] - assert_equal(search.results_['param_C'][0], 1) + assert_equal(search.cv_results_['param_C'][0], 1) assert_array_almost_equal(cv_scores, [1, 1. / 3.]) # for first split, 1/4 of dataset is in test, for second 3/4. # take weighted average and weighted std @@ -753,11 +755,12 @@ def test_search_iid_param(): search.fit(X, y) assert_false(search.iid) - cv_scores = np.array(list(search.results_['test_split%d_score' % s][0] + cv_scores = np.array(list(search.cv_results_['split%d_test_score' + % s][0] for s in range(search.n_splits_))) - mean = search.results_['test_mean_score'][0] - std = search.results_['test_std_score'][0] - assert_equal(search.results_['param_C'][0], 1) + mean = search.cv_results_['mean_test_score'][0] + std = search.cv_results_['std_test_score'][0] + assert_equal(search.cv_results_['param_C'][0], 1) # scores are the same as above assert_array_almost_equal(cv_scores, [1, 1. / 3.]) # Unweighted mean/std is used @@ -765,7 +768,7 @@ def test_search_iid_param(): assert_almost_equal(std, np.std(cv_scores)) -def test_search_results_rank_tie_breaking(): +def test_search_cv_results_rank_tie_breaking(): X, y = make_blobs(n_samples=50, random_state=42) # The two C values are close enough to give similar models @@ -778,22 +781,22 @@ def test_search_results_rank_tie_breaking(): for search in (grid_search, random_search): search.fit(X, y) - results = search.results_ + results = search.cv_results_ # Check tie breaking strategy - # Check that there is a tie in the mean scores between # candidates 1 and 2 alone - assert_almost_equal(results['test_mean_score'][0], - results['test_mean_score'][1]) + assert_almost_equal(results['mean_test_score'][0], + results['mean_test_score'][1]) try: - assert_almost_equal(results['test_mean_score'][1], - results['test_mean_score'][2]) + assert_almost_equal(results['mean_test_score'][1], + results['mean_test_score'][2]) except AssertionError: pass # 'min' rank should be assigned to the tied candidates - assert_almost_equal(search.results_['test_rank_score'], [1, 1, 3]) + assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3]) -def test_search_results_none_param(): +def test_search_cv_results_none_param(): X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1] estimators = (DecisionTreeRegressor(), DecisionTreeClassifier()) est_parameters = {"random_state": [0, None]} @@ -801,7 +804,7 @@ def test_search_results_none_param(): for est in estimators: grid_search = GridSearchCV(est, est_parameters, cv=cv).fit(X, y) - assert_array_equal(grid_search.results_['param_random_state'], + assert_array_equal(grid_search.cv_results_['param_random_state'], [0, None]) @@ -813,12 +816,12 @@ def test_grid_search_correct_score_results(): Cs = [.1, 1, 10] for score in ['f1', 'roc_auc']: grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score, cv=n_splits) - results = grid_search.fit(X, y).results_ + results = grid_search.fit(X, y).cv_results_ # Test scorer names result_keys = list(results.keys()) - expected_keys = (("test_mean_score", "test_rank_score") + - tuple("test_split%d_score" % cv_i + expected_keys = (("mean_test_score", "rank_test_score") + + tuple("split%d_test_score" % cv_i for cv_i in range(n_splits))) assert_true(all(in1d(expected_keys, result_keys))) @@ -826,9 +829,10 @@ def test_grid_search_correct_score_results(): n_splits = grid_search.n_splits_ for candidate_i, C in enumerate(Cs): clf.set_params(C=C) - cv_scores = np.array(list(grid_search.results_['test_split%d_score' - % s][candidate_i] - for s in range(n_splits))) + cv_scores = np.array( + list(grid_search.cv_results_['split%d_test_score' + % s][candidate_i] + for s in range(n_splits))) for i, (train, test) in enumerate(cv.split(X, y)): clf.fit(X[train], y[train]) if score == "f1": @@ -868,7 +872,7 @@ def test_grid_search_with_multioutput_data(): for est in estimators: grid_search = GridSearchCV(est, est_parameters, cv=cv) grid_search.fit(X, y) - res_params = grid_search.results_['params'] + res_params = grid_search.cv_results_['params'] for cand_i in range(len(res_params)): est.set_params(**res_params[cand_i]) @@ -877,14 +881,14 @@ def test_grid_search_with_multioutput_data(): correct_score = est.score(X[test], y[test]) assert_almost_equal( correct_score, - grid_search.results_['test_split%d_score' % i][cand_i]) + grid_search.cv_results_['split%d_test_score' % i][cand_i]) # Test with a randomized search for est in estimators: random_search = RandomizedSearchCV(est, est_parameters, cv=cv, n_iter=3) random_search.fit(X, y) - res_params = random_search.results_['params'] + res_params = random_search.cv_results_['params'] for cand_i in range(len(res_params)): est.set_params(**res_params[cand_i]) @@ -893,7 +897,8 @@ def test_grid_search_with_multioutput_data(): correct_score = est.score(X[test], y[test]) assert_almost_equal( correct_score, - random_search.results_['test_split%d_score' % i][cand_i]) + random_search.cv_results_['split%d_test_score' + % i][cand_i]) def test_predict_proba_disabled(): @@ -949,23 +954,24 @@ def test_grid_search_failing_classifier(): gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', refit=False, error_score=0.0) assert_warns(FitFailedWarning, gs.fit, X, y) - n_candidates = len(gs.results_['params']) + n_candidates = len(gs.cv_results_['params']) # Ensure that grid scores were set to zero as required for those fits # that are expected to fail. get_cand_scores = lambda i: np.array(list( - gs.results_['test_split%d_score' % s][i] for s in range(gs.n_splits_))) + gs.cv_results_['split%d_test_score' % s][i] + for s in range(gs.n_splits_))) assert all((np.all(get_cand_scores(cand_i) == 0.0) for cand_i in range(n_candidates) - if gs.results_['param_parameter'][cand_i] == + if gs.cv_results_['param_parameter'][cand_i] == FailingClassifier.FAILING_PARAMETER)) gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', refit=False, error_score=float('nan')) assert_warns(FitFailedWarning, gs.fit, X, y) - n_candidates = len(gs.results_['params']) + n_candidates = len(gs.cv_results_['params']) assert all(np.all(np.isnan(get_cand_scores(cand_i))) for cand_i in range(n_candidates) - if gs.results_['param_parameter'][cand_i] == + if gs.cv_results_['param_parameter'][cand_i] == FailingClassifier.FAILING_PARAMETER)