diff --git a/benchmarks/bench_multilabel_metrics.py b/benchmarks/bench_multilabel_metrics.py
index 4a4cdd7f3390272698f01e14afdc20636c0b7113..7afee7543143338e7ba3fa1f26a492a999163690 100755
--- a/benchmarks/bench_multilabel_metrics.py
+++ b/benchmarks/bench_multilabel_metrics.py
@@ -22,7 +22,7 @@ from sklearn.utils.testing import ignore_warnings
 
 
 METRICS = {
-    'f1': f1_score,
+    'f1': partial(f1_score, average='micro'),
     'f1-by-sample': partial(f1_score, average='samples'),
     'accuracy': accuracy_score,
     'hamming': hamming_loss,
diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst
index 3e40a0aaf01f5e0ac08438da39e8ed79cb2473a9..003366efa4606e532587e7bb129fb21973cb7e1d 100644
--- a/doc/datasets/twenty_newsgroups.rst
+++ b/doc/datasets/twenty_newsgroups.rst
@@ -131,7 +131,7 @@ which is fast to train and achieves a decent F-score::
   >>> clf = MultinomialNB(alpha=.01)
   >>> clf.fit(vectors, newsgroups_train.target)
   >>> pred = clf.predict(vectors_test)
-  >>> metrics.f1_score(newsgroups_test.target, pred)
+  >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
   0.88251152461278892
 
 (The example :ref:`example_text_document_classification_20newsgroups.py` shuffles
@@ -181,7 +181,7 @@ blocks, and quotation blocks respectively.
   ...                                      categories=categories)
   >>> vectors_test = vectorizer.transform(newsgroups_test.data)
   >>> pred = clf.predict(vectors_test)
-  >>> metrics.f1_score(pred, newsgroups_test.target)
+  >>> metrics.f1_score(pred, newsgroups_test.target, average='weighted')
   0.78409163025839435
 
 This classifier lost over a lot of its F-score, just because we removed
@@ -196,7 +196,7 @@ It loses even more if we also strip this metadata from the training data:
   >>> clf.fit(vectors, newsgroups_train.target)
   >>> vectors_test = vectorizer.transform(newsgroups_test.data)
   >>> pred = clf.predict(vectors_test)
-  >>> metrics.f1_score(newsgroups_test.target, pred)
+  >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
   0.73160869205141166
 
 Some other classifiers cope better with this harder version of the task. Try
diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 494f50e1994a05a1cde5150a2a0e3812193f2abc..53a8dac9c82293feee202dc1f6b1fe859ec42934 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -732,6 +732,7 @@ details.
    :template: function.rst
 
    metrics.make_scorer
+   metrics.get_scorer
 
 Classification metrics
 ----------------------
diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index 59a8486f0e3e9f5cf4f1101ac748a30851559ff8..9a3ea274837d47a6b39eef5c1a42843e766fdb03 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -120,7 +120,7 @@ scoring parameter::
 
   >>> from sklearn import metrics
   >>> scores = cross_validation.cross_val_score(clf, iris.data, iris.target,
-  ...     cv=5, scoring='f1')
+  ...     cv=5, scoring='f1_weighted')
   >>> scores                                              # doctest: +ELLIPSIS
   array([ 0.96...,  1.  ...,  0.96...,  0.96...,  1.        ])
 
diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst
index 4cf37821637e4186a6707512a4083987b07cabf6..5b8a0ed23bff562b50f1c8668984fe769bfd01d3 100644
--- a/doc/modules/model_evaluation.rst
+++ b/doc/modules/model_evaluation.rst
@@ -60,10 +60,14 @@ Scoring                    Function
 **Classification**
 'accuracy'                 :func:`metrics.accuracy_score`
 'average_precision'        :func:`metrics.average_precision_score`
-'f1'                       :func:`metrics.f1_score`
+'f1'                       :func:`metrics.f1_score`                    for binary targets
+'f1_micro'                 :func:`metrics.f1_score`                    micro-averaged
+'f1_macro'                 :func:`metrics.f1_score`                    macro-averaged
+'f1_weighted'              :func:`metrics.f1_score`                    weighted average
+'f1_samples'               :func:`metrics.f1_score`                    by multilabel sample
 'log_loss'                 :func:`metrics.log_loss`                    requires ``predict_proba`` support
