diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py
index 50d3d3b7523f0a52b6ec6b00026ae3a53e14657d..2a7be716ee48da5139816da2bb1c7463799d3ce3 100644
--- a/sklearn/metrics/classification.py
+++ b/sklearn/metrics/classification.py
@@ -623,9 +623,10 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
            parameter *labels* improved for multiclass problem.
 
     pos_label : str or int, 1 by default
-        The class to report if ``average='binary'``. Until version 0.18 it is
-        necessary to set ``pos_label=None`` if seeking to use another averaging
-        method over binary targets.
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+        setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
+        scores for that label only.
 
     average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \
                        'weighted']
@@ -652,10 +653,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
             meaningful for multilabel classification where this differs from
             :func:`accuracy_score`).
 
-        Note that if ``pos_label`` is given in binary classification with
-        `average != 'binary'`, only that positive class is reported. This
-        behavior is deprecated and will change in version 0.18.
-
     sample_weight : array-like of shape = [n_samples], optional
         Sample weights.
 
@@ -729,9 +726,10 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
            parameter *labels* improved for multiclass problem.
 
     pos_label : str or int, 1 by default
-        The class to report if ``average='binary'``. Until version 0.18 it is
-        necessary to set ``pos_label=None`` if seeking to use another averaging
-        method over binary targets.
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+        setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
+        scores for that label only.
 
     average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \
                        'weighted']
@@ -758,10 +756,6 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
             meaningful for multilabel classification where this differs from
             :func:`accuracy_score`).
 
-        Note that if ``pos_label`` is given in binary classification with
-        `average != 'binary'`, only that positive class is reported. This
-        behavior is deprecated and will change in version 0.18.
-
     sample_weight : array-like of shape = [n_samples], optional
         Sample weights.
 
@@ -905,9 +899,10 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
         ``y_pred`` are used in sorted order.
 
     pos_label : str or int, 1 by default
-        The class to report if ``average='binary'``. Until version 0.18 it is
-        necessary to set ``pos_label=None`` if seeking to use another averaging
-        method over binary targets.
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+        setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
+        scores for that label only.
 
     average : string, [None (default), 'binary', 'micro', 'macro', 'samples', \
                        'weighted']
@@ -933,10 +928,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
             meaningful for multilabel classification where this differs from
             :func:`accuracy_score`).
 
-        Note that if ``pos_label`` is given in binary classification with
-        `average != 'binary'`, only that positive class is reported. This
-        behavior is deprecated and will change in version 0.18.
-
     warn_for : tuple or set, for internal use
         This determines which warnings will be made in the case that this
         function is being used to return only one of its metrics.
@@ -1008,25 +999,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
     y_type, y_true, y_pred = _check_targets(y_true, y_pred)
     present_labels = unique_labels(y_true, y_pred)
 
-    if average == 'binary' and (y_type != 'binary' or pos_label is None):
-        warnings.warn('The default `weighted` averaging is deprecated, '
-                      'and from version 0.18, use of precision, recall or '
-                      'F-score with multiclass or multilabel data or '
-                      'pos_label=None will result in an exception. '
-                      'Please set an explicit value for `average`, one of '
-                      '%s. In cross validation use, for instance, '
-                      'scoring="f1_weighted" instead of scoring="f1".'
-                      % str(average_options), DeprecationWarning, stacklevel=2)
-        average = 'weighted'
-
-    if y_type == 'binary' and pos_label is not None and average is not None:
-        if average != 'binary':
-            warnings.warn('From version 0.18, binary input will not be '
-                          'handled specially when using averaged '
-                          'precision/recall/F-score. '
-                          'Please use average=\'binary\' to report only the '
-                          'positive class performance.', DeprecationWarning)
-        if labels is None or len(labels) <= 2:
+    if average == 'binary':
+        if y_type == 'binary':
             if pos_label not in present_labels:
                 if len(present_labels) < 2:
                     # Only negative labels
@@ -1035,6 +1009,15 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
                     raise ValueError("pos_label=%r is not a valid label: %r" %
                                      (pos_label, present_labels))
             labels = [pos_label]
+        else:
+            raise ValueError("Target is %s but average='binary'. Please "
+                             "choose another average setting." % y_type)
+    elif pos_label not in (None, 1):
+        warnings.warn("Note that pos_label (set to %r) is ignored when "
+                      "average != 'binary' (got %r). You may use "
+                      "labels=[pos_label] to specify a single positive class."
+                      % (pos_label, average), UserWarning)
+
     if labels is None:
         labels = present_labels
         n_labels = None
@@ -1187,9 +1170,10 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
            parameter *labels* improved for multiclass problem.
 
     pos_label : str or int, 1 by default
