diff --git a/benchmarks/bench_multilabel_metrics.py b/benchmarks/bench_multilabel_metrics.py index 4a4cdd7f3390272698f01e14afdc20636c0b7113..7afee7543143338e7ba3fa1f26a492a999163690 100755 --- a/benchmarks/bench_multilabel_metrics.py +++ b/benchmarks/bench_multilabel_metrics.py @@ -22,7 +22,7 @@ from sklearn.utils.testing import ignore_warnings METRICS = { - 'f1': f1_score, + 'f1': partial(f1_score, average='micro'), 'f1-by-sample': partial(f1_score, average='samples'), 'accuracy': accuracy_score, 'hamming': hamming_loss, diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst index 3e40a0aaf01f5e0ac08438da39e8ed79cb2473a9..003366efa4606e532587e7bb129fb21973cb7e1d 100644 --- a/doc/datasets/twenty_newsgroups.rst +++ b/doc/datasets/twenty_newsgroups.rst @@ -131,7 +131,7 @@ which is fast to train and achieves a decent F-score:: >>> clf = MultinomialNB(alpha=.01) >>> clf.fit(vectors, newsgroups_train.target) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(newsgroups_test.target, pred) + >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted') 0.88251152461278892 (The example :ref:`example_text_document_classification_20newsgroups.py` shuffles @@ -181,7 +181,7 @@ blocks, and quotation blocks respectively. ... categories=categories) >>> vectors_test = vectorizer.transform(newsgroups_test.data) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(pred, newsgroups_test.target) + >>> metrics.f1_score(pred, newsgroups_test.target, average='weighted') 0.78409163025839435 This classifier lost over a lot of its F-score, just because we removed @@ -196,7 +196,7 @@ It loses even more if we also strip this metadata from the training data: >>> clf.fit(vectors, newsgroups_train.target) >>> vectors_test = vectorizer.transform(newsgroups_test.data) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(newsgroups_test.target, pred) + >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted') 0.73160869205141166 Some other classifiers cope better with this harder version of the task. Try diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 494f50e1994a05a1cde5150a2a0e3812193f2abc..53a8dac9c82293feee202dc1f6b1fe859ec42934 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -732,6 +732,7 @@ details. :template: function.rst metrics.make_scorer + metrics.get_scorer Classification metrics ---------------------- diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 59a8486f0e3e9f5cf4f1101ac748a30851559ff8..9a3ea274837d47a6b39eef5c1a42843e766fdb03 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -120,7 +120,7 @@ scoring parameter:: >>> from sklearn import metrics >>> scores = cross_validation.cross_val_score(clf, iris.data, iris.target, - ... cv=5, scoring='f1') + ... cv=5, scoring='f1_weighted') >>> scores # doctest: +ELLIPSIS array([ 0.96..., 1. ..., 0.96..., 0.96..., 1. ]) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 4cf37821637e4186a6707512a4083987b07cabf6..5b8a0ed23bff562b50f1c8668984fe769bfd01d3 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -60,10 +60,14 @@ Scoring Function **Classification** 'accuracy' :func:`metrics.accuracy_score` 'average_precision' :func:`metrics.average_precision_score` -'f1' :func:`metrics.f1_score` +'f1' :func:`metrics.f1_score` for binary targets +'f1_micro' :func:`metrics.f1_score` micro-averaged +'f1_macro' :func:`metrics.f1_score` macro-averaged +'f1_weighted' :func:`metrics.f1_score` weighted average +'f1_samples' :func:`metrics.f1_score` by multilabel sample 'log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support -'precision' :func:`metrics.precision_score` -'recall' :func:`metrics.recall_score` +'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1'` +'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1' 'roc_auc' :func:`metrics.roc_auc_score` **Clustering** @@ -84,7 +88,7 @@ Usage examples: >>> model = svm.SVC() >>> cross_validation.cross_val_score(model, X, y, scoring='wrong_choice') Traceback (most recent call last): - ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'precision', 'r2', 'recall', 'roc_auc'] + ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc'] >>> clf = svm.SVC(probability=True, random_state=0) >>> cross_validation.cross_val_score(clf, X, y, scoring='log_loss') # doctest: +ELLIPSIS array([-0.07..., -0.16..., -0.06...]) @@ -208,6 +212,8 @@ The :mod:`sklearn.metrics` module implements several loss, score, and utility functions to measure classification performance. Some metrics might require probability estimates of the positive class, confidence values, or binary decisions values. +Most implementations allow each sample to provide a weighted contribution +to the overall score, through the ``sample_weight`` parameter. Some of these are restricted to the binary classification case: @@ -254,7 +260,53 @@ And some work with binary and multilabel (but not multiclass) problems: roc_auc_score -In the following sub-sections, we will describe each of those functions. +In the following sub-sections, we will describe each of those functions, +preceded by some notes on common API and metric definition. + +From binary to multiclass and multilabel +........................................ + +Some metrics are essentially defined for binary classification tasks (e.g. +:func:`f1_score`, :func:`roc_auc_score`). In these cases, by default +only the positive label is evaluated, assuming by default that the positive +class is labelled ``1`` (though this may be configurable through the +``pos_label`` parameter). + +.. _average: + +In extending a binary metric to multiclass or multilabel problems, the data +is treated as a collection of binary problems, one for each class. +There are then a number of ways to average binary metric calculations across +the set of classes, each of which may be useful in some scenario. +Where available, you should select among these using the ``average`` parameter. + +* ``"macro"`` simply calculates the mean of the binary metrics, + giving equal weight to each class. In problems where infrequent classes + are nonetheless important, macro-averaging may be a means of highlighting + their performance. On the other hand, the assumption that all classes are + equally important is often untrue, such that macro-averaging will + over-emphasise the typically low performance on an infrequent class. +* ``"weighted"`` accounts for class imbalance by computing the average of + binary metrics in which each class's score is weighted by its presence in the + true data sample. +* ``"micro"`` gives each sample-class pair an equal contribution to the overall + metric (except as a result of sample-weight). Rather than summing the + metric per class, this sums the dividends and divisors that make up the the + per-class metrics to calculate an overall quotient. + Micro-averaging may be preferred in multilabel settings, including + multiclass classification where a majority class is to be ignored. +* ``"samples"`` applies only to multilabel problems. It does not calculate a + per-class measure, instead calculating the metric over the true and predicted + classes for each sample in the evaluation data, and returning their + (``sample_weight``-weighted) average. +* Selecting ``average=None`` will return an array with the score for each + class. + +While multiclass data is provided to the metric, like binary targets, as an +array of class labels, multilabel data is specified as an indicator matrix, +in which cell ``[i, j]`` has value 1 if sample ``i`` has label ``j`` and value +0 otherwise. + Accuracy score -------------- @@ -595,21 +647,10 @@ There are a few ways to combine results across labels, specified by the ``average`` argument to the :func:`average_precision_score` (multilabel only), :func:`f1_score`, :func:`fbeta_score`, :func:`precision_recall_fscore_support`, -:func:`precision_score` and :func:`recall_score` functions: - -* ``"micro"``: calculate metrics globally by counting the total true - positives, false negatives and false positives. Except in the multi-label - case, this implies that precision, recall and :math:`F` are equal. -* ``"samples"``: calculate metrics for each sample, comparing sets of - labels assigned to each, and find the mean across all samples. - This is only meaningful and available in the multilabel case. -* ``"macro"``: calculate metrics for each label, and find their mean. - This does not take label imbalance into account. -* ``"weighted"``: calculate metrics for each label, and find their average - weighted by the number of occurrences of the label in the true data. - This alters ``"macro"`` to account for label imbalance; it may produce an - F-score that is not between precision and recall. -* ``None``: calculate metrics for each label and do not average them. +:func:`precision_score` and :func:`recall_score` functions, as described +:ref:`above <average>`. Note that for "micro"-averaging in a multiclass setting +will produce equal precision, recall and :math:`F`, while "weighted" averaging +may produce an F-score that is not between precision and recall. To make this more explicit, consider the following notation: @@ -869,20 +910,7 @@ For more information see the `Wikipedia article on AUC 0.75 In multi-label classification, the :func:`roc_auc_score` function is -extended by averaging over the labels: - -* ``"micro"``: computes AUROC globally; obtained - by considering each element of the label indicator matrix as a label. -* ``"samples"``: computes AUROC for each sample, - comparing the sets of labels and scores assigned to each, and finds the mean - across all samples. -* ``"macro"``: computes AUROC for each label, and finds - their mean. -* ``"weighted"``: computes AUROC for each label and - finds their average, weighted by the number of occurrences of the label in the - true data. -* ``None``: this returns an array of scores with scores with shape (n_classes,) - instead of an aggregate scalar score. +extended by averaging over the labels as :ref:`above <average>`. Compared to metrics such as the subset accuracy, the Hamming loss, or the F1 score, ROC doesn't require optimizing a threshold for each label. The diff --git a/doc/whats_new.rst b/doc/whats_new.rst index ccc04c6ed20bae89bd20e2a4d991887287dde696..b8f1f02e3655a636d9d40570d4b2d026c937a98b 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -234,6 +234,17 @@ API changes summary :func:`linear_model.enet_path` which constrains coefficients to be positive. By `Manoj Kumar`_. + - Users should now supply an explicit ``average`` parameter to + :func:`sklearn.metrics.f1_score`, :func:`sklearn.metrics.fbeta_score`, + :func:`sklearn.metrics.recall_score` and + :func:`sklearn.metrics.precision_score` when performing multiclass + or multilabel (i.e. not binary) classification. By `Joel Nothman`_. + + - `scoring` parameter for cross validation now accepts `'f1_micro'`, + `'f1_macro'` or `'f1_weighted'`. `'f1'` is now for binary classification + only. Similar changes apply to `'precision'` and `'recall'`. + By `Joel Nothman`_. + .. _changes_0_15_2: 0.15.2 @@ -274,7 +285,7 @@ Bug fixes running the tests. By `Joel Nothman`_. - Many documentation and website fixes by `Joel Nothman`_, `Lars Buitinck`_ - and others. + `Matt Pico`_, and others. .. _changes_0_15_1: diff --git a/examples/model_selection/grid_search_digits.py b/examples/model_selection/grid_search_digits.py index 20f25d3751cc009d3c044e75faf61fb4058b7aca..7492d0a726ef5a07eac9f9ac27daa42c7509a57a 100644 --- a/examples/model_selection/grid_search_digits.py +++ b/examples/model_selection/grid_search_digits.py @@ -50,7 +50,8 @@ for score in scores: print("# Tuning hyper-parameters for %s" % score) print() - clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, scoring=score) + clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, + scoring='%_weighted' % score) clf.fit(X_train, y_train) print("Best parameters set found on development set:") diff --git a/examples/text/document_classification_20newsgroups.py b/examples/text/document_classification_20newsgroups.py index a7197ac1f6cb0632572097a478601405426aa57d..87e29accb763f783ac3162c4f36c451ccf9ed0a1 100644 --- a/examples/text/document_classification_20newsgroups.py +++ b/examples/text/document_classification_20newsgroups.py @@ -140,7 +140,7 @@ print() # split a training set and a test set y_train, y_test = data_train.target, data_test.target -print("Extracting features from the training dataset using a sparse vectorizer") +print("Extracting features from the training data using a sparse vectorizer") t0 = time() if opts.use_hashing: vectorizer = HashingVectorizer(stop_words='english', non_negative=True, @@ -155,7 +155,7 @@ print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration)) print("n_samples: %d, n_features: %d" % X_train.shape) print() -print("Extracting features from the test dataset using the same vectorizer") +print("Extracting features from the test data using the same vectorizer") t0 = time() X_test = vectorizer.transform(data_test.data) duration = time() - t0 @@ -208,8 +208,8 @@ def benchmark(clf): test_time = time() - t0 print("test time: %0.3fs" % test_time) - score = metrics.f1_score(y_test, pred) - print("f1-score: %0.3f" % score) + score = metrics.accuracy_score(y_test, pred, average='micro') + print("accuracy: %0.3f" % score) if hasattr(clf, 'coef_'): print("dimensionality: %d" % clf.coef_.shape[1]) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 0c20377506c461b9c3430e2ed0a4417f79a7b3b2..c1207af0f4d709ccdfbb4f0a3d5e2eedd35795aa 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -14,8 +14,8 @@ from sklearn.svm import SVC, SVR from sklearn.utils import check_random_state from sklearn.utils.testing import ignore_warnings -from sklearn.metrics.scorer import SCORERS from sklearn.metrics import make_scorer +from sklearn.metrics import get_scorer def test_rfe_set_params(): @@ -97,7 +97,7 @@ def test_rfecv(): assert_array_equal(X_r, iris.data) # Test using a scorer - scorer = SCORERS['accuracy'] + scorer = get_scorer('accuracy') rfecv = RFECV(estimator=SVC(kernel="linear"), step=1, cv=5, scoring=scorer) rfecv.fit(X, y) diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 4249fe0a420da129459613b0baef9c8241f706a8..d13dc3ebf461b5869ed6e63faaffa7ac43aa7332 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -15,8 +15,8 @@ from sklearn.utils.testing import ignore_warnings from sklearn import datasets from sklearn.metrics import mean_squared_error -from sklearn.metrics.scorer import SCORERS from sklearn.metrics import make_scorer +from sklearn.metrics import get_scorer from sklearn.linear_model.base import LinearRegression from sklearn.linear_model.ridge import ridge_regression @@ -336,7 +336,7 @@ def _test_ridge_loo(filter_): assert_equal(ridge_gcv3.alpha_, alpha_) # check that we get same best alpha with a scorer - scorer = SCORERS['mean_squared_error'] + scorer = get_scorer('mean_squared_error') ridge_gcv4 = RidgeCV(fit_intercept=False, scoring=scorer) ridge_gcv4.fit(filter_(X_diabetes), y_diabetes) assert_equal(ridge_gcv4.alpha_, alpha_) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 6085ecb31bbb4ffdb4b531670a19787507aec851..45f459089bcc0c15bec9332f72e1482568407dbb 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -610,14 +610,16 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): y = y[idx] clf = self.factory(alpha=0.0001, n_iter=1000, class_weight=None).fit(X, y) - assert_almost_equal(metrics.f1_score(y, clf.predict(X)), 0.96, - decimal=1) + assert_almost_equal(metrics.f1_score(y, clf.predict(X), + average='weighted'), + 0.96, decimal=1) # make the same prediction using automated class_weight clf_auto = self.factory(alpha=0.0001, n_iter=1000, class_weight="auto").fit(X, y) - assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X)), 0.96, - decimal=1) + assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X), + average='weighted'), + 0.96, decimal=1) # Make sure that in the balanced case it does not change anything # to use "auto" @@ -634,19 +636,19 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): clf = self.factory(n_iter=1000, class_weight=None) clf.fit(X_imbalanced, y_imbalanced) y_pred = clf.predict(X) - assert_less(metrics.f1_score(y, y_pred), 0.96) + assert_less(metrics.f1_score(y, y_pred, average='weighted'), 0.96) # fit a model with auto class_weight enabled clf = self.factory(n_iter=1000, class_weight="auto") clf.fit(X_imbalanced, y_imbalanced) y_pred = clf.predict(X) - assert_greater(metrics.f1_score(y, y_pred), 0.96) + assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96) # fit another using a fit parameter override clf = self.factory(n_iter=1000, class_weight="auto") clf.fit(X_imbalanced, y_imbalanced) y_pred = clf.predict(X) - assert_greater(metrics.f1_score(y, y_pred), 0.96) + assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96) def test_sample_weights(self): """Test weights on individual samples""" diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index e419fec7a1c91956e370b715fe48687bb1ce7ac4..117120f46d6266b9fff6ed6335402349eadbe879 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -53,6 +53,7 @@ from .regression import r2_score from .scorer import make_scorer from .scorer import SCORERS +from .scorer import get_scorer # Deprecated in 0.16 from .ranking import auc_score @@ -73,6 +74,7 @@ __all__ = [ 'explained_variance_score', 'f1_score', 'fbeta_score', + 'get_scorer', 'hamming_loss', 'hinge_loss', 'homogeneity_completeness_v_measure', diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index 80ec8446462ad84ef03c9abfd79bb8b3556fadee..7f8ef9290d68bec9db6cea847e4559e5fb58ea09 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -475,7 +475,7 @@ def zero_one_loss(y_true, y_pred, normalize=True, sample_weight=None): return n_samples - score -def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted', +def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None): """Compute the F1 score, also known as balanced F-score or F-measure @@ -504,7 +504,8 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted', If ``average`` is not ``None`` and the classification target is binary, only this class's scores will be returned. - average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)] + average : one of [None, 'micro', 'macro', 'samples', 'weighted'] + This parameter is required for multiclass/multilabel targets. If ``None``, the scores for each class are returned. Otherwise, unless ``pos_label`` is given in binary classification, this determines the type of averaging performed on the data: @@ -561,7 +562,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted', def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, - average='weighted', sample_weight=None): + average='binary', sample_weight=None): """Compute the F-beta score The F-beta score is the weighted harmonic mean of precision and recall, @@ -590,7 +591,8 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, If ``average`` is not ``None`` and the classification target is binary, only this class's scores will be returned. - average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)] + average : one of [None, 'micro', 'macro', 'samples', 'weighted'] + This parameter is required for multiclass/multilabel targets. If ``None``, the scores for each class are returned. Otherwise, unless ``pos_label`` is given in binary classification, this determines the type of averaging performed on the data: @@ -822,7 +824,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, """ average_options = (None, 'micro', 'macro', 'weighted', 'samples') - if average not in average_options: + if average not in average_options and average != 'binary': raise ValueError('average has to be one of ' + str(average_options)) if beta <= 0: @@ -830,6 +832,17 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, y_type, y_true, y_pred = _check_targets(y_true, y_pred) + if average == 'binary' and y_type != 'binary': + warnings.warn('The default `weighted` averaging is deprecated, ' + 'and from version 0.18, use of precision, recall or ' + 'F-score with multiclass or multilabel data will result ' + 'in an exception. ' + 'Please set an explicit value for `average`, one of ' + '%s. In cross validation use, for instance, ' + 'scoring="f1_weighted" instead of scoring="f1".' + % str(average_options), DeprecationWarning, stacklevel=2) + average = 'weighted' + label_order = labels # save this for later if labels is None: labels = unique_labels(y_true, y_pred) @@ -852,7 +865,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, elif average == 'samples': raise ValueError("Sample-based precision, recall, fscore is " - "not meaningful outside multilabel" + "not meaningful outside multilabel " "classification. See the accuracy_score instead.") else: lb = LabelEncoder() @@ -885,11 +898,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, ### Select labels to keep ### if y_type == 'binary' and average is not None and pos_label is not None: - if label_order is not None and len(label_order) == 2: + if average != 'binary' and label_order is not None \ + and len(label_order) == 2: warnings.warn('In the future, providing two `labels` values, as ' - 'well as `average` will average over those ' - 'labels. For now, please use `labels=None` with ' - '`pos_label` to evaluate precision, recall and ' + 'well as `average!=`binary`` will average over ' + 'those labels. For now, please use `labels=None` ' + 'with `pos_label` to evaluate precision, recall and ' 'F-score for the positive label only.', FutureWarning) if pos_label not in labels: @@ -954,7 +968,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, def precision_score(y_true, y_pred, labels=None, pos_label=1, - average='weighted', sample_weight=None): + average='binary', sample_weight=None): """Compute the precision The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of @@ -979,7 +993,8 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, If ``average`` is not ``None`` and the classification target is binary, only this class's scores will be returned. - average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)] + average : one of [None, 'micro', 'macro', 'samples', 'weighted'] + This parameter is required for multiclass/multilabel targets. If ``None``, the scores for each class are returned. Otherwise, unless ``pos_label`` is given in binary classification, this determines the type of averaging performed on the data: @@ -1036,7 +1051,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, return p -def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted', +def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None): """Compute the recall @@ -1061,7 +1076,8 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted', If ``average`` is not ``None`` and the classification target is binary, only this class's scores will be returned. - average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)] + average : one of [None, 'micro', 'macro', 'samples', 'weighted'] + This parameter is required for multiclass/multilabel targets. If ``None``, the scores for each class are returned. Otherwise, unless ``pos_label`` is given in binary classification, this determines the type of averaging performed on the data: diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 6f99deb815d0454dbbb29f9f71bfd4c931de91e9..ce7eaaeca4410e5f171ecc2cd5b077ead4419801 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -19,6 +19,7 @@ ground truth labeling (or ``None`` in the case of unsupervised models). # License: Simplified BSD from abc import ABCMeta, abstractmethod +from functools import partial import numpy as np @@ -342,8 +343,15 @@ SCORERS = dict(r2=r2_scorer, median_absolute_error=median_absolute_error_scorer, mean_absolute_error=mean_absolute_error_scorer, mean_squared_error=mean_squared_error_scorer, - accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=roc_auc_scorer, + accuracy=accuracy_scorer, roc_auc=roc_auc_scorer, average_precision=average_precision_scorer, - precision=precision_scorer, recall=recall_scorer, log_loss=log_loss_scorer, adjusted_rand_score=adjusted_rand_scorer) + +for name, metric in [('precision', precision_score), + ('recall', recall_score), ('f1', f1_score)]: + SCORERS[name] = make_scorer(metric) + for average in ['macro', 'micro', 'samples', 'weighted']: + qualified_name = '{0}_{1}'.format(name, average) + SCORERS[qualified_name] = make_scorer(partial(metric, pos_label=None, + average=average)) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 578e9d13db659e3604111244b5578cfb69c5e80e..1a920bfba65c2d86b035237186ffb44049048a37 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -991,6 +991,24 @@ def test_fscore_warnings(): 'being set to 0.0 due to no true samples.') +def test_prf_average_compat(): + """Ensure warning if f1_score et al.'s average is implicit for multiclass + """ + y_true = [1, 2, 3, 3] + y_pred = [1, 2, 3, 1] + + for metric in [precision_score, recall_score, f1_score, + partial(fbeta_score, beta=2)]: + score = assert_warns(DeprecationWarning, metric, y_true, y_pred) + score_weighted = assert_no_warnings(metric, y_true, y_pred, + average='weighted') + assert_equal(score, score_weighted, + 'average does not act like "weighted" by default') + + # check binary passes without warning + assert_no_warnings(metric, [0, 1, 1], [0, 1, 0]) + + @ignore_warnings # sequence of sequences is deprecated def test__check_targets(): """Check that _check_targets correctly merges target types, squeezes diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 1973b6e3364c01bb38b8b7a9d60674515eabcdd2..e7c7c33d2484ae84bc7ed92890c4c77ae84e1fef 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -107,6 +107,7 @@ CLASSIFICATION_METRICS = { "zero_one_loss": zero_one_loss, "unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False), + # These are needed to test averaging "precision_score": precision_score, "recall_score": recall_score, "f1_score": f1_score, @@ -336,6 +337,7 @@ METRICS_WITHOUT_SAMPLE_WEIGHT = [ ] +@ignore_warnings def test_symmetry(): """Test the symmetry of score and loss functions""" random_state = check_random_state(0) @@ -366,6 +368,7 @@ def test_symmetry(): msg="%s seems to be symmetric" % name) +@ignore_warnings def test_sample_order_invariance(): random_state = check_random_state(0) y_true = random_state.randint(0, 2, size=(20, )) @@ -382,6 +385,7 @@ def test_sample_order_invariance(): % name) +@ignore_warnings def test_sample_order_invariance_multilabel_and_multioutput(): random_state = check_random_state(0) @@ -421,6 +425,7 @@ def test_sample_order_invariance_multilabel_and_multioutput(): % name) +@ignore_warnings def test_format_invariance_with_1d_vectors(): random_state = check_random_state(0) y1 = random_state.randint(0, 2, size=(20, )) @@ -499,6 +504,7 @@ def test_format_invariance_with_1d_vectors(): assert_raises(ValueError, metric, y1_row, y2_row) +@ignore_warnings def test_invariance_string_vs_numbers_labels(): """Ensure that classification metrics with string labels""" random_state = check_random_state(0) @@ -627,6 +633,7 @@ def test_multioutput_regression_invariance_to_dimension_shuffling(): "invariant" % name) +@ignore_warnings def test_multilabel_representation_invariance(): # Generate some data n_classes = 4 diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index aef0a33dc4a59a8dfe0cdd57491dcd2ff8d17789..50942c77653a53cfedd5ac2a4e230ae1685eab46 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -10,10 +10,10 @@ from sklearn.utils.testing import ignore_warnings from sklearn.utils.testing import assert_not_equal from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, - log_loss) + log_loss, precision_score, recall_score) from sklearn.metrics.cluster import adjusted_rand_score from sklearn.metrics.scorer import check_scoring -from sklearn.metrics import make_scorer, SCORERS +from sklearn.metrics import make_scorer, get_scorer, SCORERS from sklearn.svm import LinearSVC from sklearn.cluster import KMeans from sklearn.dummy import DummyRegressor @@ -30,11 +30,17 @@ from sklearn.multiclass import OneVsRestClassifier REGRESSION_SCORERS = ['r2', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error'] -CLF_SCORERS = ['accuracy', 'f1', 'roc_auc', 'average_precision', 'precision', - 'recall', 'log_loss', + +CLF_SCORERS = ['accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro', + 'roc_auc', 'average_precision', 'precision', + 'precision_weighted', 'precision_macro', 'precision_micro', + 'recall', 'recall_weighted', 'recall_macro', 'recall_micro', + 'log_loss', 'adjusted_rand_score' # not really, but works ] +MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples'] + class EstimatorWithoutFit(object): """Dummy estimator to test check_scoring""" @@ -107,18 +113,38 @@ def test_make_scorer(): def test_classification_scores(): """Test classification scorers.""" - X, y = make_blobs(random_state=0) + X, y = make_blobs(random_state=0, centers=2) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) clf = LinearSVC(random_state=0) clf.fit(X_train, y_train) - score1 = SCORERS['f1'](clf, X_test, y_test) - score2 = f1_score(y_test, clf.predict(X_test)) - assert_almost_equal(score1, score2) + + for prefix, metric in [('f1', f1_score), ('precision', precision_score), + ('recall', recall_score)]: + + score1 = get_scorer('%s_weighted' % prefix)(clf, X_test, y_test) + score2 = metric(y_test, clf.predict(X_test), pos_label=None, + average='weighted') + assert_almost_equal(score1, score2) + + score1 = get_scorer('%s_macro' % prefix)(clf, X_test, y_test) + score2 = metric(y_test, clf.predict(X_test), pos_label=None, + average='macro') + assert_almost_equal(score1, score2) + + score1 = get_scorer('%s_micro' % prefix)(clf, X_test, y_test) + score2 = metric(y_test, clf.predict(X_test), pos_label=None, + average='micro') + assert_almost_equal(score1, score2) + + score1 = get_scorer('%s' % prefix)(clf, X_test, y_test) + score2 = metric(y_test, clf.predict(X_test), pos_label=1) + assert_almost_equal(score1, score2) # test fbeta score that takes an argument scorer = make_scorer(fbeta_score, beta=2) score1 = scorer(clf, X_test, y_test) - score2 = fbeta_score(y_test, clf.predict(X_test), beta=2) + score2 = fbeta_score(y_test, clf.predict(X_test), beta=2, + average='weighted') assert_almost_equal(score1, score2) # test that custom scorer can be pickled @@ -137,7 +163,7 @@ def test_regression_scorers(): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) clf = Ridge() clf.fit(X_train, y_train) - score1 = SCORERS['r2'](clf, X_test, y_test) + score1 = get_scorer('r2')(clf, X_test, y_test) score2 = r2_score(y_test, clf.predict(X_test)) assert_almost_equal(score1, score2) @@ -148,20 +174,20 @@ def test_thresholded_scorers(): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) clf = LogisticRegression(random_state=0) clf.fit(X_train, y_train) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, clf.decision_function(X_test)) score3 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) assert_almost_equal(score1, score2) assert_almost_equal(score1, score3) - logscore = SCORERS['log_loss'](clf, X_test, y_test) + logscore = get_scorer('log_loss')(clf, X_test, y_test) logloss = log_loss(y_test, clf.predict_proba(X_test)) assert_almost_equal(-logscore, logloss) # same for an estimator without decision_function clf = DecisionTreeClassifier() clf.fit(X_train, y_train) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) assert_almost_equal(score1, score2) @@ -169,7 +195,7 @@ def test_thresholded_scorers(): X, y = make_blobs(random_state=0, centers=3) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) clf.fit(X_train, y_train) - assert_raises(ValueError, SCORERS['roc_auc'], clf, X_test, y_test) + assert_raises(ValueError, get_scorer('roc_auc'), clf, X_test, y_test) def test_thresholded_scorers_multilabel_indicator_data(): @@ -185,7 +211,7 @@ def test_thresholded_scorers_multilabel_indicator_data(): clf = DecisionTreeClassifier() clf.fit(X_train, y_train) y_proba = clf.predict_proba(X_test) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, np.vstack(p[:, -1] for p in y_proba).T) assert_almost_equal(score1, score2) @@ -198,21 +224,21 @@ def test_thresholded_scorers_multilabel_indicator_data(): clf.decision_function = lambda X: [p[:, 1] for p in clf._predict_proba(X)] y_proba = clf.decision_function(X_test) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, np.vstack(p for p in y_proba).T) assert_almost_equal(score1, score2) # Multilabel predict_proba clf = OneVsRestClassifier(DecisionTreeClassifier()) clf.fit(X_train, y_train) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, clf.predict_proba(X_test)) assert_almost_equal(score1, score2) # Multilabel decision function clf = OneVsRestClassifier(LinearSVC(random_state=0)) clf.fit(X_train, y_train) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score1 = get_scorer('roc_auc')(clf, X_test, y_test) score2 = roc_auc_score(y_test, clf.decision_function(X_test)) assert_almost_equal(score1, score2) @@ -224,7 +250,7 @@ def test_unsupervised_scorers(): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) km = KMeans(n_clusters=3) km.fit(X_train) - score1 = SCORERS['adjusted_rand_score'](km, X_test, y_test) + score1 = get_scorer('adjusted_rand_score')(km, X_test, y_test) score2 = adjusted_rand_score(y_test, km.predict(X_test)) assert_almost_equal(score1, score2) @@ -242,6 +268,7 @@ def test_raises_on_score_list(): assert_raises(ValueError, grid_search.fit, X, y) +@ignore_warnings def test_scorer_sample_weight(): """Test that scorers support sample_weight or raise sensible errors""" @@ -249,7 +276,12 @@ def test_scorer_sample_weight(): # to ensure that, on the classifier output, weighted and unweighted # scores really should be unequal. X, y = make_classification(random_state=0) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + _, y_ml = make_multilabel_classification(n_samples=X.shape[0], + return_indicator=True, + random_state=0) + split = train_test_split(X, y, y_ml, random_state=0) + X_train, X_test, y_train, y_test, y_ml_train, y_ml_test = split + sample_weight = np.ones_like(y_test) sample_weight[:10] = 0 @@ -258,17 +290,25 @@ def test_scorer_sample_weight(): sensible_regr.fit(X_train, y_train) sensible_clf = DecisionTreeClassifier() sensible_clf.fit(X_train, y_train) + sensible_ml_clf = DecisionTreeClassifier() + sensible_ml_clf.fit(X_train, y_ml_train) estimator = dict([(name, sensible_regr) for name in REGRESSION_SCORERS] + [(name, sensible_clf) - for name in CLF_SCORERS]) + for name in CLF_SCORERS] + + [(name, sensible_ml_clf) + for name in MULTILABEL_ONLY_SCORERS]) for name, scorer in SCORERS.items(): + if name in MULTILABEL_ONLY_SCORERS: + target = y_ml_test + else: + target = y_test try: - weighted = scorer(estimator[name], X_test, y_test, + weighted = scorer(estimator[name], X_test, target, sample_weight=sample_weight) - ignored = scorer(estimator[name], X_test[10:], y_test[10:]) - unweighted = scorer(estimator[name], X_test, y_test) + ignored = scorer(estimator[name], X_test[10:], target[10:]) + unweighted = scorer(estimator[name], X_test, target) assert_not_equal(weighted, unweighted, msg="scorer {0} behaves identically when " "called with sample weights: {1} vs " diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 79e0ec207dd01ab76ce1859826d58e65f95e1a02..844f35fb4766381a3675d40efdaafc6e3a48aebc 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -392,8 +392,9 @@ def test_auto_weight(): y_pred = clf.fit(X[unbalanced], y[unbalanced]).predict(X) clf.set_params(class_weight='auto') y_pred_balanced = clf.fit(X[unbalanced], y[unbalanced],).predict(X) - assert_true(metrics.f1_score(y, y_pred) - <= metrics.f1_score(y, y_pred_balanced)) + assert_true(metrics.f1_score(y, y_pred, average='weighted') + <= metrics.f1_score(y, y_pred_balanced, + average='weighted')) def test_bad_input(): diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index b60e5fcd8e008603b6d0f92e210152986c558bb3..116ee6b29e86ed79aa2f59912315b797a08cb5b9 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -683,7 +683,7 @@ def test_cross_val_score_with_score_func_classification(): # F1 score (class are balanced so f1_score should be equal to zero/one # score f1_scores = cval.cross_val_score(clf, iris.data, iris.target, - scoring="f1", cv=5) + scoring="f1_weighted", cv=5) assert_array_almost_equal(f1_scores, [0.97, 1., 0.97, 0.97, 1.], 2) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 19618b96a9fea8d8675970edb87e3d489f844b71..af08a22233ea02ce2c9b2493be4897b5fcf2f370 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -742,8 +742,8 @@ def check_class_weight_auto_classifiers(name, Classifier, X_train, y_train, classifier.set_params(class_weight='auto') classifier.fit(X_train, y_train) y_pred_auto = classifier.predict(X_test) - assert_greater(f1_score(y_test, y_pred_auto), - f1_score(y_test, y_pred)) + assert_greater(f1_score(y_test, y_pred_auto, average='weighted'), + f1_score(y_test, y_pred, average='weighted')) def check_class_weight_auto_linear_classifier(name, Classifier):