-'precision'                :func:`metrics.precision_score`
-'recall'                   :func:`metrics.recall_score`
+'precision' etc.           :func:`metrics.precision_score`             suffixes apply as with 'f1'`
+'recall' etc.              :func:`metrics.recall_score`                suffixes apply as with 'f1'
 'roc_auc'                  :func:`metrics.roc_auc_score`
 
 **Clustering**
@@ -84,7 +88,7 @@ Usage examples:
     >>> model = svm.SVC()
     >>> cross_validation.cross_val_score(model, X, y, scoring='wrong_choice')
     Traceback (most recent call last):
-    ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'precision', 'r2', 'recall', 'roc_auc']
+    ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc']
     >>> clf = svm.SVC(probability=True, random_state=0)
     >>> cross_validation.cross_val_score(clf, X, y, scoring='log_loss') # doctest: +ELLIPSIS
     array([-0.07..., -0.16..., -0.06...])
@@ -208,6 +212,8 @@ The :mod:`sklearn.metrics` module implements several loss, score, and utility
 functions to measure classification performance.
 Some metrics might require probability estimates of the positive class,
 confidence values, or binary decisions values.
+Most implementations allow each sample to provide a weighted contribution
+to the overall score, through the ``sample_weight`` parameter.
 
 Some of these are restricted to the binary classification case:
 
@@ -254,7 +260,53 @@ And some work with binary and multilabel (but not multiclass) problems:
    roc_auc_score
 
 
-In the following sub-sections, we will describe each of those functions.
+In the following sub-sections, we will describe each of those functions,
+preceded by some notes on common API and metric definition.
+
+From binary to multiclass and multilabel
+........................................
+
+Some metrics are essentially defined for binary classification tasks (e.g.
+:func:`f1_score`, :func:`roc_auc_score`). In these cases, by default
+only the positive label is evaluated, assuming by default that the positive
+class is labelled ``1`` (though this may be configurable through the
+``pos_label`` parameter).
+
+.. _average:
+
+In extending a binary metric to multiclass or multilabel problems, the data
+is treated as a collection of binary problems, one for each class.
+There are then a number of ways to average binary metric calculations across
+the set of classes, each of which may be useful in some scenario.
+Where available, you should select among these using the ``average`` parameter.
+
+* ``"macro"`` simply calculates the mean of the binary metrics,
+  giving equal weight to each class.  In problems where infrequent classes
+  are nonetheless important, macro-averaging may be a means of highlighting
+  their performance. On the other hand, the assumption that all classes are
+  equally important is often untrue, such that macro-averaging will
+  over-emphasise the typically low performance on an infrequent class.
+* ``"weighted"`` accounts for class imbalance by computing the average of
+  binary metrics in which each class's score is weighted by its presence in the
+  true data sample.
+* ``"micro"`` gives each sample-class pair an equal contribution to the overall
+  metric (except as a result of sample-weight). Rather than summing the
+  metric per class, this sums the dividends and divisors that make up the the
+  per-class metrics to calculate an overall quotient.
+  Micro-averaging may be preferred in multilabel settings, including
+  multiclass classification where a majority class is to be ignored.
+* ``"samples"`` applies only to multilabel problems. It does not calculate a
+  per-class measure, instead calculating the metric over the true and predicted
+  classes for each sample in the evaluation data, and returning their
+  (``sample_weight``-weighted) average.
+* Selecting ``average=None`` will return an array with the score for each
+  class.
+
+While multiclass data is provided to the metric, like binary targets, as an
+array of class labels, multilabel data is specified as an indicator matrix,
+in which cell ``[i, j]`` has value 1 if sample ``i`` has label ``j`` and value
+0 otherwise.
+
 
 Accuracy score
 --------------
@@ -595,21 +647,10 @@ There are a few ways to combine results across labels,
 specified by the ``average`` argument to the
 :func:`average_precision_score` (multilabel only), :func:`f1_score`,
 :func:`fbeta_score`, :func:`precision_recall_fscore_support`,
-:func:`precision_score` and :func:`recall_score` functions:
-
-* ``"micro"``: calculate metrics globally by counting the total true
-  positives, false negatives and false positives. Except in the multi-label
-  case, this implies that precision, recall and :math:`F` are equal.
-* ``"samples"``: calculate metrics for each sample, comparing sets of
-  labels assigned to each, and find the mean across all samples.
-  This is only meaningful and available in the multilabel case.
-* ``"macro"``: calculate metrics for each label, and find their mean.
-  This does not take label imbalance into account.
-* ``"weighted"``: calculate metrics for each label, and find their average
-  weighted by the number of occurrences of the label in the true data.
-  This alters ``"macro"`` to account for label imbalance; it may produce an
-  F-score that is not between precision and recall.
-* ``None``: calculate metrics for each label and do not average them.
+:func:`precision_score` and :func:`recall_score` functions, as described
+:ref:`above <average>`. Note that for "micro"-averaging in a multiclass setting
+will produce equal precision, recall and :math:`F`, while "weighted" averaging
+may produce an F-score that is not between precision and recall.
 
 To make this more explicit, consider the following notation:
 
@@ -869,20 +910,7 @@ For more information see the `Wikipedia article on AUC
   0.75
 
 In multi-label classification, the :func:`roc_auc_score` function is
-extended by averaging over the labels:
-
-* ``"micro"``: computes AUROC globally; obtained
-  by considering each element of the label indicator matrix as a label.
-* ``"samples"``: computes AUROC for each sample,
-  comparing the sets of labels and scores assigned to each, and finds the mean
-  across all samples.
-* ``"macro"``: computes AUROC for each label, and finds
-  their mean.
-* ``"weighted"``: computes AUROC for each label and
-  finds their average, weighted by the number of occurrences of the label in the
-  true data.
-* ``None``: this returns an array of scores with scores with shape (n_classes,)
-  instead of an aggregate scalar score.
+extended by averaging over the labels as :ref:`above <average>`.
 
 Compared to metrics such as the subset accuracy, the Hamming loss, or the
 F1 score, ROC doesn't require optimizing a threshold for each label. The
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index ccc04c6ed20bae89bd20e2a4d991887287dde696..b8f1f02e3655a636d9d40570d4b2d026c937a98b 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -234,6 +234,17 @@ API changes summary
       :func:`linear_model.enet_path` which constrains coefficients to be
       positive. By `Manoj Kumar`_.
 
