diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index 50d3d3b7523f0a52b6ec6b00026ae3a53e14657d..2a7be716ee48da5139816da2bb1c7463799d3ce3 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -623,9 +623,10 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', parameter *labels* improved for multiclass problem. pos_label : str or int, 1 by default - The class to report if ``average='binary'``. Until version 0.18 it is - necessary to set ``pos_label=None`` if seeking to use another averaging - method over binary targets. + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \ 'weighted'] @@ -652,10 +653,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', meaningful for multilabel classification where this differs from :func:`accuracy_score`). - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -729,9 +726,10 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, parameter *labels* improved for multiclass problem. pos_label : str or int, 1 by default - The class to report if ``average='binary'``. Until version 0.18 it is - necessary to set ``pos_label=None`` if seeking to use another averaging - method over binary targets. + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \ 'weighted'] @@ -758,10 +756,6 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, meaningful for multilabel classification where this differs from :func:`accuracy_score`). - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -905,9 +899,10 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, ``y_pred`` are used in sorted order. pos_label : str or int, 1 by default - The class to report if ``average='binary'``. Until version 0.18 it is - necessary to set ``pos_label=None`` if seeking to use another averaging - method over binary targets. + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. average : string, [None (default), 'binary', 'micro', 'macro', 'samples', \ 'weighted'] @@ -933,10 +928,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, meaningful for multilabel classification where this differs from :func:`accuracy_score`). - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - warn_for : tuple or set, for internal use This determines which warnings will be made in the case that this function is being used to return only one of its metrics. @@ -1008,25 +999,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, y_type, y_true, y_pred = _check_targets(y_true, y_pred) present_labels = unique_labels(y_true, y_pred) - if average == 'binary' and (y_type != 'binary' or pos_label is None): - warnings.warn('The default `weighted` averaging is deprecated, ' - 'and from version 0.18, use of precision, recall or ' - 'F-score with multiclass or multilabel data or ' - 'pos_label=None will result in an exception. ' - 'Please set an explicit value for `average`, one of ' - '%s. In cross validation use, for instance, ' - 'scoring="f1_weighted" instead of scoring="f1".' - % str(average_options), DeprecationWarning, stacklevel=2) - average = 'weighted' - - if y_type == 'binary' and pos_label is not None and average is not None: - if average != 'binary': - warnings.warn('From version 0.18, binary input will not be ' - 'handled specially when using averaged ' - 'precision/recall/F-score. ' - 'Please use average=\'binary\' to report only the ' - 'positive class performance.', DeprecationWarning) - if labels is None or len(labels) <= 2: + if average == 'binary': + if y_type == 'binary': if pos_label not in present_labels: if len(present_labels) < 2: # Only negative labels @@ -1035,6 +1009,15 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, raise ValueError("pos_label=%r is not a valid label: %r" % (pos_label, present_labels)) labels = [pos_label] + else: + raise ValueError("Target is %s but average='binary'. Please " + "choose another average setting." % y_type) + elif pos_label not in (None, 1): + warnings.warn("Note that pos_label (set to %r) is ignored when " + "average != 'binary' (got %r). You may use " + "labels=[pos_label] to specify a single positive class." + % (pos_label, average), UserWarning) + if labels is None: labels = present_labels n_labels = None @@ -1187,9 +1170,10 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, parameter *labels* improved for multiclass problem. pos_label : str or int, 1 by default - The class to report if ``average='binary'``. Until version 0.18 it is - necessary to set ``pos_label=None`` if seeking to use another averaging - method over binary targets. + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \ 'weighted'] @@ -1216,10 +1200,6 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, meaningful for multilabel classification where this differs from :func:`accuracy_score`). - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -1289,9 +1269,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', parameter *labels* improved for multiclass problem. pos_label : str or int, 1 by default - The class to report if ``average='binary'``. Until version 0.18 it is - necessary to set ``pos_label=None`` if seeking to use another averaging - method over binary targets. + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \ 'weighted'] @@ -1318,10 +1299,6 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', meaningful for multilabel classification where this differs from :func:`accuracy_score`). - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - sample_weight : array-like of shape = [n_samples], optional Sample weights. diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 576c3a362b29fb1e734223f48578bf1163ff9a19..dc8a7c0686b591686747bf815d552ce8a66bf674 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -130,10 +130,8 @@ def test_precision_recall_f1_score_binary(): # individual scoring function that can be used for grid search: in the # binary class case the score is the value of the measure for the positive # class (e.g. label == 1). This is deprecated for average != 'binary'. - assert_dep_warning = partial(assert_warns, DeprecationWarning) for kwargs, my_assert in [({}, assert_no_warnings), - ({'average': 'binary'}, assert_no_warnings), - ({'average': 'micro'}, assert_dep_warning)]: + ({'average': 'binary'}, assert_no_warnings)]: ps = my_assert(precision_score, y_true, y_pred, **kwargs) assert_array_almost_equal(ps, 0.85, 2) @@ -273,13 +271,24 @@ def test_precision_recall_fscore_support_errors(): # Bad pos_label assert_raises(ValueError, precision_recall_fscore_support, - y_true, y_pred, pos_label=2, average='macro') + y_true, y_pred, pos_label=2, average='binary') # Bad average option assert_raises(ValueError, precision_recall_fscore_support, [0, 1, 2], [1, 2, 0], average='mega') +def test_precision_recall_f_unused_pos_label(): + # Check warning that pos_label unused when set to non-default value + # but average != 'binary'; even if data is binary. + assert_warns_message(UserWarning, + "Note that pos_label (set to 2) is " + "ignored when average != 'binary' (got 'macro'). You " + "may use labels=[pos_label] to specify a single " + "positive class.", precision_recall_fscore_support, + [1, 2, 1], [1, 2, 2], pos_label=2, average='macro') + + def test_confusion_matrix_binary(): # Test confusion matrix - binary classification case y_true, y_pred, _ = make_prediction(binary=True) @@ -458,17 +467,24 @@ def test_precision_refcall_f1_score_multilabel_unordered_labels(): assert_array_equal(s, [0, 1, 1, 0]) -def test_precision_recall_f1_score_multiclass_pos_label_none(): - # Test Precision Recall and F1 Score for multiclass classification task - # GH Issue #1296 - # initialize data +def test_precision_recall_f1_score_binary_averaged(): y_true = np.array([0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1]) y_pred = np.array([1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1]) # compute scores with default labels introspection - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - pos_label=None, + ps, rs, fs, _ = precision_recall_fscore_support(y_true, y_pred, + average=None) + p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='macro') + assert_equal(p, np.mean(ps)) + assert_equal(r, np.mean(rs)) + assert_equal(f, np.mean(fs)) + p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, + average='weighted') + support = np.bincount(y_true) + assert_equal(p, np.average(ps, weights=support)) + assert_equal(r, np.average(rs, weights=support)) + assert_equal(f, np.average(fs, weights=support)) def test_zero_precision_recall(): @@ -1041,37 +1057,37 @@ def test_prf_warnings(): 'being set to 0.0 in labels with no true samples.') my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average) - # average of per-sample scores - msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 in samples with no predicted labels.') - my_assert(w, msg, f, np.array([[1, 0], [1, 0]]), - np.array([[1, 0], [0, 0]]), average='samples') + # average of per-sample scores + msg = ('Precision and F-score are ill-defined and ' + 'being set to 0.0 in samples with no predicted labels.') + my_assert(w, msg, f, np.array([[1, 0], [1, 0]]), + np.array([[1, 0], [0, 0]]), average='samples') - msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 in samples with no true labels.') - my_assert(w, msg, f, np.array([[1, 0], [0, 0]]), - np.array([[1, 0], [1, 0]]), - average='samples') + msg = ('Recall and F-score are ill-defined and ' + 'being set to 0.0 in samples with no true labels.') + my_assert(w, msg, f, np.array([[1, 0], [0, 0]]), + np.array([[1, 0], [1, 0]]), + average='samples') - # single score: micro-average - msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 due to no predicted samples.') - my_assert(w, msg, f, np.array([[1, 1], [1, 1]]), - np.array([[0, 0], [0, 0]]), average='micro') + # single score: micro-average + msg = ('Precision and F-score are ill-defined and ' + 'being set to 0.0 due to no predicted samples.') + my_assert(w, msg, f, np.array([[1, 1], [1, 1]]), + np.array([[0, 0], [0, 0]]), average='micro') - msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 due to no true samples.') - my_assert(w, msg, f, np.array([[0, 0], [0, 0]]), - np.array([[1, 1], [1, 1]]), average='micro') + msg = ('Recall and F-score are ill-defined and ' + 'being set to 0.0 due to no true samples.') + my_assert(w, msg, f, np.array([[0, 0], [0, 0]]), + np.array([[1, 1], [1, 1]]), average='micro') - # single postive label - msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 due to no predicted samples.') - my_assert(w, msg, f, [1, 1], [-1, -1], average='macro') + # single postive label + msg = ('Precision and F-score are ill-defined and ' + 'being set to 0.0 due to no predicted samples.') + my_assert(w, msg, f, [1, 1], [-1, -1], average='binary') - msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 due to no true samples.') - my_assert(w, msg, f, [-1, -1], [1, 1], average='macro') + msg = ('Recall and F-score are ill-defined and ' + 'being set to 0.0 due to no true samples.') + my_assert(w, msg, f, [-1, -1], [1, 1], average='binary') def test_recall_warnings(): @@ -1128,32 +1144,23 @@ def test_fscore_warnings(): 'being set to 0.0 due to no true samples.') -def test_prf_average_compat(): - # Ensure warning if f1_score et al.'s average is implicit for multiclass - y_true = [1, 2, 3, 3] - y_pred = [1, 2, 3, 1] - y_true_bin = [0, 1, 1] - y_pred_bin = [0, 1, 0] - - for metric in [precision_score, recall_score, f1_score, - partial(fbeta_score, beta=2)]: - score = assert_warns(DeprecationWarning, metric, y_true, y_pred) - score_weighted = assert_no_warnings(metric, y_true, y_pred, - average='weighted') - assert_equal(score, score_weighted, - 'average does not act like "weighted" by default') - - # check binary passes without warning - assert_no_warnings(metric, y_true_bin, y_pred_bin) - - # but binary with pos_label=None should behave like multiclass - score = assert_warns(DeprecationWarning, metric, - y_true_bin, y_pred_bin, pos_label=None) - score_weighted = assert_no_warnings(metric, y_true_bin, y_pred_bin, - pos_label=None, average='weighted') - assert_equal(score, score_weighted, - 'average does not act like "weighted" by default with ' - 'binary data and pos_label=None') +def test_prf_average_binary_data_non_binary(): + # Error if user does not explicitly set non-binary average mode + y_true_mc = [1, 2, 3, 3] + y_pred_mc = [1, 2, 3, 1] + y_true_ind = np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]]) + y_pred_ind = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + + for y_true, y_pred, y_type in [ + (y_true_mc, y_pred_mc, 'multiclass'), + (y_true_ind, y_pred_ind, 'multilabel-indicator'), + ]: + for metric in [precision_score, recall_score, f1_score, + partial(fbeta_score, beta=2)]: + assert_raise_message(ValueError, + "Target is %s but average='binary'. Please " + "choose another average setting." % y_type, + metric, y_true, y_pred) def test__check_targets(): diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 4974d29356be80f8dfd7cfd9f0cf0de96b2e9186..fa4c7e8d3124bb19ef513540836225f51c27daa5 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -217,6 +217,13 @@ METRIC_UNDEFINED_BINARY = [ METRIC_UNDEFINED_MULTICLASS = [ "brier_score_loss", "matthews_corrcoef_score", + + # with default average='binary', multiclass is prohibited + "precision_score", + "recall_score", + "f1_score", + "f2_score", + "f0.5_score", ] # Metric undefined with "binary" or "multiclass" input @@ -303,17 +310,15 @@ MULTILABELS_METRICS = [ "jaccard_similarity_score", "unnormalized_jaccard_similarity_score", "zero_one_loss", "unnormalized_zero_one_loss", - "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", - "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", - "micro_f0.5_score", "micro_f1_score", "micro_f2_score", - "micro_precision_score", "micro_recall_score", - "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", + "micro_f0.5_score", "micro_f1_score", "micro_f2_score", + "micro_precision_score", "micro_recall_score", + "samples_f0.5_score", "samples_f1_score", "samples_f2_score", "samples_precision_score", "samples_recall_score", ] @@ -332,7 +337,11 @@ SYMMETRIC_METRICS = [ "jaccard_similarity_score", "unnormalized_jaccard_similarity_score", "zero_one_loss", "unnormalized_zero_one_loss", - "f1_score", "weighted_f1_score", "micro_f1_score", "macro_f1_score", + "f1_score", "micro_f1_score", "macro_f1_score", + "weighted_recall_score", + # P = R = F = accuracy in multiclass case + "micro_f0.5_score", "micro_f1_score", "micro_f2_score", + "micro_precision_score", "micro_recall_score", "matthews_corrcoef_score", "mean_absolute_error", "mean_squared_error", "median_absolute_error", @@ -349,11 +358,8 @@ NOT_SYMMETRIC_METRICS = [ "precision_score", "recall_score", "f2_score", "f0.5_score", - "weighted_f0.5_score", "weighted_f2_score", "weighted_precision_score", - "weighted_recall_score", - - "micro_f0.5_score", "micro_f2_score", "micro_precision_score", - "micro_recall_score", + "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", + "weighted_precision_score", "macro_f0.5_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "log_loss", "hinge_loss" @@ -382,7 +388,8 @@ def test_symmetry(): # We shouldn't forget any metrics assert_equal(set(SYMMETRIC_METRICS).union( NOT_SYMMETRIC_METRICS, THRESHOLDED_METRICS, - METRIC_UNDEFINED_BINARY_MULTICLASS), set(ALL_METRICS)) + METRIC_UNDEFINED_BINARY_MULTICLASS), + set(ALL_METRICS)) assert_equal( set(SYMMETRIC_METRICS).intersection(set(NOT_SYMMETRIC_METRICS)), @@ -1062,6 +1069,7 @@ def test_sample_weight_invariance(n_samples=50): y_pred) +@ignore_warnings def test_no_averaging_labels(): # test labels argument when not using averaging # in multi-class and multi-label cases @@ -1075,7 +1083,7 @@ def test_no_averaging_labels(): for name in METRICS_WITH_AVERAGING: for y_true, y_pred in [[y_true_multiclass, y_pred_multiclass], [y_true_multilabel, y_pred_multilabel]]: - if name not in MULTILABELS_METRICS and y_pred.shape[1] > 0: + if name not in MULTILABELS_METRICS and y_pred.ndim > 1: continue metric = ALL_METRICS[name]