-        The class to report if ``average='binary'``. Until version 0.18 it is
-        necessary to set ``pos_label=None`` if seeking to use another averaging
-        method over binary targets.
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+        setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
+        scores for that label only.
 
     average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \
                        'weighted']
@@ -1216,10 +1200,6 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
             meaningful for multilabel classification where this differs from
             :func:`accuracy_score`).
 
-        Note that if ``pos_label`` is given in binary classification with
-        `average != 'binary'`, only that positive class is reported. This
-        behavior is deprecated and will change in version 0.18.
-
     sample_weight : array-like of shape = [n_samples], optional
         Sample weights.
 
@@ -1289,9 +1269,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
            parameter *labels* improved for multiclass problem.
 
     pos_label : str or int, 1 by default
-        The class to report if ``average='binary'``. Until version 0.18 it is
-        necessary to set ``pos_label=None`` if seeking to use another averaging
-        method over binary targets.
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+        setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
+        scores for that label only.
 
     average : string, [None, 'binary' (default), 'micro', 'macro', 'samples', \
                        'weighted']
@@ -1318,10 +1299,6 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
             meaningful for multilabel classification where this differs from
             :func:`accuracy_score`).
 
-        Note that if ``pos_label`` is given in binary classification with
-        `average != 'binary'`, only that positive class is reported. This
-        behavior is deprecated and will change in version 0.18.
-
     sample_weight : array-like of shape = [n_samples], optional
         Sample weights.
 
diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py
index 576c3a362b29fb1e734223f48578bf1163ff9a19..dc8a7c0686b591686747bf815d552ce8a66bf674 100644
--- a/sklearn/metrics/tests/test_classification.py
+++ b/sklearn/metrics/tests/test_classification.py
@@ -130,10 +130,8 @@ def test_precision_recall_f1_score_binary():
     # individual scoring function that can be used for grid search: in the
     # binary class case the score is the value of the measure for the positive
     # class (e.g. label == 1). This is deprecated for average != 'binary'.
-    assert_dep_warning = partial(assert_warns, DeprecationWarning)
     for kwargs, my_assert in [({}, assert_no_warnings),
-                              ({'average': 'binary'}, assert_no_warnings),
-                              ({'average': 'micro'}, assert_dep_warning)]:
+                              ({'average': 'binary'}, assert_no_warnings)]:
         ps = my_assert(precision_score, y_true, y_pred, **kwargs)
         assert_array_almost_equal(ps, 0.85, 2)
 
@@ -273,13 +271,24 @@ def test_precision_recall_fscore_support_errors():
 
     # Bad pos_label
     assert_raises(ValueError, precision_recall_fscore_support,
-                  y_true, y_pred, pos_label=2, average='macro')
+                  y_true, y_pred, pos_label=2, average='binary')
 
     # Bad average option
     assert_raises(ValueError, precision_recall_fscore_support,
                   [0, 1, 2], [1, 2, 0], average='mega')
 
 
+def test_precision_recall_f_unused_pos_label():
+    # Check warning that pos_label unused when set to non-default value
+    # but average != 'binary'; even if data is binary.
+    assert_warns_message(UserWarning,
+                         "Note that pos_label (set to 2) is "
+                         "ignored when average != 'binary' (got 'macro'). You "
+                         "may use labels=[pos_label] to specify a single "
+                         "positive class.", precision_recall_fscore_support,
+                         [1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
+
+
 def test_confusion_matrix_binary():
     # Test confusion matrix - binary classification case
     y_true, y_pred, _ = make_prediction(binary=True)
@@ -458,17 +467,24 @@ def test_precision_refcall_f1_score_multilabel_unordered_labels():
             assert_array_equal(s, [0, 1, 1, 0])
 
 
-def test_precision_recall_f1_score_multiclass_pos_label_none():
-    # Test Precision Recall and F1 Score for multiclass classification task
-    # GH Issue #1296
-    # initialize data
+def test_precision_recall_f1_score_binary_averaged():
     y_true = np.array([0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1])
     y_pred = np.array([1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1])
 
     # compute scores with default labels introspection
-    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                 pos_label=None,
+    ps, rs, fs, _ = precision_recall_fscore_support(y_true, y_pred,
+                                                    average=None)
+    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred,
                                                  average='macro')
+    assert_equal(p, np.mean(ps))
+    assert_equal(r, np.mean(rs))
+    assert_equal(f, np.mean(fs))
+    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred,
+                                                 average='weighted')
+    support = np.bincount(y_true)
+    assert_equal(p, np.average(ps, weights=support))
+    assert_equal(r, np.average(rs, weights=support))
+    assert_equal(f, np.average(fs, weights=support))
 
 
 def test_zero_precision_recall():
@@ -1041,37 +1057,37 @@ def test_prf_warnings():
                'being set to 0.0 in labels with no true samples.')
         my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average)
 