+    - Users should now supply an explicit ``average`` parameter to
+      :func:`sklearn.metrics.f1_score`, :func:`sklearn.metrics.fbeta_score`,
+      :func:`sklearn.metrics.recall_score` and
+      :func:`sklearn.metrics.precision_score` when performing multiclass
+      or multilabel (i.e. not binary) classification. By `Joel Nothman`_.
+
+    - `scoring` parameter for cross validation now accepts `'f1_micro'`,
+      `'f1_macro'` or `'f1_weighted'`. `'f1'` is now for binary classification
+      only. Similar changes apply to `'precision'` and `'recall'`.
+      By `Joel Nothman`_.
+
 .. _changes_0_15_2:
 
 0.15.2
@@ -274,7 +285,7 @@ Bug fixes
     running the tests. By `Joel Nothman`_.
 
   - Many documentation and website fixes by `Joel Nothman`_, `Lars Buitinck`_
-    and others.
+    `Matt Pico`_, and others.
 
 .. _changes_0_15_1:
 
diff --git a/examples/model_selection/grid_search_digits.py b/examples/model_selection/grid_search_digits.py
index 20f25d3751cc009d3c044e75faf61fb4058b7aca..7492d0a726ef5a07eac9f9ac27daa42c7509a57a 100644
--- a/examples/model_selection/grid_search_digits.py
+++ b/examples/model_selection/grid_search_digits.py
@@ -50,7 +50,8 @@ for score in scores:
     print("# Tuning hyper-parameters for %s" % score)
     print()
 
-    clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, scoring=score)
+    clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,
+                       scoring='%_weighted' % score)
     clf.fit(X_train, y_train)
 
     print("Best parameters set found on development set:")
diff --git a/examples/text/document_classification_20newsgroups.py b/examples/text/document_classification_20newsgroups.py
index a7197ac1f6cb0632572097a478601405426aa57d..87e29accb763f783ac3162c4f36c451ccf9ed0a1 100644
--- a/examples/text/document_classification_20newsgroups.py
+++ b/examples/text/document_classification_20newsgroups.py
@@ -140,7 +140,7 @@ print()
 # split a training set and a test set
 y_train, y_test = data_train.target, data_test.target
 
-print("Extracting features from the training dataset using a sparse vectorizer")
+print("Extracting features from the training data using a sparse vectorizer")
 t0 = time()
 if opts.use_hashing:
     vectorizer = HashingVectorizer(stop_words='english', non_negative=True,
@@ -155,7 +155,7 @@ print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration))
 print("n_samples: %d, n_features: %d" % X_train.shape)
 print()
 
-print("Extracting features from the test dataset using the same vectorizer")
+print("Extracting features from the test data using the same vectorizer")
 t0 = time()
 X_test = vectorizer.transform(data_test.data)
 duration = time() - t0
@@ -208,8 +208,8 @@ def benchmark(clf):
     test_time = time() - t0
     print("test time:  %0.3fs" % test_time)
 
-    score = metrics.f1_score(y_test, pred)
-    print("f1-score:   %0.3f" % score)
+    score = metrics.accuracy_score(y_test, pred, average='micro')
+    print("accuracy:   %0.3f" % score)
 
     if hasattr(clf, 'coef_'):
         print("dimensionality: %d" % clf.coef_.shape[1])
diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py
index 0c20377506c461b9c3430e2ed0a4417f79a7b3b2..c1207af0f4d709ccdfbb4f0a3d5e2eedd35795aa 100644
--- a/sklearn/feature_selection/tests/test_rfe.py
+++ b/sklearn/feature_selection/tests/test_rfe.py
@@ -14,8 +14,8 @@ from sklearn.svm import SVC, SVR
 from sklearn.utils import check_random_state
 from sklearn.utils.testing import ignore_warnings
 
-from sklearn.metrics.scorer import SCORERS
 from sklearn.metrics import make_scorer
+from sklearn.metrics import get_scorer
 
 
 def test_rfe_set_params():
@@ -97,7 +97,7 @@ def test_rfecv():
     assert_array_equal(X_r, iris.data)
 
     # Test using a scorer
-    scorer = SCORERS['accuracy']
+    scorer = get_scorer('accuracy')
     rfecv = RFECV(estimator=SVC(kernel="linear"), step=1, cv=5,
                   scoring=scorer)
     rfecv.fit(X, y)
diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py
index 4249fe0a420da129459613b0baef9c8241f706a8..d13dc3ebf461b5869ed6e63faaffa7ac43aa7332 100644
--- a/sklearn/linear_model/tests/test_ridge.py
+++ b/sklearn/linear_model/tests/test_ridge.py
@@ -15,8 +15,8 @@ from sklearn.utils.testing import ignore_warnings
 
 from sklearn import datasets
 from sklearn.metrics import mean_squared_error
