diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 5399e27ef4d088c27a3d478dcd634dd184885c3f..7275789c19a07cc600228a501b3fcaf20d8bf635 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -223,6 +223,7 @@ Model validation :toctree: generated/ :template: function.rst + model_selection.cross_validate model_selection.cross_val_score model_selection.cross_val_predict model_selection.permutation_test_score diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index cc5f6a3c07afc856f0edfbf638b78d0ae210881f..ab7d2227447b1596f338b0e304ce9ce2633fdba6 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -172,6 +172,65 @@ validation iterator instead, for instance:: See :ref:`combining_estimators`. + +.. _multimetric_cross_validation: + +The cross_validate function and multiple metric evaluation +---------------------------------------------------------- + +The ``cross_validate`` function differs from ``cross_val_score`` in two ways - + +- It allows specifying multiple metrics for evaluation. + +- It returns a dict containing training scores, fit-times and score-times in + addition to the test score. + +For single metric evaluation, where the scoring parameter is a string, +callable or None, the keys will be - ``['test_score', 'fit_time', 'score_time']`` + +And for multiple metric evaluation, the return value is a dict with the +following keys - +``['test_<scorer1_name>', 'test_<scorer2_name>', 'test_<scorer...>', 'fit_time', 'score_time']`` + +``return_train_score`` is set to ``True`` by default. It adds train score keys +for all the scorers. If train scores are not needed, this should be set to +``False`` explicitly. + +The multiple metrics can be specified either as a list, tuple or set of +predefined scorer names:: + + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.metrics import recall_score + >>> scoring = ['precision_macro', 'recall_macro'] + >>> clf = svm.SVC(kernel='linear', C=1, random_state=0) + >>> scores = cross_validate(clf, iris.data, iris.target, scoring=scoring, + ... cv=5, return_train_score=False) + >>> sorted(scores.keys()) + ['fit_time', 'score_time', 'test_precision_macro', 'test_recall_macro'] + >>> scores['test_recall_macro'] # doctest: +ELLIPSIS + array([ 0.96..., 1. ..., 0.96..., 0.96..., 1. ]) + +Or as a dict mapping scorer name to a predefined or custom scoring function:: + + >>> from sklearn.metrics.scorer import make_scorer + >>> scoring = {'prec_macro': 'precision_macro', + ... 'rec_micro': make_scorer(recall_score, average='macro')} + >>> scores = cross_validate(clf, iris.data, iris.target, scoring=scoring, + ... cv=5, return_train_score=True) + >>> sorted(scores.keys()) # doctest: +NORMALIZE_WHITESPACE + ['fit_time', 'score_time', 'test_prec_macro', 'test_rec_micro', + 'train_prec_macro', 'train_rec_micro'] + >>> scores['train_rec_micro'] # doctest: +ELLIPSIS + array([ 0.97..., 0.97..., 0.99..., 0.98..., 0.98...]) + +Here is an example of ``cross_validate`` using a single metric:: + + >>> scores = cross_validate(clf, iris.data, iris.target, + ... scoring='precision_macro') + >>> sorted(scores.keys()) + ['fit_time', 'score_time', 'test_score', 'train_score'] + + Obtaining predictions by cross-validation ----------------------------------------- @@ -186,7 +245,7 @@ These prediction can then be used to evaluate the classifier:: >>> from sklearn.model_selection import cross_val_predict >>> predicted = cross_val_predict(clf, iris.data, iris.target, cv=10) >>> metrics.accuracy_score(iris.target, predicted) # doctest: +ELLIPSIS - 0.966... + 0.973... Note that the result of this computation may be slightly different from those obtained using :func:`cross_val_score` as the elements are grouped in different diff --git a/doc/modules/grid_search.rst b/doc/modules/grid_search.rst index 48870a80a6c9012963bd1b3fecf54882300dbf73..1867a66594ad4adf168c9e1761da12842a52ae02 100644 --- a/doc/modules/grid_search.rst +++ b/doc/modules/grid_search.rst @@ -84,6 +84,10 @@ evaluated and the best combination is retained. dataset. This is the best practice for evaluating the performance of a model with grid search. + - See :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation` + for an example of :class:`GridSearchCV` being used to evaluate multiple + metrics simultaneously. + .. _randomized_parameter_search: Randomized Parameter Optimization @@ -161,6 +165,27 @@ scoring function can be specified via the ``scoring`` parameter to specialized cross-validation tools described below. See :ref:`scoring_parameter` for more details. +.. _multimetric_grid_search: + +Specifying multiple metrics for evaluation +------------------------------------------ + +``GridSearchCV`` and ``RandomizedSearchCV`` allow specifying multiple metrics +for the ``scoring`` parameter. + +Multimetric scoring can either be specified as a list of strings of predefined +scores names or a dict mapping the scorer name to the scorer function and/or +the predefined scorer name(s). See :ref:`multimetric_scoring` for more details. + +When specifying multiple metrics, the ``refit`` parameter must be set to the +metric (string) for which the ``best_params_`` will be found and used to build +the ``best_estimator_`` on the whole dataset. If the search should not be +refit, set ``refit=False``. Leaving refit to the default value ``None`` will +result in an error when using multiple metrics. + +See :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation` +for an example usage. + Composite estimators and parameter spaces ----------------------------------------- diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index c54417586153c4c6222c6ad8540b2673a5f8ef09..dee5865bdd33ea56adf510add45f97330bc223ba 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -210,6 +210,51 @@ the following two rules: Again, by convention higher numbers are better, so if your scorer returns loss, that value should be negated. +.. _multimetric_scoring: + +Using mutiple metric evaluation +------------------------------- + +Scikit-learn also permits evaluation of multiple metrics in ``GridSearchCV``, +``RandomizedSearchCV`` and ``cross_validate``. + +There are two ways to specify multiple scoring metrics for the ``scoring`` +parameter: + +- As an iterable of string metrics:: + >>> scoring = ['accuracy', 'precision'] + +- As a ``dict`` mapping the scorer name to the scoring function:: + >>> from sklearn.metrics import accuracy_score + >>> from sklearn.metrics import make_scorer + >>> scoring = {'accuracy': make_scorer(accuracy_score), + ... 'prec': 'precision'} + +Note that the dict values can either be scorer functions or one of the +predefined metric strings. + +Currently only those scorer functions that return a single score can be passed +inside the dict. Scorer functions that return multiple values are not +permitted and will require a wrapper to return a single metric:: + + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.metrics import confusion_matrix + >>> # A sample toy binary classification dataset + >>> X, y = datasets.make_classification(n_classes=2, random_state=0) + >>> svm = LinearSVC(random_state=0) + >>> tp = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 0] + >>> tn = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 0] + >>> fp = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[1, 0] + >>> fn = lambda y_true, y_pred: confusion_matrix(y_true, y_pred)[0, 1] + >>> scoring = {'tp' : make_scorer(tp), 'tn' : make_scorer(tn), + ... 'fp' : make_scorer(fp), 'fn' : make_scorer(fn)} + >>> cv_results = cross_validate(svm.fit(X, y), X, y, scoring=scoring) + >>> # Getting the test set false positive scores + >>> print(cv_results['test_tp']) # doctest: +NORMALIZE_WHITESPACE + [12 13 15] + >>> # Getting the test set false negative scores + >>> print(cv_results['test_fn']) # doctest: +NORMALIZE_WHITESPACE + [5 4 1] .. _classification_metrics: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 02035113485105952de4550d391563b102410077..0c5608d6b5970a028348e4878ff318dd9fd2f34a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -31,6 +31,19 @@ Changelog New features ............ + - :class:`model_selection.GridSearchCV` and + :class:`model_selection.RandomizedSearchCV` now support simultaneous + evaluation of multiple metrics. Refer to the + :ref:`multimetric_grid_search` section of the user guide for more + information. :issue:`7388` by `Raghav RV`_ + + - Added the :func:`model_selection.cross_validate` which allows evaluation + of multiple metrics. This function returns a dict with more useful + information from cross-validation such as the train scores, fit times and + score times. + Refer to :ref:`multimetric_cross_validation` section of the userguide + for more information. :issue:`7388` by `Raghav RV`_ + - Added :class:`multioutput.ClassifierChain` for multi-label classification. By `Adam Kleczewski <adamklec>`_. diff --git a/examples/model_selection/plot_multi_metric_evaluation.py b/examples/model_selection/plot_multi_metric_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4491e51f49cb591b215ee16acfef484a3c1c31 --- /dev/null +++ b/examples/model_selection/plot_multi_metric_evaluation.py @@ -0,0 +1,94 @@ +"""Demonstration of multi-metric evaluation on cross_val_score and GridSearchCV + +Multiple metric parameter search can be done by setting the ``scoring`` +parameter to a list of metric scorer names or a dict mapping the scorer names +to the scorer callables. + +The scores of all the scorers are available in the ``cv_results_`` dict at keys +ending in ``'_<scorer_name>'`` (``'mean_test_precision'``, +``'rank_test_precision'``, etc...) + +The ``best_estimator_``, ``best_index_``, ``best_score_`` and ``best_params_`` +correspond to the scorer (key) that is set to the ``refit`` attribute. +""" + +# Author: Raghav RV <rvraghav93@gmail.com> +# License: BSD + +import numpy as np +from matplotlib import pyplot as plt + +from sklearn.datasets import make_hastie_10_2 +from sklearn.model_selection import GridSearchCV +from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score +from sklearn.tree import DecisionTreeClassifier + +print(__doc__) + +############################################################################### +# Running ``GridSearchCV`` using multiple evaluation metrics +# ---------------------------------------------------------- +# + +X, y = make_hastie_10_2(n_samples=8000, random_state=42) + +# The scorers can be either be one of the predefined metric strings or a scorer +# callable, like the one returned by make_scorer +scoring = {'AUC': 'roc_auc', 'Accuracy': make_scorer(accuracy_score)} + +# Setting refit='AUC', refits an estimator on the whole dataset with the +# parameter setting that has the best cross-validated AUC score. +# That estimator is made available at ``gs.best_estimator_`` along with +# parameters like ``gs.best_score_``, ``gs.best_parameters_`` and +# ``gs.best_index_`` +gs = GridSearchCV(DecisionTreeClassifier(random_state=42), + param_grid={'min_samples_split': range(2, 403, 10)}, + scoring=scoring, cv=5, refit='AUC') +gs.fit(X, y) +results = gs.cv_results_ + +############################################################################### +# Plotting the result +# ------------------- + +plt.figure(figsize=(13, 13)) +plt.title("GridSearchCV evaluating using multiple scorers simultaneously", + fontsize=16) + +plt.xlabel("min_samples_split") +plt.ylabel("Score") +plt.grid() + +ax = plt.axes() +ax.set_xlim(0, 402) +ax.set_ylim(0.73, 1) + +# Get the regular numpy array from the MaskedArray +X_axis = np.array(results['param_min_samples_split'].data, dtype=float) + +for scorer, color in zip(sorted(scoring), ['g', 'k']): + for sample, style in (('train', '--'), ('test', '-')): + sample_score_mean = results['mean_%s_%s' % (sample, scorer)] + sample_score_std = results['std_%s_%s' % (sample, scorer)] + ax.fill_between(X_axis, sample_score_mean - sample_score_std, + sample_score_mean + sample_score_std, + alpha=0.1 if sample == 'test' else 0, color=color) + ax.plot(X_axis, sample_score_mean, style, color=color, + alpha=1 if sample == 'test' else 0.7, + label="%s (%s)" % (scorer, sample)) + + best_index = np.nonzero(results['rank_test_%s' % scorer] == 1)[0][0] + best_score = results['mean_test_%s' % scorer][best_index] + + # Plot a dotted vertical line at the best score for that scorer marked by x + ax.plot([X_axis[best_index], ] * 2, [0, best_score], + linestyle='-.', color=color, marker='x', markeredgewidth=3, ms=8) + + # Annotate the best score for that scorer + ax.annotate("%0.2f" % best_score, + (X_axis[best_index], best_score + 0.005)) + +plt.legend(loc="best") +plt.grid('off') +plt.show() diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 3a163d967c542b3b48405809cf629fe8a70c4c44..1d16a9dcb01ac07cb50dffd6d6045bb0b58a3cdd 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -209,12 +209,15 @@ class _ThresholdScorer(_BaseScorer): def get_scorer(scoring): + valid = True if isinstance(scoring, six.string_types): try: scorer = SCORERS[scoring] except KeyError: scorers = [scorer for scorer in SCORERS if SCORERS[scorer]._deprecation_msg is None] + valid = False # Don't raise here to make the error message elegant + if not valid: raise ValueError('%r is not a valid scoring value. ' 'Valid options are %s' % (scoring, sorted(scorers))) @@ -253,13 +256,12 @@ def check_scoring(estimator, scoring=None, allow_none=False): A scorer callable object / function with signature ``scorer(estimator, X, y)``. """ - has_scoring = scoring is not None if not hasattr(estimator, 'fit'): raise TypeError("estimator should be an estimator implementing " "'fit' method, %r was passed" % estimator) if isinstance(scoring, six.string_types): return get_scorer(scoring) - elif has_scoring: + elif callable(scoring): # Heuristic to ensure user has not passed a metric module = getattr(scoring, '__module__', None) if hasattr(module, 'startswith') and \ @@ -272,14 +274,114 @@ def check_scoring(estimator, scoring=None, allow_none=False): 'Please use `make_scorer` to convert a metric ' 'to a scorer.' % scoring) return get_scorer(scoring) - elif hasattr(estimator, 'score'): - return _passthrough_scorer - elif allow_none: - return None + elif scoring is None: + if hasattr(estimator, 'score'): + return _passthrough_scorer + elif allow_none: + return None + else: + raise TypeError( + "If no scoring is specified, the estimator passed should " + "have a 'score' method. The estimator %r does not." + % estimator) else: - raise TypeError( - "If no scoring is specified, the estimator passed should " - "have a 'score' method. The estimator %r does not." % estimator) + raise ValueError("scoring value should either be a callable, string or" + " None. %r was passed" % scoring) + + +def _check_multimetric_scoring(estimator, scoring=None): + """Check the scoring parameter in cases when multiple metrics are allowed + + Parameters + ---------- + estimator : sklearn estimator instance + The estimator for which the scoring will be applied. + + scoring : string, callable, list/tuple, dict or None, default: None + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + + For evaluating multiple metrics, either give a list of (unique) strings + or a dict with names as keys and callables as values. + + NOTE that when using custom scorers, each scorer should return a single + value. Metric functions returning a list/array of values can be wrapped + into multiple scorers that return one value each. + + See :ref:`multivalued_scorer_wrapping` for an example. + + If None the estimator's default scorer (if available) is used. + The return value in that case will be ``{'score': <default_scorer>}``. + If the estimator's default scorer is not available, a ``TypeError`` + is raised. + + Returns + ------- + scorers_dict : dict + A dict mapping each scorer name to its validated scorer. + + is_multimetric : bool + True if scorer is a list/tuple or dict of callables + False if scorer is None/str/callable + """ + if callable(scoring) or scoring is None or isinstance(scoring, + six.string_types): + scorers = {"score": check_scoring(estimator, scoring=scoring)} + return scorers, False + else: + err_msg_generic = ("scoring should either be a single string or " + "callable for single metric evaluation or a " + "list/tuple of strings or a dict of scorer name " + "mapped to the callable for multiple metric " + "evaluation. Got %s of type %s" + % (repr(scoring), type(scoring))) + + if isinstance(scoring, (list, tuple, set)): + err_msg = ("The list/tuple elements must be unique " + "strings of predefined scorers. ") + invalid = False + try: + keys = set(scoring) + except TypeError: + invalid = True + if invalid: + raise ValueError(err_msg) + + if len(keys) != len(scoring): + raise ValueError(err_msg + "Duplicate elements were found in" + " the given list. %r" % repr(scoring)) + elif len(keys) > 0: + if not all(isinstance(k, six.string_types) for k in keys): + if any(callable(k) for k in keys): + raise ValueError(err_msg + + "One or more of the elements were " + "callables. Use a dict of score name " + "mapped to the scorer callable. " + "Got %r" % repr(scoring)) + else: + raise ValueError(err_msg + + "Non-string types were found in " + "the given list. Got %r" + % repr(scoring)) + scorers = {scorer: check_scoring(estimator, scoring=scorer) + for scorer in scoring} + else: + raise ValueError(err_msg + + "Empty list was given. %r" % repr(scoring)) + + elif isinstance(scoring, dict): + keys = set(scoring) + if not all(isinstance(k, six.string_types) for k in keys): + raise ValueError("Non-string types were found in the keys of " + "the given dict. scoring=%r" % repr(scoring)) + if len(keys) == 0: + raise ValueError("An empty dict was passed. %r" + % repr(scoring)) + scorers = {key: check_scoring(estimator, scoring=scorer) + for key, scorer in scoring.items()} + else: + raise ValueError(err_msg_generic) + return scorers, True def make_scorer(score_func, greater_is_better=True, needs_proba=False, diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 461bdadf3d6e501d85fbdf9e0dc37145c5f2725b..47c4d334f893a05a22baf5ddc19e2d9ec4c2ba06 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -8,9 +8,11 @@ import numpy as np from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raises_regexp from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_false from sklearn.utils.testing import ignore_warnings from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_warns_message @@ -21,6 +23,8 @@ from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, from sklearn.metrics import cluster as cluster_module from sklearn.metrics.scorer import (check_scoring, _PredictScorer, _passthrough_scorer) +from sklearn.metrics import accuracy_score +from sklearn.metrics.scorer import _check_multimetric_scoring from sklearn.metrics import make_scorer, get_scorer, SCORERS from sklearn.svm import LinearSVC from sklearn.pipeline import make_pipeline @@ -104,18 +108,18 @@ def teardown_module(): class EstimatorWithoutFit(object): - """Dummy estimator to test check_scoring""" + """Dummy estimator to test scoring validators""" pass class EstimatorWithFit(BaseEstimator): - """Dummy estimator to test check_scoring""" + """Dummy estimator to test scoring validators""" def fit(self, X, y): return self class EstimatorWithFitAndScore(object): - """Dummy estimator to test check_scoring""" + """Dummy estimator to test scoring validators""" def fit(self, X, y): return self @@ -124,7 +128,7 @@ class EstimatorWithFitAndScore(object): class EstimatorWithFitAndPredict(object): - """Dummy estimator to test check_scoring""" + """Dummy estimator to test scoring validators""" def fit(self, X, y): self.y = y return self @@ -145,16 +149,16 @@ def test_all_scorers_repr(): repr(scorer) -def test_check_scoring(): - # Test all branches of check_scoring +def check_scoring_validator_for_single_metric_usecases(scoring_validator): + # Test all branches of single metric usecases estimator = EstimatorWithoutFit() pattern = (r"estimator should be an estimator implementing 'fit' method," r" .* was passed") - assert_raises_regexp(TypeError, pattern, check_scoring, estimator) + assert_raises_regexp(TypeError, pattern, scoring_validator, estimator) estimator = EstimatorWithFitAndScore() estimator.fit([[1]], [1]) - scorer = check_scoring(estimator) + scorer = scoring_validator(estimator) assert_true(scorer is _passthrough_scorer) assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) @@ -162,18 +166,85 @@ def test_check_scoring(): estimator.fit([[1]], [1]) pattern = (r"If no scoring is specified, the estimator passed should have" r" a 'score' method\. The estimator .* does not\.") - assert_raises_regexp(TypeError, pattern, check_scoring, estimator) + assert_raises_regexp(TypeError, pattern, scoring_validator, estimator) - scorer = check_scoring(estimator, "accuracy") + scorer = scoring_validator(estimator, "accuracy") assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) estimator = EstimatorWithFit() - scorer = check_scoring(estimator, "accuracy") + scorer = scoring_validator(estimator, "accuracy") assert_true(isinstance(scorer, _PredictScorer)) - estimator = EstimatorWithFit() - scorer = check_scoring(estimator, allow_none=True) - assert_true(scorer is None) + # Test the allow_none parameter for check_scoring alone + if scoring_validator is check_scoring: + estimator = EstimatorWithFit() + scorer = scoring_validator(estimator, allow_none=True) + assert_true(scorer is None) + + +def check_multimetric_scoring_single_metric_wrapper(*args, **kwargs): + # This wraps the _check_multimetric_scoring to take in single metric + # scoring parameter so we can run the tests that we will run for + # check_scoring, for check_multimetric_scoring too for single-metric + # usecases + scorers, is_multi = _check_multimetric_scoring(*args, **kwargs) + # For all single metric use cases, it should register as not multimetric + assert_false(is_multi) + if args[0] is not None: + assert_true(scorers is not None) + names, scorers = zip(*scorers.items()) + assert_equal(len(scorers), 1) + assert_equal(names[0], 'score') + scorers = scorers[0] + return scorers + + +def test_check_scoring_and_check_multimetric_scoring(): + check_scoring_validator_for_single_metric_usecases(check_scoring) + # To make sure the check_scoring is correctly applied to the constituent + # scorers + check_scoring_validator_for_single_metric_usecases( + check_multimetric_scoring_single_metric_wrapper) + + # For multiple metric use cases + # Make sure it works for the valid cases + for scoring in (('accuracy',), ['precision'], + {'acc': 'accuracy', 'precision': 'precision'}, + ('accuracy', 'precision'), ['precision', 'accuracy'], + {'accuracy': make_scorer(accuracy_score), + 'precision': make_scorer(precision_score)}): + estimator = LinearSVC(random_state=0) + estimator.fit([[1], [2], [3]], [1, 1, 0]) + + scorers, is_multi = _check_multimetric_scoring(estimator, scoring) + assert_true(is_multi) + assert_true(isinstance(scorers, dict)) + assert_equal(sorted(scorers.keys()), sorted(list(scoring))) + assert_true(all([isinstance(scorer, _PredictScorer) + for scorer in list(scorers.values())])) + + if 'acc' in scoring: + assert_almost_equal(scorers['acc']( + estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.) + if 'accuracy' in scoring: + assert_almost_equal(scorers['accuracy']( + estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.) + if 'precision' in scoring: + assert_almost_equal(scorers['precision']( + estimator, [[1], [2], [3]], [1, 0, 0]), 0.5) + + estimator = EstimatorWithFitAndPredict() + estimator.fit([[1]], [1]) + + # Make sure it raises errors when scoring parameter is not valid. + # More weird corner cases are tested at test_validation.py + error_message_regexp = ".*must be unique strings.*" + for scoring in ((make_scorer(precision_score), # Tuple of callables + make_scorer(accuracy_score)), [5], + (make_scorer(precision_score),), (), ('f1', 'f1')): + assert_raises_regexp(ValueError, error_message_regexp, + _check_multimetric_scoring, estimator, + scoring=scoring) def test_check_scoring_gridsearchcv(): diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index 73c842e706df8fc3d3ffd781475d8107af195013..82a9b9371710d214f02e8b7a752220a71b97ae50 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -18,6 +18,7 @@ from ._split import check_cv from ._validation import cross_val_score from ._validation import cross_val_predict +from ._validation import cross_validate from ._validation import learning_curve from ._validation import permutation_test_score from ._validation import validation_curve @@ -50,6 +51,7 @@ __all__ = ('BaseCrossValidator', 'check_cv', 'cross_val_predict', 'cross_val_score', + 'cross_validate', 'fit_grid_point', 'learning_curve', 'permutation_test_score', diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 67bd8597de0d4029656e944d05d41719308b0c39..17c588c293eda708f98c4239abb7c1cadb946d9a 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -9,6 +9,7 @@ from __future__ import division # Gael Varoquaux <gael.varoquaux@normalesup.org> # Andreas Mueller <amueller@ais.uni-bonn.de> # Olivier Grisel <olivier.grisel@ensta.org> +# Raghav RV <rvraghav93@gmail.com> # License: BSD 3 clause from abc import ABCMeta, abstractmethod @@ -25,6 +26,7 @@ from ..base import BaseEstimator, is_classifier, clone from ..base import MetaEstimatorMixin from ._split import check_cv from ._validation import _fit_and_score +from ._validation import _aggregate_score_dicts from ..exceptions import NotFittedError from ..externals.joblib import Parallel, delayed from ..externals import six @@ -34,6 +36,7 @@ from ..utils.fixes import MaskedArray from ..utils.random import sample_without_replacement from ..utils.validation import indexable, check_is_fitted from ..utils.metaestimators import if_delegate_has_method +from ..metrics.scorer import _check_multimetric_scoring from ..metrics.scorer import check_scoring @@ -295,10 +298,12 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, test : ndarray, dtype int or bool Boolean mask or indices for test set. - scorer : callable or None. - If provided must be a scorer callable object / function with signature + scorer : callable or None + The scorer callable object / function must have its signature as ``scorer(estimator, X, y)``. + If ``None`` the estimator's default scorer is used. + verbose : int Verbosity level. @@ -314,7 +319,7 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, Returns ------- score : float - Score of this parameter setting on given training / test split. + Score of this parameter setting on given training / test split. parameters : dict The parameters that have been evaluated. @@ -322,12 +327,16 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train, - test, verbose, parameters, - fit_params=fit_params, - return_n_test_samples=True, - error_score=error_score) - return score, parameters, n_samples_test + # NOTE we are not using the return value as the scorer by itself should be + # validated before. We use check_scoring only to reject multimetric scorer + check_scoring(estimator, scorer) + scores, n_samples_test = _fit_and_score(estimator, X, y, + scorer, train, + test, verbose, parameters, + fit_params=fit_params, + return_n_test_samples=True, + error_score=error_score) + return scores, parameters, n_samples_test def _check_param_grid(param_grid): @@ -419,18 +428,23 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, ------- score : float """ + self._check_is_fitted('score') if self.scorer_ is None: raise ValueError("No score function explicitly defined, " "and the estimator doesn't provide one %s" % self.best_estimator_) - return self.scorer_(self.best_estimator_, X, y) + score = self.scorer_[self.refit] if self.multimetric_ else self.scorer_ + return score(self.best_estimator_, X, y) def _check_is_fitted(self, method_name): if not self.refit: - raise NotFittedError(('This GridSearchCV instance was initialized ' - 'with refit=False. %s is ' - 'available only after refitting on the best ' - 'parameters. ') % method_name) + raise NotFittedError('This %s instance was initialized ' + 'with refit=False. %s is ' + 'available only after refitting on the best ' + 'parameters. You can refit an estimator ' + 'manually using the ``best_parameters_`` ' + 'attribute' + % (type(self).__name__, method_name)) else: check_is_fitted(self, 'best_estimator_') @@ -575,7 +589,27 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, fit_params = self.fit_params estimator = self.estimator cv = check_cv(self.cv, y, classifier=is_classifier(estimator)) - self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) + + scorers, self.multimetric_ = _check_multimetric_scoring( + self.estimator, scoring=self.scoring) + + if self.multimetric_: + if self.refit is not False and ( + not isinstance(self.refit, six.string_types) or + # This will work for both dict / list (tuple) + self.refit not in scorers): + raise ValueError("For multi-metric scoring, the parameter " + "refit must be set to a scorer key " + "to refit an estimator with the best " + "parameter setting on the whole data and " + "make the best_* attributes " + "available for that metric. If this is not " + "needed, refit should be set to False " + "explicitly. %r was passed." % self.refit) + else: + refit_metric = self.refit + else: + refit_metric = 'score' X, y, groups = indexable(X, y, groups) n_splits = cv.get_n_splits(X, y, groups) @@ -593,8 +627,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, out = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch - )(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_, - train, test, self.verbose, parameters, + )(delayed(_fit_and_score)(clone(base_estimator), X, y, scorers, train, + test, self.verbose, parameters, fit_params=fit_params, return_train_score=self.return_train_score, return_n_test_samples=True, @@ -605,20 +639,29 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, # if one choose to see train score, "out" will contain train score info if self.return_train_score: - (train_scores, test_scores, test_sample_counts, fit_time, + (train_score_dicts, test_score_dicts, test_sample_counts, fit_time, score_time) = zip(*out) else: - (test_scores, test_sample_counts, fit_time, score_time) = zip(*out) + (test_score_dicts, test_sample_counts, fit_time, + score_time) = zip(*out) + + # test_score_dicts and train_score dicts are lists of dictionaries and + # we make them into dict of lists + test_scores = _aggregate_score_dicts(test_score_dicts) + if self.return_train_score: + train_scores = _aggregate_score_dicts(train_score_dicts) results = dict() def _store(key_name, array, weights=None, splits=False, rank=False): """A small helper to store the scores/times to the cv_results_""" # When iterated first by splits, then by parameters + # We want `array` to have `n_candidates` rows and `n_splits` cols. array = np.array(array, dtype=np.float64).reshape(n_candidates, n_splits) if splits: for split_i in range(n_splits): + # Uses closure to alter the results results["split%d_%s" % (split_i, key_name)] = array[:, split_i] @@ -634,21 +677,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, results["rank_%s" % key_name] = np.asarray( rankdata(-array_means, method='min'), dtype=np.int32) - # Computed the (weighted) mean and std for test scores alone - # NOTE test_sample counts (weights) remain the same for all candidates - test_sample_counts = np.array(test_sample_counts[:n_splits], - dtype=np.int) - - _store('test_score', test_scores, splits=True, rank=True, - weights=test_sample_counts if self.iid else None) - if self.return_train_score: - _store('train_score', train_scores, splits=True) _store('fit_time', fit_time) _store('score_time', score_time) - - best_index = np.flatnonzero(results["rank_test_score"] == 1)[0] - best_parameters = candidate_params[best_index] - # Use one MaskedArray and mask all the places where the param is not # applicable for that candidate. Use defaultdict as each candidate may # not contain all the params @@ -664,45 +694,58 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, param_results["param_%s" % name][cand_i] = value results.update(param_results) - # Store a list of param dicts at the key 'params' results['params'] = candidate_params - self.cv_results_ = results - self.best_index_ = best_index - self.n_splits_ = n_splits + # NOTE test_sample counts (weights) remain the same for all candidates + test_sample_counts = np.array(test_sample_counts[:n_splits], + dtype=np.int) + for scorer_name in scorers.keys(): + # Computed the (weighted) mean and std for test scores alone + _store('test_%s' % scorer_name, test_scores[scorer_name], + splits=True, rank=True, + weights=test_sample_counts if self.iid else None) + if self.return_train_score: + _store('train_%s' % scorer_name, train_scores[scorer_name], + splits=True) + + # For multi-metric evaluation, store the best_index_, best_params_ and + # best_score_ iff refit is one of the scorer names + # In single metric evaluation, refit_metric is "score" + if self.refit or not self.multimetric_: + self.best_index_ = results["rank_test_%s" % refit_metric].argmin() + self.best_params_ = candidate_params[self.best_index_] + self.best_score_ = results["mean_test_%s" % refit_metric][ + self.best_index_] if self.refit: - # fit the best estimator using the entire dataset - # clone first to work around broken estimators - best_estimator = clone(base_estimator).set_params( - **best_parameters) + self.best_estimator_ = clone(base_estimator).set_params( + **self.best_params_) if y is not None: - best_estimator.fit(X, y, **fit_params) + self.best_estimator_.fit(X, y, **fit_params) else: - best_estimator.fit(X, **fit_params) - self.best_estimator_ = best_estimator - return self + self.best_estimator_.fit(X, **fit_params) - @property - def best_params_(self): - check_is_fitted(self, 'cv_results_') - return self.cv_results_['params'][self.best_index_] + # Store the only scorer not as a dict for single metric evaluation + self.scorer_ = scorers if self.multimetric_ else scorers['score'] - @property - def best_score_(self): - check_is_fitted(self, 'cv_results_') - return self.cv_results_['mean_test_score'][self.best_index_] + self.cv_results_ = results + self.n_splits_ = n_splits + + return self @property def grid_scores_(self): + check_is_fitted(self, 'cv_results_') + if self.multimetric_: + raise AttributeError("grid_scores_ attribute is not available for" + " multi-metric evaluation.") warnings.warn( "The grid_scores_ attribute was deprecated in version 0.18" " in favor of the more elaborate cv_results_ attribute." " The grid_scores_ attribute will not be available from 0.20", DeprecationWarning) - check_is_fitted(self, 'cv_results_') grid_scores = list() for i, (params, mean, std) in enumerate(zip( @@ -747,11 +790,20 @@ class GridSearchCV(BaseSearchCV): in the list are explored. This enables searching over any sequence of parameter settings. - scoring : string, callable or None, default=None - A string (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)``. - If ``None``, the ``score`` method of the estimator is used. + scoring : string, callable, list/tuple, dict or None, default: None + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + + For evaluating multiple metrics, either give a list of (unique) strings + or a dict with names as keys and callables as values. + + NOTE that when using custom scorers, each scorer should return a single + value. Metric functions returning a list/array of values can be wrapped + into multiple scorers that return one value each. + + See :ref:`multivalued_scorer_wrapping` for an example. + + If None, the estimator's default scorer (if available) is used. fit_params : dict, optional Parameters to pass to the fit method. @@ -801,10 +853,25 @@ class GridSearchCV(BaseSearchCV): Refer :ref:`User Guide <cross_validation>` for the various cross-validation strategies that can be used here. - refit : boolean, default=True - Refit the best estimator with the entire dataset. - If "False", it is impossible to make predictions using - this GridSearchCV instance after fitting. + refit : boolean, or string, default=True + Refit an estimator using the best found parameters on the whole + dataset. + + For multiple metric evaluation, this needs to be a string denoting the + scorer is used to find the best parameters for refitting the estimator + at the end. + + The refitted estimator is made available at the ``best_estimator_`` + attribute and permits using ``predict`` directly on this + ``GridSearchCV`` instance. + + Also for multiple metric evaluation, the attributes ``best_index_``, + ``best_score_`` and ``best_parameters_`` will only be available if + ``refit`` is set and all of them will be determined w.r.t this specific + scorer. + + See ``scoring`` parameter to know more about multiple metric + evaluation. verbose : integer Controls the verbosity: the higher, the more messages. @@ -857,7 +924,7 @@ class GridSearchCV(BaseSearchCV): For instance the below given table +------------+-----------+------------+-----------------+---+---------+ - |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_....| + |param_kernel|param_gamma|param_degree|split0_test_score|...|..rank...| +============+===========+============+=================+===+=========+ | 'poly' | -- | 2 | 0.8 |...| 2 | +------------+-----------+------------+-----------------+---+---------+ @@ -893,23 +960,38 @@ class GridSearchCV(BaseSearchCV): 'params' : [{'kernel': 'poly', 'degree': 2}, ...], } - NOTE that the key ``'params'`` is used to store a list of parameter - settings dict for all the parameter candidates. + NOTE + + The key ``'params'`` is used to store a list of parameter + settings dicts for all the parameter candidates. The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and ``std_score_time`` are all in seconds. - best_estimator_ : estimator + For multi-metric evaluation, the scores for all the scorers are + available in the ``cv_results_`` dict at the keys ending with that + scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown + above. ('split0_test_precision', 'mean_train_precision' etc.) + + best_estimator_ : estimator or dict Estimator that was chosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. Not available if refit=False. + on the left out data. Not available if ``refit=False``. + + See ``refit`` parameter for more information on allowed values. best_score_ : float - Score of best_estimator on the left out data. + Mean cross-validated score of the best_estimator + + For multi-metric evaluation, this is present only if ``refit`` is + specified. best_params_ : dict Parameter setting that gave the best results on the hold out data. + For multi-metric evaluation, this is present only if ``refit`` is + specified. + best_index_ : int The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting. @@ -918,10 +1000,16 @@ class GridSearchCV(BaseSearchCV): the parameter setting for the best model, that gives the highest mean score (``search.best_score_``). - scorer_ : function + For multi-metric evaluation, this is present only if ``refit`` is + specified. + + scorer_ : function or a dict Scorer function used on the held out data to choose the best parameters for the model. + For multi-metric evaluation, this attribute holds the validated + ``scoring`` dict which maps the scorer key to the scorer callable. + n_splits_ : int The number of cross-validation splits (folds/iterations). @@ -1012,11 +1100,20 @@ class RandomizedSearchCV(BaseSearchCV): Number of parameter settings that are sampled. n_iter trades off runtime vs quality of the solution. - scoring : string, callable or None, default=None - A string (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)``. - If ``None``, the ``score`` method of the estimator is used. + scoring : string, callable, list/tuple, dict or None, default: None + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + + For evaluating multiple metrics, either give a list of (unique) strings + or a dict with names as keys and callables as values. + + NOTE that when using custom scorers, each scorer should return a single + value. Metric functions returning a list/array of values can be wrapped + into multiple scorers that return one value each. + + See :ref:`multivalued_scorer_wrapping` for an example. + + If None, the estimator's default scorer (if available) is used. fit_params : dict, optional Parameters to pass to the fit method. @@ -1066,10 +1163,25 @@ class RandomizedSearchCV(BaseSearchCV): Refer :ref:`User Guide <cross_validation>` for the various cross-validation strategies that can be used here. - refit : boolean, default=True - Refit the best estimator with the entire dataset. - If "False", it is impossible to make predictions using - this RandomizedSearchCV instance after fitting. + refit : boolean, or string default=True + Refit an estimator using the best found parameters on the whole + dataset. + + For multiple metric evaluation, this needs to be a string denoting the + scorer that would be used to find the best parameters for refitting + the estimator at the end. + + The refitted estimator is made available at the ``best_estimator_`` + attribute and permits using ``predict`` directly on this + ``RandomizedSearchCV`` instance. + + Also for multiple metric evaluation, the attributes ``best_index_``, + ``best_score_`` and ``best_parameters_`` will only be available if + ``refit`` is set and all of them will be determined w.r.t this specific + scorer. + + See ``scoring`` parameter to know more about multiple metric + evaluation. verbose : integer Controls the verbosity: the higher, the more messages. @@ -1129,26 +1241,44 @@ class RandomizedSearchCV(BaseSearchCV): 'std_fit_time' : [0.01, 0.02, 0.01, 0.01], 'mean_score_time' : [0.007, 0.06, 0.04, 0.04], 'std_score_time' : [0.001, 0.002, 0.003, 0.005], - 'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...], + 'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...], } - NOTE that the key ``'params'`` is used to store a list of parameter - settings dict for all the parameter candidates. + NOTE + + The key ``'params'`` is used to store a list of parameter + settings dicts for all the parameter candidates. The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and ``std_score_time`` are all in seconds. - best_estimator_ : estimator + For multi-metric evaluation, the scores for all the scorers are + available in the ``cv_results_`` dict at the keys ending with that + scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown + above. ('split0_test_precision', 'mean_train_precision' etc.) + + best_estimator_ : estimator or dict Estimator that was chosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. Not available if refit=False. + on the left out data. Not available if ``refit=False``. + + For multi-metric evaluation, this attribute is present only if + ``refit`` is specified. + + See ``refit`` parameter for more information on allowed values. best_score_ : float - Score of best_estimator on the left out data. + Mean cross-validated score of the best_estimator. + + For multi-metric evaluation, this is not available if ``refit`` is + ``False``. See ``refit`` parameter for more information. best_params_ : dict Parameter setting that gave the best results on the hold out data. + For multi-metric evaluation, this is not available if ``refit`` is + ``False``. See ``refit`` parameter for more information. + best_index_ : int The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting. @@ -1157,10 +1287,16 @@ class RandomizedSearchCV(BaseSearchCV): the parameter setting for the best model, that gives the highest mean score (``search.best_score_``). - scorer_ : function + For multi-metric evaluation, this is not available if ``refit`` is + ``False``. See ``refit`` parameter for more information. + + scorer_ : function or a dict Scorer function used on the held out data to choose the best parameters for the model. + For multi-metric evaluation, this attribute holds the validated + ``scoring`` dict which maps the scorer key to the scorer callable. + n_splits_ : int The number of cross-validation splits (folds/iterations). diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index fe9c0e8c46c094190c7edabe0c6339e5b19165cc..1e5ea29740c00f28e389c92455a1df70e36d6a04 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -3,12 +3,12 @@ The :mod:`sklearn.model_selection._validation` module includes classes and functions to validate the model. """ -# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>, -# Gael Varoquaux <gael.varoquaux@normalesup.org>, +# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr> +# Gael Varoquaux <gael.varoquaux@normalesup.org> # Olivier Grisel <olivier.grisel@ensta.org> +# Raghav RV <rvraghav93@gmail.com> # License: BSD 3 clause - from __future__ import print_function from __future__ import division @@ -24,13 +24,193 @@ from ..utils import indexable, check_random_state, safe_indexing from ..utils.validation import _is_arraylike, _num_samples from ..utils.metaestimators import _safe_split from ..externals.joblib import Parallel, delayed, logger -from ..metrics.scorer import check_scoring +from ..externals.six.moves import zip +from ..metrics.scorer import check_scoring, _check_multimetric_scoring from ..exceptions import FitFailedWarning from ._split import check_cv from ..preprocessing import LabelEncoder -__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score', - 'learning_curve', 'validation_curve'] + +__all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict', + 'permutation_test_score', 'learning_curve', 'validation_curve'] + + +def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, + n_jobs=1, verbose=0, fit_params=None, + pre_dispatch='2*n_jobs', return_train_score=True): + """Evaluate metric(s) by cross-validation and also record fit/score times. + + Read more in the :ref:`User Guide <multimetric_cross_validation>`. + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like + The data to fit. Can be for example a list, or an array. + + y : array-like, optional, default: None + The target variable to try to predict in the case of + supervised learning. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + scoring : string, callable, list/tuple, dict or None, default: None + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + + For evaluating multiple metrics, either give a list of (unique) strings + or a dict with names as keys and callables as values. + + NOTE that when using custom scorers, each scorer should return a single + value. Metric functions returning a list/array of values can be wrapped + into multiple scorers that return one value each. + + See :ref:`multivalued_scorer_wrapping` for an example. + + If None, the estimator's default scorer (if available) is used. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, if the estimator is a classifier and ``y`` is + either binary or multiclass, :class:`StratifiedKFold` is used. In all + other cases, :class:`KFold` is used. + + Refer :ref:`User Guide <cross_validation>` for the various + cross-validation strategies that can be used here. + + n_jobs : integer, optional + The number of CPUs to use to do the computation. -1 means + 'all CPUs'. + + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + return_train_score : boolean, default True + Whether to include train scores in the return dict if ``scoring`` is + of multimetric type. + + Returns + ------- + scores : dict of float arrays of shape=(n_splits,) + Array of scores of the estimator for each run of the cross validation. + + A dict of arrays containing the score/time arrays for each scorer is + returned. The possible keys for this ``dict`` are: + + ``test_score`` + The score array for test scores on each cv split. + ``train_score`` + The score array for train scores on each cv split. + This is available only if ``return_train_score`` parameter + is ``True``. + ``fit_time`` + The time for fitting the estimator on the train + set for each cv split. + ``score_time`` + The time for scoring the estimator on the test set for each + cv split. (Note time for scoring on the train set is not + included even if ``return_train_score`` is set to ``True`` + + Examples + -------- + >>> from sklearn import datasets, linear_model + >>> from sklearn.model_selection import cross_val_score + >>> from sklearn.metrics.scorer import make_scorer + >>> from sklearn.metrics import confusion_matrix + >>> from sklearn.svm import LinearSVC + >>> diabetes = datasets.load_diabetes() + >>> X = diabetes.data[:150] + >>> y = diabetes.target[:150] + >>> lasso = linear_model.Lasso() + + # single metric evaluation using cross_validate + >>> cv_results = cross_validate(lasso, X, y, return_train_score=False) + >>> sorted(cv_results.keys()) # doctest: +ELLIPSIS + ['fit_time', 'score_time', 'test_score'] + >>> cv_results['test_score'] # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + array([ 0.33..., 0.08..., 0.03...]) + + # Multiple metric evaluation using cross_validate + # (Please refer the ``scoring`` parameter doc for more information) + >>> scores = cross_validate(lasso, X, y, + ... scoring=('r2', 'neg_mean_squared_error')) + >>> print(scores['test_neg_mean_squared_error']) # doctest: +ELLIPSIS + [-3635.5... -3573.3... -6114.7...] + >>> print(scores['train_r2']) # doctest: +ELLIPSIS + [ 0.28... 0.39... 0.22...] + + See Also + --------- + :func:`sklearn.metrics.cross_val_score`: + Run cross-validation for single metric evaluation. + + :func:`sklearn.metrics.make_scorer`: + Make a scorer from a performance metric or loss function. + + """ + X, y, groups = indexable(X, y, groups) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring) + + # We clone the estimator to make sure that all the folds are + # independent, and that it is pickle-able. + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) + scores = parallel( + delayed(_fit_and_score)( + clone(estimator), X, y, scorers, train, test, verbose, None, + fit_params, return_train_score=return_train_score, + return_times=True) + for train, test in cv.split(X, y, groups)) + + if return_train_score: + train_scores, test_scores, fit_times, score_times = zip(*scores) + train_scores = _aggregate_score_dicts(train_scores) + else: + test_scores, fit_times, score_times = zip(*scores) + test_scores = _aggregate_score_dicts(test_scores) + + ret = dict() + ret['fit_time'] = np.array(fit_times) + ret['score_time'] = np.array(score_times) + + for name in scorers: + ret['test_%s' % name] = np.array(test_scores[name]) + if return_train_score: + ret['train_%s' % name] = np.array(train_scores[name]) + + return ret def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, @@ -46,7 +226,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, The object to use to fit the data. X : array-like - The data to fit. Can be, for example a list, or an array at least 2d. + The data to fit. Can be for example a list, or an array. y : array-like, optional, default: None The target variable to try to predict in the case of @@ -122,23 +302,24 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, See Also --------- + :func:`sklearn.model_selection.cross_validate`: + To run cross-validation on multiple metrics and also to return + train scores, fit times and score times. + :func:`sklearn.metrics.make_scorer`: Make a scorer from a performance metric or loss function. """ - X, y, groups = indexable(X, y, groups) - - cv = check_cv(cv, y, classifier=is_classifier(estimator)) + # To ensure multimetric format is not supported scorer = check_scoring(estimator, scoring=scoring) - # We clone the estimator to make sure that all the folds are - # independent, and that it is pickle-able. - parallel = Parallel(n_jobs=n_jobs, verbose=verbose, - pre_dispatch=pre_dispatch) - scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, - train, test, verbose, None, - fit_params) - for train, test in cv.split(X, y, groups)) - return np.array(scores)[:, 0] + + cv_results = cross_validate(estimator=estimator, X=X, y=y, groups=groups, + scoring={'score': scorer}, cv=cv, + return_train_score=False, + n_jobs=n_jobs, verbose=verbose, + fit_params=fit_params, + pre_dispatch=pre_dispatch) + return cv_results['test_score'] def _fit_and_score(estimator, X, y, scorer, train, test, verbose, @@ -159,8 +340,14 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, The target variable to try to predict in the case of supervised learning. - scorer : callable - A scorer callable object / function with signature + scorer : A single callable or dict mapping scorer name to the callable + If it is a single callable, the return value for ``train_scores`` and + ``test_scores`` is a single float. + + For a dict, it should be one mapping the scorer name to the scorer + callable object / function. + + The callable object / fn should have signature ``scorer(estimator, X, y)``. train : array-like, shape (n_train_samples,) @@ -190,13 +377,20 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_parameters : boolean, optional, default: False Return parameters that has been used for the estimator. + return_n_test_samples : boolean, optional, default: False + Whether to return the ``n_test_samples`` + + return_times : boolean, optional, default: False + Whether to return the fit/score times. + Returns ------- - train_score : float, optional - Score on training set, returned only if `return_train_score` is `True`. + train_scores : dict of scorer name -> float, optional + Score on training set (for all the scorers), + returned only if `return_train_score` is `True`. - test_score : float - Score on test set. + test_scores : dict of scorer name -> float, optional + Score on testing set (for all the scorers). n_test_samples : int Number of test samples. @@ -223,6 +417,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, fit_params = dict([(k, _index_param_value(X, v, train)) for k, v in fit_params.items()]) + test_scores = {} + train_scores = {} if parameters is not None: estimator.set_params(**parameters) @@ -231,6 +427,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) + is_multimetric = not callable(scorer) + n_scorers = len(scorer.keys()) if is_multimetric else 1 + try: if y_train is None: estimator.fit(X_train, **fit_params) @@ -244,9 +443,16 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, if error_score == 'raise': raise elif isinstance(error_score, numbers.Number): - test_score = error_score - if return_train_score: - train_score = error_score + if is_multimetric: + test_scores = dict(zip(scorer.keys(), + [error_score, ] * n_scorers)) + if return_train_score: + train_scores = dict(zip(scorer.keys(), + [error_score, ] * n_scorers)) + else: + test_scores = error_score + if return_train_score: + train_scores = error_score warnings.warn("Classifier fit failed. The score on this train-test" " partition for these parameters will be set to %f. " "Details: \n%r" % (error_score, e), FitFailedWarning) @@ -257,19 +463,25 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, else: fit_time = time.time() - start_time - test_score = _score(estimator, X_test, y_test, scorer) + # _score will return dict if is_multimetric is True + test_scores = _score(estimator, X_test, y_test, scorer, is_multimetric) score_time = time.time() - start_time - fit_time if return_train_score: - train_score = _score(estimator, X_train, y_train, scorer) + train_scores = _score(estimator, X_train, y_train, scorer, + is_multimetric) if verbose > 2: - msg += ", score=%f" % test_score + if is_multimetric: + for scorer_name, score in test_scores.items(): + msg += ", %s=%s" % (scorer_name, score) + else: + msg += ", score=%s" % test_scores if verbose > 1: total_time = score_time + fit_time end_msg = "%s, total=%s" % (msg, logger.short_format_time(total_time)) print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - ret = [train_score, test_score] if return_train_score else [test_score] + ret = [train_scores, test_scores] if return_train_score else [test_scores] if return_n_test_samples: ret.append(_num_samples(X_test)) @@ -280,25 +492,61 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return ret -def _score(estimator, X_test, y_test, scorer): - """Compute the score of an estimator on a given test set.""" - if y_test is None: - score = scorer(estimator, X_test) +def _score(estimator, X_test, y_test, scorer, is_multimetric=False): + """Compute the score(s) of an estimator on a given test set. + + Will return a single float if is_multimetric is False and a dict of floats, + if is_multimetric is True + """ + if is_multimetric: + return _multimetric_score(estimator, X_test, y_test, scorer) else: - score = scorer(estimator, X_test, y_test) - if hasattr(score, 'item'): - try: - # e.g. unwrap memmapped scalars - score = score.item() - except ValueError: - # non-scalar? - pass - if not isinstance(score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s) instead." - % (str(score), type(score))) + if y_test is None: + score = scorer(estimator, X_test) + else: + score = scorer(estimator, X_test, y_test) + + if hasattr(score, 'item'): + try: + # e.g. unwrap memmapped scalars + score = score.item() + except ValueError: + # non-scalar? + pass + + if not isinstance(score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s) " + "instead. (scorer=%r)" + % (str(score), type(score), scorer)) return score +def _multimetric_score(estimator, X_test, y_test, scorers): + """Return a dict of score for multimetric scoring""" + scores = {} + + for name, scorer in scorers.items(): + if y_test is None: + score = scorer(estimator, X_test) + else: + score = scorer(estimator, X_test, y_test) + + if hasattr(score, 'item'): + try: + # e.g. unwrap memmapped scalars + score = score.item() + except ValueError: + # non-scalar? + pass + scores[name] = score + + if not isinstance(score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s) " + "instead. (scorer=%s)" + % (str(score), type(score), name)) + return scores + + def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', method='predict'): @@ -555,9 +803,10 @@ def permutation_test_score(estimator, X, y, groups=None, cv=None, the dataset into train/test set. scoring : string, callable or None, optional, default: None - A string (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)``. + A single string (see :ref:`_scoring_parameter`) or a callable + (see :ref:`_scoring`) to evaluate the predictions on the test set. + + If None the estimator's default scorer, if available, is used. cv : int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy. @@ -997,10 +1246,38 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, parameters={param_name: v}, fit_params=None, return_train_score=True) # NOTE do not change order of iteration to allow one time cv splitters for train, test in cv.split(X, y, groups) for v in param_range) - out = np.asarray(out) n_params = len(param_range) n_cv_folds = out.shape[0] // n_params out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0)) return out[0], out[1] + + +def _aggregate_score_dicts(scores): + """Aggregate the list of dict to dict of np ndarray + + The aggregated output of _fit_and_score will be a list of dict + of form [{'prec': 0.1, 'acc':1.0}, {'prec': 0.1, 'acc':1.0}, ...] + Convert it to a dict of array {'prec': np.array([0.1 ...]), ...} + + Parameters + ---------- + + scores : list of dict + List of dicts of the scores for all scorers. This is a flat list, + assumed originally to be of row major order. + + Example + ------- + + >>> scores = [{'a': 1, 'b':10}, {'a': 2, 'b':2}, {'a': 3, 'b':3}, + ... {'a': 10, 'b': 10}] # doctest: +SKIP + >>> _aggregate_score_dicts(scores) # doctest: +SKIP + {'a': array([1, 2, 3, 10]), + 'b': array([10, 2, 3, 10])} + """ + out = {} + for key in scores[0]: + out[key] = np.asarray([score[key] for score in scores]) + return out diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 9e6fd57ccdbc05eb5d0021abc755b97e290e9481..9dfd49714ee08de8848e2f77ea1aa163e1fc04fd 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -7,6 +7,7 @@ from sklearn.externals.joblib._compat import PY3_OR_LATER from itertools import chain, product import pickle import sys +import re import numpy as np import scipy.sparse as sp @@ -27,13 +28,14 @@ from sklearn.utils.mocking import CheckingClassifier, MockDataFrame from scipy.stats import bernoulli, expon, uniform -from sklearn.externals.six.moves import zip from sklearn.base import BaseEstimator +from sklearn.base import clone from sklearn.exceptions import NotFittedError from sklearn.datasets import make_classification from sklearn.datasets import make_blobs from sklearn.datasets import make_multilabel_classification +from sklearn.model_selection import fit_grid_point from sklearn.model_selection import KFold from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedShuffleSplit @@ -54,6 +56,8 @@ from sklearn.tree import DecisionTreeClassifier from sklearn.cluster import KMeans from sklearn.neighbors import KernelDensity from sklearn.metrics import f1_score +from sklearn.metrics import recall_score +from sklearn.metrics import accuracy_score from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score from sklearn.preprocessing import Imputer @@ -370,19 +374,30 @@ def test_trivial_cv_results_attr(): def test_no_refit(): # Test that GSCV can be used for model selection alone without refitting clf = MockClassifier() - grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False) - grid_search.fit(X, y) - assert_true(not hasattr(grid_search, "best_estimator_") and - hasattr(grid_search, "best_index_") and - hasattr(grid_search, "best_params_")) - - # Make sure the predict/transform etc fns raise meaningfull error msg - for fn_name in ('predict', 'predict_proba', 'predict_log_proba', - 'transform', 'inverse_transform'): - assert_raise_message(NotFittedError, - ('refit=False. %s is available only after ' - 'refitting on the best parameters' % fn_name), - getattr(grid_search, fn_name), X) + for scoring in [None, ['accuracy', 'precision']]: + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False) + grid_search.fit(X, y) + assert_true(not hasattr(grid_search, "best_estimator_") and + hasattr(grid_search, "best_index_") and + hasattr(grid_search, "best_params_")) + + # Make sure the functions predict/transform etc raise meaningful + # error messages + for fn_name in ('predict', 'predict_proba', 'predict_log_proba', + 'transform', 'inverse_transform'): + assert_raise_message(NotFittedError, + ('refit=False. %s is available only after ' + 'refitting on the best parameters' + % fn_name), getattr(grid_search, fn_name), X) + + # Test that an invalid refit param raises appropriate error messages + for refit in ["", 5, True, 'recall', 'accuracy']: + assert_raise_message(ValueError, "For multi-metric scoring, the " + "parameter refit must be set to a scorer key", + GridSearchCV(clf, {}, refit=refit, + scoring={'acc': 'accuracy', + 'prec': 'precision'}).fit, + X, y) def test_grid_search_error(): @@ -622,8 +637,13 @@ def test_pandas_input(): for InputFeatureType, TargetType in types: # X dataframe, y series X_df, y_ser = InputFeatureType(X), TargetType(y) - check_df = lambda x: isinstance(x, InputFeatureType) - check_series = lambda x: isinstance(x, TargetType) + + def check_df(x): + return isinstance(x, InputFeatureType) + + def check_series(x): + return isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) @@ -636,16 +656,20 @@ def test_unsupervised_grid_search(): # test grid-search with unsupervised estimator X, y = make_blobs(random_state=0) km = KMeans(random_state=0) - grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), - scoring='adjusted_rand_score') - grid_search.fit(X, y) - # ARI can find the right number :) - assert_equal(grid_search.best_params_["n_clusters"], 3) + # Multi-metric evaluation unsupervised + scoring = ['adjusted_rand_score', 'fowlkes_mallows_score'] + for refit in ['adjusted_rand_score', 'fowlkes_mallows_score']: + grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), + scoring=scoring, refit=refit) + grid_search.fit(X, y) + # Both ARI and FMS can find the right number :) + assert_equal(grid_search.best_params_["n_clusters"], 3) + + # Single metric evaluation unsupervised grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), scoring='fowlkes_mallows_score') grid_search.fit(X, y) - # So can FMS ;) assert_equal(grid_search.best_params_["n_clusters"], 3) # Now without a score, and without y @@ -694,8 +718,9 @@ def test_param_sampler(): assert_equal([x for x in sampler], [x for x in sampler]) -def check_cv_results_array_types(cv_results, param_keys, score_keys): +def check_cv_results_array_types(search, param_keys, score_keys): # Check if the search `cv_results`'s array are of correct types + cv_results = search.cv_results_ assert_true(all(isinstance(cv_results[param], np.ma.MaskedArray) for param in param_keys)) assert_true(all(cv_results[key].dtype == object for key in param_keys)) @@ -703,7 +728,11 @@ def check_cv_results_array_types(cv_results, param_keys, score_keys): for key in score_keys)) assert_true(all(cv_results[key].dtype == np.float64 for key in score_keys if not key.startswith('rank'))) - assert_true(cv_results['rank_test_score'].dtype == np.int32) + + scorer_keys = search.scorer_.keys() if search.multimetric_ else ['score'] + + for key in scorer_keys: + assert_true(cv_results['rank_test_%s' % key].dtype == np.int32) def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand): @@ -715,22 +744,27 @@ def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand): def check_cv_results_grid_scores_consistency(search): - # TODO Remove in 0.20 - cv_results = search.cv_results_ - res_scores = np.vstack(list([cv_results["split%d_test_score" % i] - for i in range(search.n_splits_)])).T - res_means = cv_results["mean_test_score"] - res_params = cv_results["params"] - n_cand = len(res_params) - grid_scores = assert_warns(DeprecationWarning, getattr, - search, 'grid_scores_') - assert_equal(len(grid_scores), n_cand) - # Check consistency of the structure of grid_scores - for i in range(n_cand): - assert_equal(grid_scores[i].parameters, res_params[i]) - assert_array_equal(grid_scores[i].cv_validation_scores, - res_scores[i, :]) - assert_array_equal(grid_scores[i].mean_validation_score, res_means[i]) + # TODO Remove test in 0.20 + if search.multimetric_: + assert_raise_message(AttributeError, "not available for multi-metric", + getattr, search, 'grid_scores_') + else: + cv_results = search.cv_results_ + res_scores = np.vstack(list([cv_results["split%d_test_score" % i] + for i in range(search.n_splits_)])).T + res_means = cv_results["mean_test_score"] + res_params = cv_results["params"] + n_cand = len(res_params) + grid_scores = assert_warns(DeprecationWarning, getattr, + search, 'grid_scores_') + assert_equal(len(grid_scores), n_cand) + # Check consistency of the structure of grid_scores + for i in range(n_cand): + assert_equal(grid_scores[i].parameters, res_params[i]) + assert_array_equal(grid_scores[i].cv_validation_scores, + res_scores[i, :]) + assert_array_equal(grid_scores[i].mean_validation_score, + res_means[i]) def test_grid_search_cv_results(): @@ -741,12 +775,6 @@ def test_grid_search_cv_results(): n_grid_points = 6 params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]), dict(kernel=['poly', ], degree=[1, 2])] - grid_search = GridSearchCV(SVC(), cv=n_splits, iid=False, - param_grid=params) - grid_search.fit(X, y) - grid_search_iid = GridSearchCV(SVC(), cv=n_splits, iid=True, - param_grid=params) - grid_search_iid.fit(X, y) param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel') score_keys = ('mean_test_score', 'mean_train_score', @@ -760,7 +788,9 @@ def test_grid_search_cv_results(): 'mean_score_time', 'std_score_time') n_candidates = n_grid_points - for search, iid in zip((grid_search, grid_search_iid), (False, True)): + for iid in (False, True): + search = GridSearchCV(SVC(), cv=n_splits, iid=iid, param_grid=params) + search.fit(X, y) assert_equal(iid, search.iid) cv_results = search.cv_results_ # Check if score and timing are reasonable @@ -771,11 +801,11 @@ def test_grid_search_cv_results(): if 'time' not in k and k is not 'rank_test_score') # Check cv_results structure - check_cv_results_array_types(cv_results, param_keys, score_keys) + check_cv_results_array_types(search, param_keys, score_keys) check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates) # Check masking - cv_results = grid_search.cv_results_ - n_candidates = len(grid_search.cv_results_['params']) + cv_results = search.cv_results_ + n_candidates = len(search.cv_results_['params']) assert_true(all((cv_results['param_C'].mask[i] and cv_results['param_gamma'].mask[i] and not cv_results['param_degree'].mask[i]) @@ -790,26 +820,12 @@ def test_grid_search_cv_results(): def test_random_search_cv_results(): - # Make a dataset with a lot of noise to get various kind of prediction - # errors across CV folds and parameter settings - X, y = make_classification(n_samples=200, n_features=100, n_informative=3, - random_state=0) + X, y = make_classification(n_samples=50, n_features=4, random_state=42) - # scipy.stats dists now supports `seed` but we still support scipy 0.12 - # which doesn't support the seed. Hence the assertions in the test for - # random_search alone should not depend on randomization. n_splits = 3 n_search_iter = 30 - params = dict(C=expon(scale=10), gamma=expon(scale=0.1)) - random_search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, - cv=n_splits, iid=False, - param_distributions=params) - random_search.fit(X, y) - random_search_iid = RandomizedSearchCV(SVC(), n_iter=n_search_iter, - cv=n_splits, iid=True, - param_distributions=params) - random_search_iid.fit(X, y) + params = dict(C=expon(scale=10), gamma=expon(scale=0.1)) param_keys = ('param_C', 'param_gamma') score_keys = ('mean_test_score', 'mean_train_score', 'rank_test_score', @@ -822,11 +838,14 @@ def test_random_search_cv_results(): 'mean_score_time', 'std_score_time') n_cand = n_search_iter - for search, iid in zip((random_search, random_search_iid), (False, True)): + for iid in (False, True): + search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_splits, + iid=iid, param_distributions=params) + search.fit(X, y) assert_equal(iid, search.iid) cv_results = search.cv_results_ # Check results structure - check_cv_results_array_types(cv_results, param_keys, score_keys) + check_cv_results_array_types(search, param_keys, score_keys) check_cv_results_keys(cv_results, param_keys, score_keys, n_cand) # For random_search, all the param array vals should be unmasked assert_false(any(cv_results['param_C'].mask) or @@ -928,6 +947,108 @@ def test_search_iid_param(): assert_almost_equal(train_std, 0) +def test_grid_search_cv_results_multimetric(): + X, y = make_classification(n_samples=50, n_features=4, random_state=42) + + n_splits = 3 + params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]), + dict(kernel=['poly', ], degree=[1, 2])] + + for iid in (False, True): + grid_searches = [] + for scoring in ({'accuracy': make_scorer(accuracy_score), + 'recall': make_scorer(recall_score)}, + 'accuracy', 'recall'): + grid_search = GridSearchCV(SVC(), cv=n_splits, iid=iid, + param_grid=params, scoring=scoring, + refit=False) + grid_search.fit(X, y) + assert_equal(grid_search.iid, iid) + grid_searches.append(grid_search) + + compare_cv_results_multimetric_with_single(*grid_searches, iid=iid) + + +def test_random_search_cv_results_multimetric(): + X, y = make_classification(n_samples=50, n_features=4, random_state=42) + + n_splits = 3 + n_search_iter = 30 + scoring = ('accuracy', 'recall') + + # Scipy 0.12's stats dists do not accept seed, hence we use param grid + params = dict(C=np.logspace(-10, 1), gamma=np.logspace(-5, 0, base=0.1)) + for iid in (True, False): + for refit in (True, False): + random_searches = [] + for scoring in (('accuracy', 'recall'), 'accuracy', 'recall'): + # If True, for multi-metric pass refit='accuracy' + if refit: + refit = 'accuracy' if isinstance(scoring, tuple) else refit + clf = SVC(probability=True, random_state=42) + random_search = RandomizedSearchCV(clf, n_iter=n_search_iter, + cv=n_splits, iid=iid, + param_distributions=params, + scoring=scoring, + refit=refit, random_state=0) + random_search.fit(X, y) + random_searches.append(random_search) + + compare_cv_results_multimetric_with_single(*random_searches, + iid=iid) + if refit: + compare_refit_methods_when_refit_with_acc( + random_searches[0], random_searches[1], refit) + + +def compare_cv_results_multimetric_with_single( + search_multi, search_acc, search_rec, iid): + """Compare multi-metric cv_results with the ensemble of multiple + single metric cv_results from single metric grid/random search""" + + assert_equal(search_multi.iid, iid) + assert_true(search_multi.multimetric_) + assert_array_equal(sorted(search_multi.scorer_), + ('accuracy', 'recall')) + + cv_results_multi = search_multi.cv_results_ + cv_results_acc_rec = {re.sub('_score$', '_accuracy', k): v + for k, v in search_acc.cv_results_.items()} + cv_results_acc_rec.update({re.sub('_score$', '_recall', k): v + for k, v in search_rec.cv_results_.items()}) + + # Check if score and timing are reasonable, also checks if the keys + # are present + assert_true(all((np.all(cv_results_multi[k] <= 1) for k in ( + 'mean_score_time', 'std_score_time', 'mean_fit_time', + 'std_fit_time')))) + + # Compare the keys, other than time keys, among multi-metric and + # single metric grid search results. np.testing.assert_equal performs a + # deep nested comparison of the two cv_results dicts + np.testing.assert_equal({k: v for k, v in cv_results_multi.items() + if not k.endswith('_time')}, + {k: v for k, v in cv_results_acc_rec.items() + if not k.endswith('_time')}) + + +def compare_refit_methods_when_refit_with_acc(search_multi, search_acc, refit): + """Compare refit multi-metric search methods with single metric methods""" + if refit: + assert_equal(search_multi.refit, 'accuracy') + else: + assert_false(search_multi.refit) + assert_equal(search_acc.refit, refit) + + X, y = make_blobs(n_samples=100, n_features=4, random_state=42) + for method in ('predict', 'predict_proba', 'predict_log_proba'): + assert_almost_equal(getattr(search_multi, method)(X), + getattr(search_acc, method)(X)) + assert_almost_equal(search_multi.score(X, y), search_acc.score(X, y)) + for key in ('best_index_', 'best_score_', 'best_params_'): + assert_equal(getattr(search_multi, key), getattr(search_acc, key)) + + def test_search_cv_results_rank_tie_breaking(): X, y = make_blobs(n_samples=50, random_state=42) @@ -1034,6 +1155,34 @@ def test_grid_search_correct_score_results(): assert_almost_equal(correct_score, cv_scores[i]) +def test_fit_grid_point(): + X, y = make_classification(random_state=0) + cv = StratifiedKFold(random_state=0) + svc = LinearSVC(random_state=0) + scorer = make_scorer(accuracy_score) + + for params in ({'C': 0.1}, {'C': 0.01}, {'C': 0.001}): + for train, test in cv.split(X, y): + this_scores, this_params, n_test_samples = fit_grid_point( + X, y, clone(svc), params, train, test, + scorer, verbose=False) + + est = clone(svc).set_params(**params) + est.fit(X[train], y[train]) + expected_score = scorer(est, X[test], y[test]) + + # Test the return values of fit_grid_point + assert_almost_equal(this_scores, expected_score) + assert_equal(params, this_params) + assert_equal(n_test_samples, test.size) + + # Should raise an error upon multimetric scorer + assert_raise_message(ValueError, "scoring value should either be a " + "callable, string or None.", fit_grid_point, X, y, + svc, params, train, test, {'score': scorer}, + verbose=True) + + def test_pickle(): # Test that a fit search can be pickled clf = MockClassifier() @@ -1272,20 +1421,16 @@ def test_grid_search_cv_splits_consistency(): cv=KFold(n_splits=n_splits)) gs2.fit(X, y) - def _pop_time_keys(cv_results): - for key in ('mean_fit_time', 'std_fit_time', - 'mean_score_time', 'std_score_time'): - cv_results.pop(key) - return cv_results - # OneTimeSplitter is a non-re-entrant cv where split can be called only # once if ``cv.split`` is called once per param setting in GridSearchCV.fit # the 2nd and 3rd parameter will not be evaluated as no train/test indices # will be generated for the 2nd and subsequent cv.split calls. # This is a check to make sure cv.split is not called once per param # setting. - np.testing.assert_equal(_pop_time_keys(gs.cv_results_), - _pop_time_keys(gs2.cv_results_)) + np.testing.assert_equal({k: v for k, v in gs.cv_results_.items() + if not k.endswith('_time')}, + {k: v for k, v in gs2.cv_results_.items() + if not k.endswith('_time')}) # Check consistency of folds across the parameters gs = GridSearchCV(LinearSVC(random_state=0), diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 3087c1f3bda9a7ad1a80d2c3cab7a1a3be62b2fd..c73f42fb27dd2c63ef76c59b6da51b8ccb134319 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -16,6 +16,7 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_raises_regex from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_array_almost_equal @@ -25,6 +26,7 @@ from sklearn.utils.mocking import CheckingClassifier, MockDataFrame from sklearn.model_selection import cross_val_score from sklearn.model_selection import cross_val_predict +from sklearn.model_selection import cross_validate from sklearn.model_selection import permutation_test_score from sklearn.model_selection import KFold from sklearn.model_selection import StratifiedKFold @@ -42,7 +44,12 @@ from sklearn.datasets import load_boston from sklearn.datasets import load_iris from sklearn.metrics import explained_variance_score from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score +from sklearn.metrics import confusion_matrix +from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import precision_score +from sklearn.metrics import r2_score +from sklearn.metrics.scorer import check_scoring from sklearn.linear_model import Ridge, LogisticRegression from sklearn.linear_model import PassiveAggressiveClassifier @@ -56,6 +63,7 @@ from sklearn.pipeline import Pipeline from sklearn.externals.six.moves import cStringIO as StringIO from sklearn.base import BaseEstimator +from sklearn.base import clone from sklearn.multiclass import OneVsRestClassifier from sklearn.utils import shuffle from sklearn.datasets import make_classification @@ -262,6 +270,196 @@ def test_cross_val_score(): assert_raises(ValueError, cross_val_score, clf, X_3d, y2) +def test_cross_validate_invalid_scoring_param(): + X, y = make_classification(random_state=0) + estimator = MockClassifier() + + # Test the errors + error_message_regexp = ".*must be unique strings.*" + + # List/tuple of callables should raise a message advising users to use + # dict of names to callables mapping + assert_raises_regex(ValueError, error_message_regexp, + cross_validate, estimator, X, y, + scoring=(make_scorer(precision_score), + make_scorer(accuracy_score))) + assert_raises_regex(ValueError, error_message_regexp, + cross_validate, estimator, X, y, + scoring=(make_scorer(precision_score),)) + + # So should empty lists/tuples + assert_raises_regex(ValueError, error_message_regexp + "Empty list.*", + cross_validate, estimator, X, y, scoring=()) + + # So should duplicated entries + assert_raises_regex(ValueError, error_message_regexp + "Duplicate.*", + cross_validate, estimator, X, y, + scoring=('f1_micro', 'f1_micro')) + + # Nested Lists should raise a generic error message + assert_raises_regex(ValueError, error_message_regexp, + cross_validate, estimator, X, y, + scoring=[[make_scorer(precision_score)]]) + + error_message_regexp = (".*should either be.*string or callable.*for " + "single.*.*dict.*for multi.*") + + # Empty dict should raise invalid scoring error + assert_raises_regex(ValueError, "An empty dict", + cross_validate, estimator, X, y, scoring=(dict())) + + # And so should any other invalid entry + assert_raises_regex(ValueError, error_message_regexp, + cross_validate, estimator, X, y, scoring=5) + + multiclass_scorer = make_scorer(precision_recall_fscore_support) + + # Multiclass Scorers that return multiple values are not supported yet + assert_raises_regex(ValueError, + "Can't handle mix of binary and continuous", + cross_validate, estimator, X, y, + scoring=multiclass_scorer) + assert_raises_regex(ValueError, + "Can't handle mix of binary and continuous", + cross_validate, estimator, X, y, + scoring={"foo": multiclass_scorer}) + + multivalued_scorer = make_scorer(confusion_matrix) + + # Multiclass Scorers that return multiple values are not supported yet + assert_raises_regex(ValueError, "scoring must return a number, got", + cross_validate, SVC(), X, y, + scoring=multivalued_scorer) + assert_raises_regex(ValueError, "scoring must return a number, got", + cross_validate, SVC(), X, y, + scoring={"foo": multivalued_scorer}) + + assert_raises_regex(ValueError, "'mse' is not a valid scoring value.", + cross_validate, SVC(), X, y, scoring="mse") + + +def test_cross_validate(): + # Compute train and test mse/r2 scores + cv = KFold(n_splits=5) + + # Regression + X_reg, y_reg = make_regression(n_samples=30, random_state=0) + reg = Ridge(random_state=0) + + # Classification + X_clf, y_clf = make_classification(n_samples=30, random_state=0) + clf = SVC(kernel="linear", random_state=0) + + for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)): + # It's okay to evaluate regression metrics on classification too + mse_scorer = check_scoring(est, 'neg_mean_squared_error') + r2_scorer = check_scoring(est, 'r2') + train_mse_scores = [] + test_mse_scores = [] + train_r2_scores = [] + test_r2_scores = [] + for train, test in cv.split(X, y): + est = clone(reg).fit(X[train], y[train]) + train_mse_scores.append(mse_scorer(est, X[train], y[train])) + train_r2_scores.append(r2_scorer(est, X[train], y[train])) + test_mse_scores.append(mse_scorer(est, X[test], y[test])) + test_r2_scores.append(r2_scorer(est, X[test], y[test])) + + train_mse_scores = np.array(train_mse_scores) + test_mse_scores = np.array(test_mse_scores) + train_r2_scores = np.array(train_r2_scores) + test_r2_scores = np.array(test_r2_scores) + + scores = (train_mse_scores, test_mse_scores, train_r2_scores, + test_r2_scores) + + yield check_cross_validate_single_metric, est, X, y, scores + yield check_cross_validate_multi_metric, est, X, y, scores + + +def check_cross_validate_single_metric(clf, X, y, scores): + (train_mse_scores, test_mse_scores, train_r2_scores, + test_r2_scores) = scores + # Test single metric evaluation when scoring is string or singleton list + for (return_train_score, dict_len) in ((True, 4), (False, 3)): + # Single metric passed as a string + if return_train_score: + # It must be True by default + mse_scores_dict = cross_validate(clf, X, y, cv=5, + scoring='neg_mean_squared_error') + assert_array_almost_equal(mse_scores_dict['train_score'], + train_mse_scores) + else: + mse_scores_dict = cross_validate(clf, X, y, cv=5, + scoring='neg_mean_squared_error', + return_train_score=False) + assert_true(isinstance(mse_scores_dict, dict)) + assert_equal(len(mse_scores_dict), dict_len) + assert_array_almost_equal(mse_scores_dict['test_score'], + test_mse_scores) + + # Single metric passed as a list + if return_train_score: + # It must be True by default + r2_scores_dict = cross_validate(clf, X, y, cv=5, scoring=['r2']) + assert_array_almost_equal(r2_scores_dict['train_r2'], + train_r2_scores) + else: + r2_scores_dict = cross_validate(clf, X, y, cv=5, scoring=['r2'], + return_train_score=False) + assert_true(isinstance(r2_scores_dict, dict)) + assert_equal(len(r2_scores_dict), dict_len) + assert_array_almost_equal(r2_scores_dict['test_r2'], test_r2_scores) + + +def check_cross_validate_multi_metric(clf, X, y, scores): + # Test multimetric evaluation when scoring is a list / dict + (train_mse_scores, test_mse_scores, train_r2_scores, + test_r2_scores) = scores + all_scoring = (('r2', 'neg_mean_squared_error'), + {'r2': make_scorer(r2_score), + 'neg_mean_squared_error': 'neg_mean_squared_error'}) + + keys_sans_train = set(('test_r2', 'test_neg_mean_squared_error', + 'fit_time', 'score_time')) + keys_with_train = keys_sans_train.union( + set(('train_r2', 'train_neg_mean_squared_error'))) + + for return_train_score in (True, False): + for scoring in all_scoring: + if return_train_score: + # return_train_score must be True by default + cv_results = cross_validate(clf, X, y, cv=5, scoring=scoring) + assert_array_almost_equal(cv_results['train_r2'], + train_r2_scores) + assert_array_almost_equal( + cv_results['train_neg_mean_squared_error'], + train_mse_scores) + else: + cv_results = cross_validate(clf, X, y, cv=5, scoring=scoring, + return_train_score=False) + assert_true(isinstance(cv_results, dict)) + assert_equal(set(cv_results.keys()), + keys_with_train if return_train_score + else keys_sans_train) + assert_array_almost_equal(cv_results['test_r2'], test_r2_scores) + assert_array_almost_equal( + cv_results['test_neg_mean_squared_error'], test_mse_scores) + + # Make sure all the arrays are of np.ndarray type + assert type(cv_results['test_r2']) == np.ndarray + assert (type(cv_results['test_neg_mean_squared_error']) == + np.ndarray) + assert type(cv_results['fit_time'] == np.ndarray) + assert type(cv_results['score_time'] == np.ndarray) + + # Ensure all the times are within sane limits + assert np.all(cv_results['fit_time'] >= 0) + assert np.all(cv_results['fit_time'] < 10) + assert np.all(cv_results['score_time'] >= 0) + assert np.all(cv_results['score_time'] < 10) + + def test_cross_val_score_predict_groups(): # Check if ValueError (when groups is None) propagates to cross_val_score # and cross_val_predict @@ -386,8 +584,9 @@ def test_cross_val_score_score_func(): with warnings.catch_warnings(record=True): scoring = make_scorer(score_func) - score = cross_val_score(clf, X, y, scoring=scoring) + score = cross_val_score(clf, X, y, scoring=scoring, cv=3) assert_array_equal(score, [1.0, 1.0, 1.0]) + # Test that score function is called only 3 times (for cv=3) assert len(_score_func_args) == 3