-        # average of per-sample scores
-        msg = ('Precision and F-score are ill-defined and '
-               'being set to 0.0 in samples with no predicted labels.')
-        my_assert(w, msg, f, np.array([[1, 0], [1, 0]]),
-                  np.array([[1, 0], [0, 0]]), average='samples')
+    # average of per-sample scores
+    msg = ('Precision and F-score are ill-defined and '
+           'being set to 0.0 in samples with no predicted labels.')
+    my_assert(w, msg, f, np.array([[1, 0], [1, 0]]),
+              np.array([[1, 0], [0, 0]]), average='samples')
 
-        msg = ('Recall and F-score are ill-defined and '
-               'being set to 0.0 in samples with no true labels.')
-        my_assert(w, msg, f, np.array([[1, 0], [0, 0]]),
-                  np.array([[1, 0], [1, 0]]),
-                  average='samples')
+    msg = ('Recall and F-score are ill-defined and '
+           'being set to 0.0 in samples with no true labels.')
+    my_assert(w, msg, f, np.array([[1, 0], [0, 0]]),
+              np.array([[1, 0], [1, 0]]),
+              average='samples')
 
-        # single score: micro-average
-        msg = ('Precision and F-score are ill-defined and '
-               'being set to 0.0 due to no predicted samples.')
-        my_assert(w, msg, f, np.array([[1, 1], [1, 1]]),
-                  np.array([[0, 0], [0, 0]]), average='micro')
+    # single score: micro-average
+    msg = ('Precision and F-score are ill-defined and '
+           'being set to 0.0 due to no predicted samples.')
+    my_assert(w, msg, f, np.array([[1, 1], [1, 1]]),
+              np.array([[0, 0], [0, 0]]), average='micro')
 
-        msg = ('Recall and F-score are ill-defined and '
-               'being set to 0.0 due to no true samples.')
-        my_assert(w, msg, f, np.array([[0, 0], [0, 0]]),
-                  np.array([[1, 1], [1, 1]]), average='micro')
+    msg = ('Recall and F-score are ill-defined and '
+           'being set to 0.0 due to no true samples.')
+    my_assert(w, msg, f, np.array([[0, 0], [0, 0]]),
+              np.array([[1, 1], [1, 1]]), average='micro')
 
-        # single postive label
-        msg = ('Precision and F-score are ill-defined and '
-               'being set to 0.0 due to no predicted samples.')
-        my_assert(w, msg, f, [1, 1], [-1, -1], average='macro')
+    # single postive label
+    msg = ('Precision and F-score are ill-defined and '
+           'being set to 0.0 due to no predicted samples.')
+    my_assert(w, msg, f, [1, 1], [-1, -1], average='binary')
 
-        msg = ('Recall and F-score are ill-defined and '
-               'being set to 0.0 due to no true samples.')
-        my_assert(w, msg, f, [-1, -1], [1, 1], average='macro')
+    msg = ('Recall and F-score are ill-defined and '
+           'being set to 0.0 due to no true samples.')
+    my_assert(w, msg, f, [-1, -1], [1, 1], average='binary')
 
 
 def test_recall_warnings():
@@ -1128,32 +1144,23 @@ def test_fscore_warnings():
                          'being set to 0.0 due to no true samples.')
 
 
-def test_prf_average_compat():
-    # Ensure warning if f1_score et al.'s average is implicit for multiclass
-    y_true = [1, 2, 3, 3]
-    y_pred = [1, 2, 3, 1]
-    y_true_bin = [0, 1, 1]
-    y_pred_bin = [0, 1, 0]
-
-    for metric in [precision_score, recall_score, f1_score,
-                   partial(fbeta_score, beta=2)]:
-        score = assert_warns(DeprecationWarning, metric, y_true, y_pred)
-        score_weighted = assert_no_warnings(metric, y_true, y_pred,
-                                            average='weighted')
-        assert_equal(score, score_weighted,
-                     'average does not act like "weighted" by default')
-
-        # check binary passes without warning
-        assert_no_warnings(metric, y_true_bin, y_pred_bin)
-
-        # but binary with pos_label=None should behave like multiclass
-        score = assert_warns(DeprecationWarning, metric,
-                             y_true_bin, y_pred_bin, pos_label=None)
-        score_weighted = assert_no_warnings(metric, y_true_bin, y_pred_bin,
-                                            pos_label=None, average='weighted')
-        assert_equal(score, score_weighted,
-                     'average does not act like "weighted" by default with '
-                     'binary data and pos_label=None')
+def test_prf_average_binary_data_non_binary():
+    # Error if user does not explicitly set non-binary average mode
+    y_true_mc = [1, 2, 3, 3]
+    y_pred_mc = [1, 2, 3, 1]
+    y_true_ind = np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]])
+    y_pred_ind = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
+
+    for y_true, y_pred, y_type in [
+        (y_true_mc, y_pred_mc, 'multiclass'),
+        (y_true_ind, y_pred_ind, 'multilabel-indicator'),
+    ]:
+        for metric in [precision_score, recall_score, f1_score,
+                       partial(fbeta_score, beta=2)]:
+            assert_raise_message(ValueError,
+                                 "Target is %s but average='binary'. Please "
+                                 "choose another average setting." % y_type,
+                                 metric, y_true, y_pred)
 
 
 def test__check_targets():
diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py
index 4974d29356be80f8dfd7cfd9f0cf0de96b2e9186..fa4c7e8d3124bb19ef513540836225f51c27daa5 100644
--- a/sklearn/metrics/tests/test_common.py
+++ b/sklearn/metrics/tests/test_common.py
@@ -217,6 +217,13 @@ METRIC_UNDEFINED_BINARY = [
 METRIC_UNDEFINED_MULTICLASS = [
     "brier_score_loss",
     "matthews_corrcoef_score",
+
+    # with default average='binary', multiclass is prohibited
+    "precision_score",
+    "recall_score",
+    "f1_score",
+    "f2_score",
+    "f0.5_score",
 ]
 
 # Metric undefined with "binary" or "multiclass" input
@@ -303,17 +310,15 @@ MULTILABELS_METRICS = [
     "jaccard_similarity_score", "unnormalized_jaccard_similarity_score",
     "zero_one_loss", "unnormalized_zero_one_loss",
 
-    "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score",
-
     "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score",
     "weighted_precision_score", "weighted_recall_score",
 
-    "micro_f0.5_score", "micro_f1_score", "micro_f2_score",
-    "micro_precision_score", "micro_recall_score",
-
     "macro_f0.5_score", "macro_f1_score", "macro_f2_score",
     "macro_precision_score", "macro_recall_score",
 
+    "micro_f0.5_score", "micro_f1_score", "micro_f2_score",
+    "micro_precision_score", "micro_recall_score",
+
     "samples_f0.5_score", "samples_f1_score", "samples_f2_score",
     "samples_precision_score", "samples_recall_score",
 ]
@@ -332,7 +337,11 @@ SYMMETRIC_METRICS = [
     "jaccard_similarity_score", "unnormalized_jaccard_similarity_score",
     "zero_one_loss", "unnormalized_zero_one_loss",
 
-    "f1_score", "weighted_f1_score", "micro_f1_score", "macro_f1_score",
+    "f1_score", "micro_f1_score", "macro_f1_score",
+    "weighted_recall_score",
+    # P = R = F = accuracy in multiclass case
+    "micro_f0.5_score", "micro_f1_score", "micro_f2_score",
+    "micro_precision_score", "micro_recall_score",
 
     "matthews_corrcoef_score", "mean_absolute_error", "mean_squared_error",
     "median_absolute_error",
@@ -349,11 +358,8 @@ NOT_SYMMETRIC_METRICS = [
 
     "precision_score", "recall_score", "f2_score", "f0.5_score",
 
-    "weighted_f0.5_score", "weighted_f2_score", "weighted_precision_score",
-    "weighted_recall_score",
-
-    "micro_f0.5_score", "micro_f2_score", "micro_precision_score",
-    "micro_recall_score",
+    "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score",
+    "weighted_precision_score",
 
     "macro_f0.5_score", "macro_f2_score", "macro_precision_score",
     "macro_recall_score", "log_loss", "hinge_loss"
@@ -382,7 +388,8 @@ def test_symmetry():
     # We shouldn't forget any metrics
     assert_equal(set(SYMMETRIC_METRICS).union(
         NOT_SYMMETRIC_METRICS, THRESHOLDED_METRICS,
-        METRIC_UNDEFINED_BINARY_MULTICLASS), set(ALL_METRICS))
+        METRIC_UNDEFINED_BINARY_MULTICLASS),
+        set(ALL_METRICS))
 
     assert_equal(
         set(SYMMETRIC_METRICS).intersection(set(NOT_SYMMETRIC_METRICS)),
@@ -1062,6 +1069,7 @@ def test_sample_weight_invariance(n_samples=50):
                    y_pred)
 
 
+@ignore_warnings
 def test_no_averaging_labels():
     # test labels argument when not using averaging
     # in multi-class and multi-label cases
@@ -1075,7 +1083,7 @@ def test_no_averaging_labels():
     for name in METRICS_WITH_AVERAGING:
         for y_true, y_pred in [[y_true_multiclass, y_pred_multiclass],
                                [y_true_multilabel, y_pred_multilabel]]:
-            if name not in MULTILABELS_METRICS and y_pred.shape[1] > 0:
+            if name not in MULTILABELS_METRICS and y_pred.ndim > 1:
                 continue
 
             metric = ALL_METRICS[name]