-from sklearn.metrics.scorer import SCORERS
 from sklearn.metrics import make_scorer
+from sklearn.metrics import get_scorer
 
 from sklearn.linear_model.base import LinearRegression
 from sklearn.linear_model.ridge import ridge_regression
@@ -336,7 +336,7 @@ def _test_ridge_loo(filter_):
     assert_equal(ridge_gcv3.alpha_, alpha_)
 
     # check that we get same best alpha with a scorer
-    scorer = SCORERS['mean_squared_error']
+    scorer = get_scorer('mean_squared_error')
     ridge_gcv4 = RidgeCV(fit_intercept=False, scoring=scorer)
     ridge_gcv4.fit(filter_(X_diabetes), y_diabetes)
     assert_equal(ridge_gcv4.alpha_, alpha_)
diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py
index 6085ecb31bbb4ffdb4b531670a19787507aec851..45f459089bcc0c15bec9332f72e1482568407dbb 100644
--- a/sklearn/linear_model/tests/test_sgd.py
+++ b/sklearn/linear_model/tests/test_sgd.py
@@ -610,14 +610,16 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest):
         y = y[idx]
         clf = self.factory(alpha=0.0001, n_iter=1000,
                            class_weight=None).fit(X, y)
-        assert_almost_equal(metrics.f1_score(y, clf.predict(X)), 0.96,
-                            decimal=1)
+        assert_almost_equal(metrics.f1_score(y, clf.predict(X),
+                                             average='weighted'),
+                            0.96, decimal=1)
 
         # make the same prediction using automated class_weight
         clf_auto = self.factory(alpha=0.0001, n_iter=1000,
                                 class_weight="auto").fit(X, y)
-        assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X)), 0.96,
-                            decimal=1)
+        assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X),
+                                             average='weighted'),
+                            0.96, decimal=1)
 
         # Make sure that in the balanced case it does not change anything
         # to use "auto"
@@ -634,19 +636,19 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest):
         clf = self.factory(n_iter=1000, class_weight=None)
         clf.fit(X_imbalanced, y_imbalanced)
         y_pred = clf.predict(X)
-        assert_less(metrics.f1_score(y, y_pred), 0.96)
+        assert_less(metrics.f1_score(y, y_pred, average='weighted'), 0.96)
 
         # fit a model with auto class_weight enabled
         clf = self.factory(n_iter=1000, class_weight="auto")
         clf.fit(X_imbalanced, y_imbalanced)
         y_pred = clf.predict(X)
-        assert_greater(metrics.f1_score(y, y_pred), 0.96)
+        assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96)
 
         # fit another using a fit parameter override
         clf = self.factory(n_iter=1000, class_weight="auto")
         clf.fit(X_imbalanced, y_imbalanced)
         y_pred = clf.predict(X)
-        assert_greater(metrics.f1_score(y, y_pred), 0.96)
+        assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96)
 
     def test_sample_weights(self):
         """Test weights on individual samples"""
diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py
index e419fec7a1c91956e370b715fe48687bb1ce7ac4..117120f46d6266b9fff6ed6335402349eadbe879 100644
--- a/sklearn/metrics/__init__.py
+++ b/sklearn/metrics/__init__.py
@@ -53,6 +53,7 @@ from .regression import r2_score
 
 from .scorer import make_scorer
 from .scorer import SCORERS
+from .scorer import get_scorer
 
 # Deprecated in 0.16
 from .ranking import auc_score
@@ -73,6 +74,7 @@ __all__ = [
     'explained_variance_score',
     'f1_score',
     'fbeta_score',
+    'get_scorer',
     'hamming_loss',
     'hinge_loss',
     'homogeneity_completeness_v_measure',
diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py
index 80ec8446462ad84ef03c9abfd79bb8b3556fadee..7f8ef9290d68bec9db6cea847e4559e5fb58ea09 100644
--- a/sklearn/metrics/classification.py
+++ b/sklearn/metrics/classification.py
@@ -475,7 +475,7 @@ def zero_one_loss(y_true, y_pred, normalize=True, sample_weight=None):
         return n_samples - score
 
 
-def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted',
+def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
              sample_weight=None):
     """Compute the F1 score, also known as balanced F-score or F-measure
 
@@ -504,7 +504,8 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted',
         If ``average`` is not ``None`` and the classification target is binary,
         only this class's scores will be returned.
 
-    average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
+    average : one of [None, 'micro', 'macro', 'samples', 'weighted']
+        This parameter is required for multiclass/multilabel targets.
         If ``None``, the scores for each class are returned. Otherwise,
         unless ``pos_label`` is given in binary classification, this
         determines the type of averaging performed on the data:
@@ -561,7 +562,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted',
 
 
 def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
-                average='weighted', sample_weight=None):
+                average='binary', sample_weight=None):
     """Compute the F-beta score
 
     The F-beta score is the weighted harmonic mean of precision and recall,
@@ -590,7 +591,8 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
         If ``average`` is not ``None`` and the classification target is binary,
         only this class's scores will be returned.
 
