diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py
index c04582dac9ca97fea58be3c6c1e2896b26f78791..1b3ac7762783db61c36fd66e72ec6d78de221e66 100644
--- a/scikits/learn/metrics.py
+++ b/scikits/learn/metrics.py
@@ -155,7 +155,7 @@ def precision(y_true, y_pred):
     =======
     precision : float
     """
-    return precision_recall_fscore(y_true, y_pred)[0]
+    return precision_recall_fscore_support(y_true, y_pred)[0]
 
 
 def recall(y_true, y_pred):
@@ -179,7 +179,7 @@ def recall(y_true, y_pred):
     =======
     recall : array, shape = [n_unique_labels], dtype = np.double
     """
-    return precision_recall_fscore(y_true, y_pred)[1]
+    return precision_recall_fscore_support(y_true, y_pred)[1]
 
 
 def fbeta_score(y_true, y_pred, beta):
@@ -241,8 +241,8 @@ def f1_score(y_true, y_pred):
     return fbeta_score(y_true, y_pred, 1)
 
 
-def precision_recall_fscore(y_true, y_pred, beta=1.0, labels=None):
-    """Compute precision and recall and f-measure at the same time.
+def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None):
+    """Compute precisions, recalls, f-measures and support for each class
 
     The precision is the ratio :math:`tp / (tp + fp)` where tp is the number of
     true positives and fp the number of false positives. The precision is
@@ -253,11 +253,14 @@ def precision_recall_fscore(y_true, y_pred, beta=1.0, labels=None):
     true positives and fn the number of false negatives. The recall is
     intuitively the ability of the classifier to find all the positive samples.
 
-    The F_beta score can be interpreted as a weighted average of the precision
-    and recall, where an F_beta score reaches its best value at 1 and worst
-    score at 0.
+    The F_beta score can be interpreted as a weighted harmonic mean of
+    the precision and recall, where an F_beta score reaches its best
+    value at 1 and worst score at 0.
 
-    The F_1 score weights recall beta as much as precision.
+    The F_beta score weights recall beta as much as precision. beta = 1.0 means
+    recall and precsion are as important.
+
+    The support is the number of occurrences of each class in y_true.
 
     Parameters
     ==========
@@ -274,7 +277,8 @@ def precision_recall_fscore(y_true, y_pred, beta=1.0, labels=None):
     =======
     precision: array, shape = [n_unique_labels], dtype = np.double
     recall: array, shape = [n_unique_labels], dtype = np.double
-    precision: array, shape = [n_unique_labels], dtype = np.double
+    f1_score: array, shape = [n_unique_labels], dtype = np.double
+    support: array, shape = [n_unique_labels], dtype = np.long
 
     References
     ==========
@@ -290,11 +294,13 @@ def precision_recall_fscore(y_true, y_pred, beta=1.0, labels=None):
     true_pos = np.zeros(n_labels, dtype=np.double)
     false_pos = np.zeros(n_labels, dtype=np.double)
     false_neg = np.zeros(n_labels, dtype=np.double)
+    support = np.zeros(n_labels, dtype=np.long)
 
     for i, label_i in enumerate(labels):
         true_pos[i] = np.sum(y_pred[y_true == label_i] == label_i)
         false_pos[i] = np.sum(y_pred[y_true != label_i] == label_i)
         false_neg[i] = np.sum(y_pred[y_true == label_i] != label_i)
+        support[i] = np.sum(y_true == label_i)
 
     # precision and recall
     precision = true_pos / (true_pos + false_pos)
@@ -304,7 +310,7 @@ def precision_recall_fscore(y_true, y_pred, beta=1.0, labels=None):
     beta2 = beta ** 2
     fscore = (1 + beta2) * (precision * recall) / (
         beta2 * precision + recall)
-    return precision, recall, fscore
+    return precision, recall, fscore, support
 
 
 def classification_report(y_true, y_pred, labels=None, class_names=None):
@@ -336,15 +342,17 @@ def classification_report(y_true, y_pred, labels=None, class_names=None):
     else:
         labels = np.asarray(labels, dtype=np.int)
 
+    last_line_heading = 'avg / total'
+
     if class_names is None:
-        width = len('mean')
+        width = len(last_line_heading)
         class_names = ['%d' % l for l in labels]
     else:
         width = max(len(cn) for cn in class_names)
-        width = max(width, len('mean'))
+        width = max(width, len(last_line_heading))
 
 
-    headers = ["precision", "recall", "f1-score"]
+    headers = ["precision", "recall", "f1-score", "support"]
     fmt = '{0:>%d}' % width # first column: class name
     fmt += '  '
     fmt += ' '.join(['{%d:>9}' % (i + 1) for i, _ in enumerate(headers)])
@@ -354,19 +362,23 @@ def classification_report(y_true, y_pred, labels=None, class_names=None):
     report = fmt.format(*headers)
     report += '\n'
 
