From 2f7f5a1a50c2a2022d42160fce9d0596ecac2ada Mon Sep 17 00:00:00 2001 From: "(Venkat) Raghav (Rajagopalan)" <rvraghav93@gmail.com> Date: Fri, 6 Jan 2017 12:03:10 +0100 Subject: [PATCH] [MRG + 1] Add fowlkess-mallows and other supervised cluster metrics to SCORERS dict so it can be used in hyper-param search (#8117) * Add supervised cluster metrics to metrics.scorers * Add all the supervised cluster metrics to the tests * Add test for fowlkes_mallows_score in unsupervised grid search * COSMIT: Clarify comment on CLUSTER_SCORERS * Fix doctest --- doc/modules/model_evaluation.rst | 2 +- sklearn/metrics/scorer.py | 28 ++++++++++++++++++- sklearn/metrics/tests/test_score_objects.py | 29 +++++++++++++------- sklearn/model_selection/tests/test_search.py | 6 ++++ 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 5b13f82428..db7b59d6c1 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -94,7 +94,7 @@ Usage examples: >>> model = svm.SVC() >>> cross_val_score(model, X, y, scoring='wrong_choice') Traceback (most recent call last): - ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc'] + ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score'] .. note:: diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 4aeea1710d..3a163d967c 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -27,7 +27,16 @@ from . import (r2_score, median_absolute_error, mean_absolute_error, mean_squared_error, mean_squared_log_error, accuracy_score, f1_score, roc_auc_score, average_precision_score, precision_score, recall_score, log_loss) + from .cluster import adjusted_rand_score +from .cluster import homogeneity_score +from .cluster import completeness_score +from .cluster import v_measure_score +from .cluster import mutual_info_score +from .cluster import adjusted_mutual_info_score +from .cluster import normalized_mutual_info_score +from .cluster import fowlkes_mallows_score + from ..utils.multiclass import type_of_target from ..externals import six from ..base import is_regressor @@ -393,6 +402,14 @@ log_loss_scorer._deprecation_msg = deprecation_msg # Clustering scores adjusted_rand_scorer = make_scorer(adjusted_rand_score) +homogeneity_scorer = make_scorer(homogeneity_score) +completeness_scorer = make_scorer(completeness_score) +v_measure_scorer = make_scorer(v_measure_score) +mutual_info_scorer = make_scorer(mutual_info_score) +adjusted_mutual_info_scorer = make_scorer(adjusted_mutual_info_score) +normalized_mutual_info_scorer = make_scorer(normalized_mutual_info_score) +fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score) + SCORERS = dict(r2=r2_scorer, neg_median_absolute_error=neg_median_absolute_error_scorer, @@ -406,7 +423,16 @@ SCORERS = dict(r2=r2_scorer, average_precision=average_precision_scorer, log_loss=log_loss_scorer, neg_log_loss=neg_log_loss_scorer, - adjusted_rand_score=adjusted_rand_scorer) + # Cluster metrics that use supervised evaluation + adjusted_rand_score=adjusted_rand_scorer, + homogeneity_score=homogeneity_scorer, + completeness_score=completeness_scorer, + v_measure_score=v_measure_scorer, + mutual_info_score=mutual_info_scorer, + adjusted_mutual_info_score=adjusted_mutual_info_scorer, + normalized_mutual_info_score=normalized_mutual_info_scorer, + fowlkes_mallows_score=fowlkes_mallows_scorer) + for name, metric in [('precision', precision_score), ('recall', recall_score), ('f1', f1_score)]: diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 17a4811f52..461bdadf3d 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -18,7 +18,7 @@ from sklearn.utils.testing import assert_warns_message from sklearn.base import BaseEstimator from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, log_loss, precision_score, recall_score) -from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics import cluster as cluster_module from sklearn.metrics.scorer import (check_scoring, _PredictScorer, _passthrough_scorer) from sklearn.metrics import make_scorer, get_scorer, SCORERS @@ -47,9 +47,17 @@ CLF_SCORERS = ['accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro', 'roc_auc', 'average_precision', 'precision', 'precision_weighted', 'precision_macro', 'precision_micro', 'recall', 'recall_weighted', 'recall_macro', 'recall_micro', - 'neg_log_loss', 'log_loss', - 'adjusted_rand_score' # not really, but works - ] + 'neg_log_loss', 'log_loss'] + +# All supervised cluster scorers (They behave like classification metric) +CLUSTER_SCORERS = ["adjusted_rand_score", + "homogeneity_score", + "completeness_score", + "v_measure_score", + "mutual_info_score", + "adjusted_mutual_info_score", + "normalized_mutual_info_score", + "fowlkes_mallows_score"] MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples'] @@ -65,6 +73,7 @@ def _make_estimators(X_train, y_train, y_ml_train): return dict( [(name, sensible_regr) for name in REGRESSION_SCORERS] + [(name, sensible_clf) for name in CLF_SCORERS] + + [(name, sensible_clf) for name in CLUSTER_SCORERS] + [(name, sensible_ml_clf) for name in MULTILABEL_ONLY_SCORERS] ) @@ -330,16 +339,16 @@ def test_thresholded_scorers_multilabel_indicator_data(): assert_almost_equal(score1, score2) -def test_unsupervised_scorers(): +def test_supervised_cluster_scorers(): # Test clustering scorers against gold standard labeling. - # We don't have any real unsupervised Scorers yet. X, y = make_blobs(random_state=0, centers=2) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) km = KMeans(n_clusters=3) km.fit(X_train) - score1 = get_scorer('adjusted_rand_score')(km, X_test, y_test) - score2 = adjusted_rand_score(y_test, km.predict(X_test)) - assert_almost_equal(score1, score2) + for name in CLUSTER_SCORERS: + score1 = get_scorer(name)(km, X_test, y_test) + score2 = getattr(cluster_module, name)(y_test, km.predict(X_test)) + assert_almost_equal(score1, score2) @ignore_warnings @@ -445,4 +454,4 @@ def test_scoring_is_not_metric(): assert_raises_regexp(ValueError, 'make_scorer', check_scoring, Ridge(), r2_score) assert_raises_regexp(ValueError, 'make_scorer', check_scoring, - KMeans(), adjusted_rand_score) + KMeans(), cluster_module.adjusted_rand_score) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 49d1d566bd..117b81a35a 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -542,6 +542,12 @@ def test_unsupervised_grid_search(): # ARI can find the right number :) assert_equal(grid_search.best_params_["n_clusters"], 3) + grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), + scoring='fowlkes_mallows_score') + grid_search.fit(X, y) + # So can FMS ;) + assert_equal(grid_search.best_params_["n_clusters"], 3) + # Now without a score, and without y grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4])) grid_search.fit(X) -- GitLab