-    average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
+    average : one of [None, 'micro', 'macro', 'samples', 'weighted']
+        This parameter is required for multiclass/multilabel targets.
         If ``None``, the scores for each class are returned. Otherwise,
         unless ``pos_label`` is given in binary classification, this
         determines the type of averaging performed on the data:
@@ -822,7 +824,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
 
     """
     average_options = (None, 'micro', 'macro', 'weighted', 'samples')
-    if average not in average_options:
+    if average not in average_options and average != 'binary':
         raise ValueError('average has to be one of ' +
                          str(average_options))
     if beta <= 0:
@@ -830,6 +832,17 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
 
     y_type, y_true, y_pred = _check_targets(y_true, y_pred)
 
+    if average == 'binary' and y_type != 'binary':
+        warnings.warn('The default `weighted` averaging is deprecated, '
+                      'and from version 0.18, use of precision, recall or '
+                      'F-score with multiclass or multilabel data will result '
+                      'in an exception. '
+                      'Please set an explicit value for `average`, one of '
+                      '%s. In cross validation use, for instance, '
+                      'scoring="f1_weighted" instead of scoring="f1".'
+                      % str(average_options), DeprecationWarning, stacklevel=2)
+        average = 'weighted'
+
     label_order = labels  # save this for later
     if labels is None:
         labels = unique_labels(y_true, y_pred)
@@ -852,7 +865,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
 
     elif average == 'samples':
         raise ValueError("Sample-based precision, recall, fscore is "
-                         "not meaningful outside multilabel"
+                         "not meaningful outside multilabel "
                          "classification. See the accuracy_score instead.")
     else:
         lb = LabelEncoder()
@@ -885,11 +898,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
     ### Select labels to keep ###
 
     if y_type == 'binary' and average is not None and pos_label is not None:
-        if label_order is not None and len(label_order) == 2:
+        if average != 'binary' and label_order is not None \
+           and len(label_order) == 2:
             warnings.warn('In the future, providing two `labels` values, as '
-                          'well as `average` will average over those '
-                          'labels. For now, please use `labels=None` with '
-                          '`pos_label` to evaluate precision, recall and '
+                          'well as `average!=`binary`` will average over '
+                          'those labels. For now, please use `labels=None` '
+                          'with `pos_label` to evaluate precision, recall and '
                           'F-score for the positive label only.',
                           FutureWarning)
         if pos_label not in labels:
@@ -954,7 +968,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
 
 
 def precision_score(y_true, y_pred, labels=None, pos_label=1,
-                    average='weighted', sample_weight=None):
+                    average='binary', sample_weight=None):
     """Compute the precision
 
     The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
@@ -979,7 +993,8 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
         If ``average`` is not ``None`` and the classification target is binary,
         only this class's scores will be returned.
 
-    average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
+    average : one of [None, 'micro', 'macro', 'samples', 'weighted']
+        This parameter is required for multiclass/multilabel targets.
         If ``None``, the scores for each class are returned. Otherwise,
         unless ``pos_label`` is given in binary classification, this
         determines the type of averaging performed on the data:
@@ -1036,7 +1051,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
     return p
 
 
-def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted',
+def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
                  sample_weight=None):
     """Compute the recall
 
@@ -1061,7 +1076,8 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted',
         If ``average`` is not ``None`` and the classification target is binary,
         only this class's scores will be returned.
 
-    average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
+    average : one of [None, 'micro', 'macro', 'samples', 'weighted']
+        This parameter is required for multiclass/multilabel targets.
         If ``None``, the scores for each class are returned. Otherwise,
         unless ``pos_label`` is given in binary classification, this
         determines the type of averaging performed on the data:
diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py
index 6f99deb815d0454dbbb29f9f71bfd4c931de91e9..ce7eaaeca4410e5f171ecc2cd5b077ead4419801 100644
--- a/sklearn/metrics/scorer.py
+++ b/sklearn/metrics/scorer.py
@@ -19,6 +19,7 @@ ground truth labeling (or ``None`` in the case of unsupervised models).
 # License: Simplified BSD
 
 from abc import ABCMeta, abstractmethod
+from functools import partial
 
 import numpy as np
 
@@ -342,8 +343,15 @@ SCORERS = dict(r2=r2_scorer,
                median_absolute_error=median_absolute_error_scorer,
                mean_absolute_error=mean_absolute_error_scorer,
                mean_squared_error=mean_squared_error_scorer,
-               accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=roc_auc_scorer,
+               accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
                average_precision=average_precision_scorer,
-               precision=precision_scorer, recall=recall_scorer,
                log_loss=log_loss_scorer,
                adjusted_rand_score=adjusted_rand_scorer)
+
+for name, metric in [('precision', precision_score),
+                     ('recall', recall_score), ('f1', f1_score)]:
+    SCORERS[name] = make_scorer(metric)
+    for average in ['macro', 'micro', 'samples', 'weighted']:
+        qualified_name = '{0}_{1}'.format(name, average)
+        SCORERS[qualified_name] = make_scorer(partial(metric, pos_label=None,
+                                                      average=average))
diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py
index 578e9d13db659e3604111244b5578cfb69c5e80e..1a920bfba65c2d86b035237186ffb44049048a37 100644
--- a/sklearn/metrics/tests/test_classification.py
+++ b/sklearn/metrics/tests/test_classification.py
@@ -991,6 +991,24 @@ def test_fscore_warnings():
                          'being set to 0.0 due to no true samples.')
 
 
+def test_prf_average_compat():
+    """Ensure warning if f1_score et al.'s average is implicit for multiclass
+    """
+    y_true = [1, 2, 3, 3]
+    y_pred = [1, 2, 3, 1]
+
+    for metric in [precision_score, recall_score, f1_score,
+                   partial(fbeta_score, beta=2)]:
+        score = assert_warns(DeprecationWarning, metric, y_true, y_pred)
+        score_weighted = assert_no_warnings(metric, y_true, y_pred,
+                                            average='weighted')
+        assert_equal(score, score_weighted,
+                     'average does not act like "weighted" by default')
+
+        # check binary passes without warning
+        assert_no_warnings(metric, [0, 1, 1], [0, 1, 0])
+
+
 @ignore_warnings  # sequence of sequences is deprecated
 def test__check_targets():
     """Check that _check_targets correctly merges target types, squeezes
diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py
index 1973b6e3364c01bb38b8b7a9d60674515eabcdd2..e7c7c33d2484ae84bc7ed92890c4c77ae84e1fef 100644
--- a/sklearn/metrics/tests/test_common.py
+++ b/sklearn/metrics/tests/test_common.py
@@ -107,6 +107,7 @@ CLASSIFICATION_METRICS = {
     "zero_one_loss": zero_one_loss,
     "unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False),
 
+    # These are needed to test averaging
     "precision_score": precision_score,
     "recall_score": recall_score,
     "f1_score": f1_score,
@@ -336,6 +337,7 @@ METRICS_WITHOUT_SAMPLE_WEIGHT = [
 ]
 
 
+@ignore_warnings
 def test_symmetry():
     """Test the symmetry of score and loss functions"""
     random_state = check_random_state(0)
@@ -366,6 +368,7 @@ def test_symmetry():
                     msg="%s seems to be symmetric" % name)
 
 
+@ignore_warnings
 def test_sample_order_invariance():
     random_state = check_random_state(0)
     y_true = random_state.randint(0, 2, size=(20, ))
@@ -382,6 +385,7 @@ def test_sample_order_invariance():
                                     % name)
 
 
+@ignore_warnings
 def test_sample_order_invariance_multilabel_and_multioutput():
     random_state = check_random_state(0)
 
@@ -421,6 +425,7 @@ def test_sample_order_invariance_multilabel_and_multioutput():
                                     % name)
 
 
+@ignore_warnings
 def test_format_invariance_with_1d_vectors():
     random_state = check_random_state(0)
     y1 = random_state.randint(0, 2, size=(20, ))
@@ -499,6 +504,7 @@ def test_format_invariance_with_1d_vectors():
             assert_raises(ValueError, metric, y1_row, y2_row)
 
 
+@ignore_warnings
 def test_invariance_string_vs_numbers_labels():
     """Ensure that classification metrics with string labels"""
     random_state = check_random_state(0)
@@ -627,6 +633,7 @@ def test_multioutput_regression_invariance_to_dimension_shuffling():
                                         "invariant" % name)
 
 
+@ignore_warnings
 def test_multilabel_representation_invariance():
     # Generate some data
     n_classes = 4
diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py
index aef0a33dc4a59a8dfe0cdd57491dcd2ff8d17789..50942c77653a53cfedd5ac2a4e230ae1685eab46 100644
--- a/sklearn/metrics/tests/test_score_objects.py
+++ b/sklearn/metrics/tests/test_score_objects.py
@@ -10,10 +10,10 @@ from sklearn.utils.testing import ignore_warnings
 from sklearn.utils.testing import assert_not_equal
 
 from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
-                             log_loss)
+                             log_loss, precision_score, recall_score)
 from sklearn.metrics.cluster import adjusted_rand_score
 from sklearn.metrics.scorer import check_scoring
-from sklearn.metrics import make_scorer, SCORERS
+from sklearn.metrics import make_scorer, get_scorer, SCORERS
 from sklearn.svm import LinearSVC
 from sklearn.cluster import KMeans
 from sklearn.dummy import DummyRegressor
@@ -30,11 +30,17 @@ from sklearn.multiclass import OneVsRestClassifier
 
 REGRESSION_SCORERS = ['r2', 'mean_absolute_error', 'mean_squared_error',
                       'median_absolute_error']
