diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index 4333df19bbe20e14c287df22871948d5b6e519ef..aa26e4d4665a29a0e302ed381247325d1ea9a867 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -353,10 +353,18 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None): precision = true_pos / (true_pos + false_pos) recall = true_pos / (true_pos + false_neg) + # handle division by 0.0 in precision and recall + precision[(true_pos + false_pos) == 0.0] = 0.0 + recall[(true_pos + false_neg) == 0.0] = 0.0 + # fbeta score beta2 = beta ** 2 fscore = (1 + beta2) * (precision * recall) / ( beta2 * precision + recall) + + # handle division by 0.0 in fscore + fscore[(precision + recall) == 0.0] = 0.0 + return precision, recall, fscore, support diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py index 5b7782bdc6e311fb7adf78e4dfd83972e118f285..93fac32958ae19b2cfc69cb9ec91f7779ea4a75a 100644 --- a/scikits/learn/tests/test_metrics.py +++ b/scikits/learn/tests/test_metrics.py @@ -141,6 +141,16 @@ def test_precision_recall_f1_score_multiclass(): assert_array_equal(s, [25, 20, 30]) +def test_zero_precision_recall(): + """Check that patological cases do not bring NaNs""" + y_true = np.array([0, 1, 2, 0, 1, 2]) + y_pred = np.array([2, 0, 1, 1, 2, 0]) + + assert_almost_equal(precision_score(y_true, y_pred), 0.0, 2) + assert_almost_equal(recall_score(y_true, y_pred), 0.0, 2) + assert_almost_equal(f1_score(y_true, y_pred), 0.0, 2) + + def test_confusion_matrix_multiclass(): """Test confusion matrix - multi-class case""" y_true, y_pred, _ = make_prediction(binary=False)