-    p, r, f1 = precision_recall_fscore(y_true, y_pred, labels=labels)
+    p, r, f1, s = precision_recall_fscore_support(y_true, y_pred, labels=labels)
     for i, label in enumerate(labels):
         values = [class_names[i]]
         for v in (p[i], r[i], f1[i]):
             values += ["%0.2f" % float(v)]
+        values += ["%d" % int(s[i])]
         report += fmt.format(*values)
 
     report += '\n'
 
     # compute averages
-    values = ['mean']
-    for v in (np.mean(p), np.mean(r), np.mean(f1)):
+    values = [last_line_heading]
+    for v in (np.average(p, weights=s),
+              np.average(r, weights=s),
+              np.average(f1, weights=s)):
         values += ["%0.2f" % float(v)]
+    values += ['%d' % np.sum(s)]
     report += fmt.format(*values)
     return report
 
@@ -421,7 +433,7 @@ def precision_recall_curve(y_true, probas_pred):
     for i, t in enumerate(thresholds):
         y_pred = np.ones(len(y_true))
         y_pred[probas_pred < t] = 0
-        p, r, _ = precision_recall_fscore(y_true, y_pred)
+        p, r, _, _ = precision_recall_fscore_support(y_true, y_pred)
         precision[i] = p[1]
         recall[i] = r[1]
     precision[-1] = 1.0
diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py
index d5509f439282551a008cbb2f7b648c0f996528c2..4810405a30697122df378474c87b27d510b1bbbc 100644
--- a/scikits/learn/tests/test_metrics.py
+++ b/scikits/learn/tests/test_metrics.py
@@ -17,7 +17,7 @@ from ..metrics import f1_score
 from ..metrics import mean_square_error
 from ..metrics import precision
 from ..metrics import precision_recall_curve
-from ..metrics import precision_recall_fscore
+from ..metrics import precision_recall_fscore_support
 from ..metrics import recall
 from ..metrics import roc_curve
 from ..metrics import zero_one
@@ -80,10 +80,11 @@ def test_precision_recall_f1_score_binary():
     """Test Precision Recall and F1 Score for binary classification task"""
     y_true, y_pred, _ = make_prediction(binary=True)
 
-    p, r, f = precision_recall_fscore(y_true, y_pred)
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred)
     assert_array_almost_equal(p, [0.73, 0.75], 2)
     assert_array_almost_equal(r, [0.76, 0.72], 2)
     assert_array_almost_equal(f, [0.75, 0.74], 2)
+    assert_array_equal(s, [25, 25])
 
 
 def test_confusion_matrix_binary():
@@ -99,16 +100,19 @@ def test_precision_recall_f1_score_multiclass():
     y_true, y_pred, _ = make_prediction(binary=False)
 
     # compute scores with default labels introspection
-    p, r, f = precision_recall_fscore(y_true, y_pred)
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred)
     assert_array_almost_equal(p, [0.82, 0.55, 0.47], 2)
     assert_array_almost_equal(r, [0.92, 0.17, 0.90], 2)
     assert_array_almost_equal(f, [0.87, 0.26, 0.62], 2)
+    assert_array_equal(s, [25, 30, 20])
 
     # same prediction but with and explicit label ordering
-    p, r, f = precision_recall_fscore(y_true, y_pred, labels=[0, 2, 1])
+    p, r, f, s = precision_recall_fscore_support(
+        y_true, y_pred, labels=[0, 2, 1])
     assert_array_almost_equal(p, [0.82, 0.47, 0.55], 2)
     assert_array_almost_equal(r, [0.92, 0.90, 0.17], 2)
     assert_array_almost_equal(f, [0.87, 0.62, 0.26], 2)
+    assert_array_equal(s, [25, 20, 30])
 
 
 def test_confusion_matrix_multiclass():
@@ -135,13 +139,13 @@ def test_classification_report():
 
     # print classification report with class names
     expected_report = """\
-            precision    recall  f1-score
+             precision    recall  f1-score   support
 
-    setosa       0.82      0.92      0.87
-versicolor       0.56      0.17      0.26
- virginica       0.47      0.90      0.62
+     setosa       0.82      0.92      0.87        25
+ versicolor       0.56      0.17      0.26        30
+  virginica       0.47      0.90      0.62        20
 
-      mean       0.62      0.66      0.58
+avg / total       0.62      0.61      0.56        75
 """
     report = classification_report(
         y_true, y_pred, labels=range(len(iris.target_names)),
@@ -150,13 +154,13 @@ versicolor       0.56      0.17      0.26
 
     # print classification report with label detection
     expected_report = """\
-      precision    recall  f1-score
+             precision    recall  f1-score   support
 
-   0       0.82      0.92      0.87
-   1       0.56      0.17      0.26
-   2       0.47      0.90      0.62
+          0       0.82      0.92      0.87        25
+          1       0.56      0.17      0.26        30
+          2       0.47      0.90      0.62        20
 
-mean       0.62      0.66      0.58
+avg / total       0.62      0.61      0.56        75
 """
     report = classification_report(y_true, y_pred)
     assert_equal(report, expected_report)