-CLF_SCORERS = ['accuracy', 'f1', 'roc_auc', 'average_precision', 'precision',
-               'recall', 'log_loss',
+
+CLF_SCORERS = ['accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
+               'roc_auc', 'average_precision', 'precision',
+               'precision_weighted', 'precision_macro', 'precision_micro',
+               'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
+               'log_loss',
                'adjusted_rand_score'  # not really, but works
                ]
 
+MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples']
+
 
 class EstimatorWithoutFit(object):
     """Dummy estimator to test check_scoring"""
@@ -107,18 +113,38 @@ def test_make_scorer():
 
 def test_classification_scores():
     """Test classification scorers."""
-    X, y = make_blobs(random_state=0)
+    X, y = make_blobs(random_state=0, centers=2)
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
     clf = LinearSVC(random_state=0)
     clf.fit(X_train, y_train)
-    score1 = SCORERS['f1'](clf, X_test, y_test)
-    score2 = f1_score(y_test, clf.predict(X_test))
-    assert_almost_equal(score1, score2)
+
+    for prefix, metric in [('f1', f1_score), ('precision', precision_score),
+                           ('recall', recall_score)]:
+
+        score1 = get_scorer('%s_weighted' % prefix)(clf, X_test, y_test)
+        score2 = metric(y_test, clf.predict(X_test), pos_label=None,
+                        average='weighted')
+        assert_almost_equal(score1, score2)
+
+        score1 = get_scorer('%s_macro' % prefix)(clf, X_test, y_test)
+        score2 = metric(y_test, clf.predict(X_test), pos_label=None,
+                        average='macro')
+        assert_almost_equal(score1, score2)
+
+        score1 = get_scorer('%s_micro' % prefix)(clf, X_test, y_test)
+        score2 = metric(y_test, clf.predict(X_test), pos_label=None,
+                        average='micro')
+        assert_almost_equal(score1, score2)
+
+        score1 = get_scorer('%s' % prefix)(clf, X_test, y_test)
+        score2 = metric(y_test, clf.predict(X_test), pos_label=1)
+        assert_almost_equal(score1, score2)
 
     # test fbeta score that takes an argument
     scorer = make_scorer(fbeta_score, beta=2)
     score1 = scorer(clf, X_test, y_test)
-    score2 = fbeta_score(y_test, clf.predict(X_test), beta=2)
+    score2 = fbeta_score(y_test, clf.predict(X_test), beta=2,
+                         average='weighted')
     assert_almost_equal(score1, score2)
 
     # test that custom scorer can be pickled
@@ -137,7 +163,7 @@ def test_regression_scorers():
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
     clf = Ridge()
     clf.fit(X_train, y_train)
-    score1 = SCORERS['r2'](clf, X_test, y_test)
+    score1 = get_scorer('r2')(clf, X_test, y_test)
     score2 = r2_score(y_test, clf.predict(X_test))
     assert_almost_equal(score1, score2)
 
@@ -148,20 +174,20 @@ def test_thresholded_scorers():
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
     clf = LogisticRegression(random_state=0)
     clf.fit(X_train, y_train)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, clf.decision_function(X_test))
     score3 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
     assert_almost_equal(score1, score2)
     assert_almost_equal(score1, score3)
 
-    logscore = SCORERS['log_loss'](clf, X_test, y_test)
+    logscore = get_scorer('log_loss')(clf, X_test, y_test)
     logloss = log_loss(y_test, clf.predict_proba(X_test))
     assert_almost_equal(-logscore, logloss)
 
     # same for an estimator without decision_function
     clf = DecisionTreeClassifier()
     clf.fit(X_train, y_train)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
     assert_almost_equal(score1, score2)
 
@@ -169,7 +195,7 @@ def test_thresholded_scorers():
     X, y = make_blobs(random_state=0, centers=3)
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
     clf.fit(X_train, y_train)
-    assert_raises(ValueError, SCORERS['roc_auc'], clf, X_test, y_test)
+    assert_raises(ValueError, get_scorer('roc_auc'), clf, X_test, y_test)
 
 
 def test_thresholded_scorers_multilabel_indicator_data():
@@ -185,7 +211,7 @@ def test_thresholded_scorers_multilabel_indicator_data():
     clf = DecisionTreeClassifier()
     clf.fit(X_train, y_train)
     y_proba = clf.predict_proba(X_test)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, np.vstack(p[:, -1] for p in y_proba).T)
     assert_almost_equal(score1, score2)
 
@@ -198,21 +224,21 @@ def test_thresholded_scorers_multilabel_indicator_data():
     clf.decision_function = lambda X: [p[:, 1] for p in clf._predict_proba(X)]
 
     y_proba = clf.decision_function(X_test)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, np.vstack(p for p in y_proba).T)
     assert_almost_equal(score1, score2)
 
     # Multilabel predict_proba
     clf = OneVsRestClassifier(DecisionTreeClassifier())
     clf.fit(X_train, y_train)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, clf.predict_proba(X_test))
     assert_almost_equal(score1, score2)
 
     # Multilabel decision function
     clf = OneVsRestClassifier(LinearSVC(random_state=0))
     clf.fit(X_train, y_train)
-    score1 = SCORERS['roc_auc'](clf, X_test, y_test)
+    score1 = get_scorer('roc_auc')(clf, X_test, y_test)
     score2 = roc_auc_score(y_test, clf.decision_function(X_test))
     assert_almost_equal(score1, score2)
 
@@ -224,7 +250,7 @@ def test_unsupervised_scorers():
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
     km = KMeans(n_clusters=3)
     km.fit(X_train)
-    score1 = SCORERS['adjusted_rand_score'](km, X_test, y_test)
+    score1 = get_scorer('adjusted_rand_score')(km, X_test, y_test)
     score2 = adjusted_rand_score(y_test, km.predict(X_test))
     assert_almost_equal(score1, score2)
 
@@ -242,6 +268,7 @@ def test_raises_on_score_list():
     assert_raises(ValueError, grid_search.fit, X, y)
 
 
+@ignore_warnings
 def test_scorer_sample_weight():
     """Test that scorers support sample_weight or raise sensible errors"""
 
@@ -249,7 +276,12 @@ def test_scorer_sample_weight():
     # to ensure that, on the classifier output, weighted and unweighted
     # scores really should be unequal.
     X, y = make_classification(random_state=0)
-    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+    _, y_ml = make_multilabel_classification(n_samples=X.shape[0],
+                                             return_indicator=True,
+                                             random_state=0)
+    split = train_test_split(X, y, y_ml, random_state=0)
+    X_train, X_test, y_train, y_test, y_ml_train, y_ml_test = split
+
     sample_weight = np.ones_like(y_test)
     sample_weight[:10] = 0
 
@@ -258,17 +290,25 @@ def test_scorer_sample_weight():
     sensible_regr.fit(X_train, y_train)
     sensible_clf = DecisionTreeClassifier()
     sensible_clf.fit(X_train, y_train)
+    sensible_ml_clf = DecisionTreeClassifier()
+    sensible_ml_clf.fit(X_train, y_ml_train)
     estimator = dict([(name, sensible_regr)
                       for name in REGRESSION_SCORERS] +
                      [(name, sensible_clf)
-                      for name in CLF_SCORERS])
+                      for name in CLF_SCORERS] +
+                     [(name, sensible_ml_clf)
+                      for name in MULTILABEL_ONLY_SCORERS])
 
     for name, scorer in SCORERS.items():
+        if name in MULTILABEL_ONLY_SCORERS:
+            target = y_ml_test
+        else:
+            target = y_test
         try:
-            weighted = scorer(estimator[name], X_test, y_test,
+            weighted = scorer(estimator[name], X_test, target,
                               sample_weight=sample_weight)
-            ignored = scorer(estimator[name], X_test[10:], y_test[10:])
-            unweighted = scorer(estimator[name], X_test, y_test)
+            ignored = scorer(estimator[name], X_test[10:], target[10:])
+            unweighted = scorer(estimator[name], X_test, target)
             assert_not_equal(weighted, unweighted,
                              msg="scorer {0} behaves identically when "
                              "called with sample weights: {1} vs "
diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py
index 79e0ec207dd01ab76ce1859826d58e65f95e1a02..844f35fb4766381a3675d40efdaafc6e3a48aebc 100644
--- a/sklearn/svm/tests/test_svm.py
+++ b/sklearn/svm/tests/test_svm.py
@@ -392,8 +392,9 @@ def test_auto_weight():
         y_pred = clf.fit(X[unbalanced], y[unbalanced]).predict(X)
         clf.set_params(class_weight='auto')
         y_pred_balanced = clf.fit(X[unbalanced], y[unbalanced],).predict(X)
-        assert_true(metrics.f1_score(y, y_pred)
-                    <= metrics.f1_score(y, y_pred_balanced))
+        assert_true(metrics.f1_score(y, y_pred, average='weighted')
+                    <= metrics.f1_score(y, y_pred_balanced,
+                                        average='weighted'))
 
 
 def test_bad_input():
diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py
index b60e5fcd8e008603b6d0f92e210152986c558bb3..116ee6b29e86ed79aa2f59912315b797a08cb5b9 100644
--- a/sklearn/tests/test_cross_validation.py
+++ b/sklearn/tests/test_cross_validation.py
@@ -683,7 +683,7 @@ def test_cross_val_score_with_score_func_classification():
     # F1 score (class are balanced so f1_score should be equal to zero/one
     # score
     f1_scores = cval.cross_val_score(clf, iris.data, iris.target,
-                                     scoring="f1", cv=5)
+                                     scoring="f1_weighted", cv=5)
     assert_array_almost_equal(f1_scores, [0.97, 1., 0.97, 0.97, 1.], 2)
 
 
diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py
index 19618b96a9fea8d8675970edb87e3d489f844b71..af08a22233ea02ce2c9b2493be4897b5fcf2f370 100644
--- a/sklearn/utils/estimator_checks.py
+++ b/sklearn/utils/estimator_checks.py
@@ -742,8 +742,8 @@ def check_class_weight_auto_classifiers(name, Classifier, X_train, y_train,
     classifier.set_params(class_weight='auto')
     classifier.fit(X_train, y_train)
     y_pred_auto = classifier.predict(X_test)
-    assert_greater(f1_score(y_test, y_pred_auto),
-                   f1_score(y_test, y_pred))
+    assert_greater(f1_score(y_test, y_pred_auto, average='weighted'),
+                   f1_score(y_test, y_pred, average='weighted'))
 
 
 def check_class_weight_auto_linear_classifier(name, Classifier):