diff --git a/benchmarks/bench_multilabel_metrics.py b/benchmarks/bench_multilabel_metrics.py
index 669d8604e05f01b21663e6a49faad5ad91cd3404..a7b9374126959237be4cedd803e27721f1a95c64 100755
--- a/benchmarks/bench_multilabel_metrics.py
+++ b/benchmarks/bench_multilabel_metrics.py
@@ -81,11 +81,9 @@ def benchmark(metrics=tuple(v for k, v in sorted(METRICS.items())),
     for i, (s, c, d) in enumerate(it):
         _, y_true = make_multilabel_classification(n_samples=s, n_features=1,
                                                    n_classes=c, n_labels=d * c,
-                                                   return_indicator=True,
                                                    random_state=42)
         _, y_pred = make_multilabel_classification(n_samples=s, n_features=1,
                                                    n_classes=c, n_labels=d * c,
-                                                   return_indicator=True,
                                                    random_state=84)
         for j, f in enumerate(formats):
             f_true = f(y_true)
diff --git a/doc/modules/multiclass.rst b/doc/modules/multiclass.rst
index 3fd3b8e073ae01cc727dd4be6def12a9d7926923..996447e9dbe5f347b3cfe663893930071d5ed3a8 100644
--- a/doc/modules/multiclass.rst
+++ b/doc/modules/multiclass.rst
@@ -90,18 +90,26 @@ zero elements, corresponds to the subset of labels. An array such as
 ``np.array([[1, 0, 0], [0, 1, 1], [0, 0, 0]])`` represents label 0 in the first
 sample, labels 1 and 2 in the second sample, and no labels in the third sample.
 
-Producing multilabel data as a list of sets of labels may be more intuitive.
-The transformer :class:`MultiLabelBinarizer <preprocessing.MultiLabelBinarizer>`
-will convert between a collection of collections of labels and the indicator
-format.
+Using the
+:func:`make_multilabel_classification <sklearn.datasets.make_multilabel_classification>`
+function, a sample multilabel indicator matrix can be obtained.
 
   >>> from sklearn.datasets import make_multilabel_classification
   >>> from sklearn.preprocessing import MultiLabelBinarizer
-  >>> X, Y = make_multilabel_classification(n_samples=5, random_state=0,
-  ...                                       return_indicator=False)
-  >>> Y
-  [[2, 3, 4], [2], [0, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2]]
-  >>> MultiLabelBinarizer().fit_transform(Y)
+  >>> X, y = make_multilabel_classification(n_samples=5, random_state=0)
+  >>> y
+  array([[0, 0, 1, 1, 1],
+         [0, 0, 1, 0, 0],
+         [1, 1, 0, 1, 0],
+         [1, 1, 1, 1, 1],
+         [1, 1, 1, 0, 0]])
+
+The :class:`MultiLabelBinarizer <sklearn.preprocessing.MultiLabelBinarizer>`
+transformer can be used to convert between a collection of collections of
+labels and the indicator format.
+
+  >>> y = [[2, 3, 4], [2], [0, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2]]
+  >>> MultiLabelBinarizer().fit_transform(y)
   array([[0, 0, 1, 1, 1],
          [0, 0, 1, 0, 0],
          [1, 1, 0, 1, 0],
diff --git a/examples/datasets/plot_random_multilabel_dataset.py b/examples/datasets/plot_random_multilabel_dataset.py
index 4137a79bf5630ac027851fa8dc76b7c589aeabfb..683ee1d0ccecddebd48e6c000250dd754c875460 100644
--- a/examples/datasets/plot_random_multilabel_dataset.py
+++ b/examples/datasets/plot_random_multilabel_dataset.py
@@ -61,7 +61,6 @@ def plot_2d(ax, n_labels=1, n_classes=3, length=50):
     X, Y, p_c, p_w_c = make_ml_clf(n_samples=150, n_features=2,
                                    n_classes=n_classes, n_labels=n_labels,
                                    length=length, allow_unlabeled=False,
-                                   return_indicator=True,
                                    return_distributions=True,
                                    random_state=RANDOM_SEED)
 
diff --git a/examples/plot_multilabel.py b/examples/plot_multilabel.py
index d6171ade8c51143788a98d7e9310700441fa45be..e566c73d56a35bc3ae45cb299e493d677e45b793 100644
--- a/examples/plot_multilabel.py
+++ b/examples/plot_multilabel.py
@@ -98,7 +98,6 @@ plt.figure(figsize=(8, 6))
 
 X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                       allow_unlabeled=True,
-                                      return_indicator=True,
                                       random_state=1)
 
 plot_subfigure(X, Y, 1, "With unlabeled samples + CCA", "cca")
@@ -106,7 +105,6 @@ plot_subfigure(X, Y, 2, "With unlabeled samples + PCA", "pca")
 
 X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                       allow_unlabeled=False,
-                                      return_indicator=True,
                                       random_state=1)
 
 plot_subfigure(X, Y, 3, "Without unlabeled samples + CCA", "cca")
diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py
index b43c6bec7a7b6440bec07672861216bfd036b05c..77ec18fefc0f3ef5ae7ab33f47b7a341db4732d4 100644
--- a/sklearn/datasets/samples_generator.py
+++ b/sklearn/datasets/samples_generator.py
@@ -250,8 +250,7 @@ def make_classification(n_samples=100, n_features=20, n_informative=2,
 
 def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5,
                                    n_labels=2, length=50, allow_unlabeled=True,
-                                   sparse=False, return_indicator=False,
-                                   return_distributions=False,
+                                   sparse=False, return_distributions=False,
                                    random_state=None):
     """Generate a random multilabel classification problem.
 
@@ -295,10 +294,6 @@ def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5,
     sparse : bool, optional (default=False)
         If ``True``, return a sparse feature matrix
 
-    return_indicator : bool, optional (default=False),
-        If ``True``, return ``Y`` in the binary indicator format, else
-        return a tuple of lists of labels.
-
     return_distributions : bool, optional (default=False)
         If ``True``, return the prior class probability and conditional
         probabilities of features given classes, from which the data was
@@ -383,16 +378,8 @@ def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5,
     if not sparse:
         X = X.toarray()
 
-    if return_indicator:
-        lb = MultiLabelBinarizer()
-        Y = lb.fit([range(n_classes)]).transform(Y)
-    else:
-        warnings.warn('Support for the sequence of sequences multilabel '
-                      'representation is being deprecated and replaced with '
-                      'a sparse indicator matrix. '
-                      'return_indicator will default to True from version '
-                      '0.17.',
-                      DeprecationWarning)
+    lb = MultiLabelBinarizer()
+    Y = lb.fit([range(n_classes)]).transform(Y)
 
     if return_distributions:
         return X, Y, p_c, p_w_c
diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py
index eb8c996aa53aa5ec3f86af83bd231da500c87f7b..67fab12cb651c2af0f19634af596f2a3cb60d0d8 100644
--- a/sklearn/datasets/tests/test_samples_generator.py
+++ b/sklearn/datasets/tests/test_samples_generator.py
@@ -13,7 +13,6 @@ from sklearn.utils.testing import assert_array_almost_equal
 from sklearn.utils.testing import assert_true
 from sklearn.utils.testing import assert_less
 from sklearn.utils.testing import assert_raises
-from sklearn.utils.testing import assert_warns
 
 from sklearn.datasets import make_classification
 from sklearn.datasets import make_multilabel_classification
@@ -132,23 +131,10 @@ def test_make_classification_informative_features():
                   n_clusters_per_class=2)
 
 
-def test_make_multilabel_classification_return_sequences():
-    for allow_unlabeled, min_length in zip((True, False), (0, 1)):
-        X, Y = assert_warns(DeprecationWarning, make_multilabel_classification,
-                            n_samples=100, n_features=20, n_classes=3,
-                            random_state=0, allow_unlabeled=allow_unlabeled)
-        assert_equal(X.shape, (100, 20), "X shape mismatch")
-        if not allow_unlabeled:
-            assert_equal(max([max(y) for y in Y]), 2)
-        assert_equal(min([len(y) for y in Y]), min_length)
-        assert_true(max([len(y) for y in Y]) <= 3)
-
-
-def test_make_multilabel_classification_return_indicator():
+def test_make_multilabel_classification():
     for allow_unlabeled, min_length in zip((True, False), (0, 1)):
         X, Y = make_multilabel_classification(n_samples=25, n_features=20,
                                               n_classes=3, random_state=0,
-                                              return_indicator=True,
                                               allow_unlabeled=allow_unlabeled)
         assert_equal(X.shape, (25, 20), "X shape mismatch")
         assert_equal(Y.shape, (25, 3), "Y shape mismatch")
@@ -157,8 +143,7 @@ def test_make_multilabel_classification_return_indicator():
     # Also test return_distributions
     X2, Y2, p_c, p_w_c = make_multilabel_classification(
         n_samples=25, n_features=20, n_classes=3, random_state=0,
-        return_indicator=True, allow_unlabeled=allow_unlabeled,
-        return_distributions=True)
+        allow_unlabeled=allow_unlabeled, return_distributions=True)
 
     assert_array_equal(X, X2)
     assert_array_equal(Y, Y2)
diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py
index 60880c416ceb7bd9096ce99ba70a57d43a351d54..e12f52d66d94a3886ff8ab203d54b5033048a0ad 100644
--- a/sklearn/ensemble/tests/test_forest.py
+++ b/sklearn/ensemble/tests/test_forest.py
@@ -474,8 +474,7 @@ def test_random_hasher():
 
 
 def test_random_hasher_sparse_data():
-    X, y = datasets.make_multilabel_classification(return_indicator=True,
-                                                   random_state=0)
+    X, y = datasets.make_multilabel_classification(random_state=0)
     hasher = RandomTreesEmbedding(n_estimators=30, random_state=1)
     X_transformed = hasher.fit_transform(X)
     X_transformed_sparse = hasher.fit_transform(csc_matrix(X))
@@ -662,8 +661,7 @@ def check_sparse_input(name, X, X_sparse, y):
 
 
 def test_sparse_input():
-    X, y = datasets.make_multilabel_classification(return_indicator=True,
-                                                   random_state=0,
+    X, y = datasets.make_multilabel_classification(random_state=0,
                                                    n_samples=40)
 
     for name, sparse_matrix in product(FOREST_ESTIMATORS,
@@ -986,4 +984,4 @@ def test_dtype_convert():
     y = [ch for ch in 'ABCDEFGHIJKLMNOPQRSTU'[:CLASSES]]
 
     result = classifier.fit(X, y).predict(X)
-    assert_array_equal(result, y)
\ No newline at end of file
+    assert_array_equal(result, y)
diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py
index 7d71952c6b427464d7ceb75e276e52401f94a283..fb86d2ec46ea2f76ec8b095ecd49433f563a5dd5 100644
--- a/sklearn/ensemble/tests/test_voting_classifier.py
+++ b/sklearn/ensemble/tests/test_voting_classifier.py
@@ -151,7 +151,6 @@ def test_multilabel():
     """Check if error is raised for multilabel classification."""
     X, y = make_multilabel_classification(n_classes=2, n_labels=1,
                                           allow_unlabeled=False,
-                                          return_indicator=True,
                                           random_state=123)
     clf = OneVsRestClassifier(SVC(kernel='linear'))
 
diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py
index 875b7bade834d8a60c178ec5f75ca436f3eb0294..f58933a0aaed01ee7388eb578909962430591bfe 100755
--- a/sklearn/ensemble/tests/test_weight_boosting.py
+++ b/sklearn/ensemble/tests/test_weight_boosting.py
@@ -316,7 +316,6 @@ def test_sparse_classification():
 
     X, y = datasets.make_multilabel_classification(n_classes=1, n_samples=15,
                                                    n_features=5,
-                                                   return_indicator=True,
                                                    random_state=42)
     # Flatten y to a 1d array
     y = np.ravel(y)
diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py
index 1bf9b25338337d0fe1a668b2d69c5987a3fc416b..f827dcc26040a82fdfc23dcffabac0f7a7b79580 100644
--- a/sklearn/metrics/classification.py
+++ b/sklearn/metrics/classification.py
@@ -32,7 +32,6 @@ from ..preprocessing import LabelBinarizer, label_binarize
 from ..preprocessing import LabelEncoder
 from ..utils import check_array
 from ..utils import check_consistent_length
-from ..preprocessing import MultiLabelBinarizer
 from ..utils import column_or_1d
 from ..utils.multiclass import unique_labels
 from ..utils.multiclass import type_of_target
@@ -62,8 +61,7 @@ def _check_targets(y_true, y_pred):
 
     Returns
     -------
-    type_true : one of {'multilabel-indicator', 'multilabel-sequences', \
-                        'multiclass', 'binary'}
+    type_true : one of {'multilabel-indicator', 'multiclass', 'binary'}
         The type of the true target data, as output by
         ``utils.multiclass.type_of_target``
 
@@ -87,8 +85,7 @@ def _check_targets(y_true, y_pred):
     y_type = y_type.pop()
 
     # No metrics support "multiclass-multioutput" format
-    if (y_type not in ["binary", "multiclass", "multilabel-indicator",
-                       "multilabel-sequences"]):
+    if (y_type not in ["binary", "multiclass", "multilabel-indicator"]):
         raise ValueError("{0} is not supported".format(y_type))
 
     if y_type in ["binary", "multiclass"]:
@@ -96,12 +93,6 @@ def _check_targets(y_true, y_pred):
         y_pred = column_or_1d(y_pred)
 
     if y_type.startswith('multilabel'):
-        if y_type == 'multilabel-sequences':
-            labels = unique_labels(y_true, y_pred)
-            binarizer = MultiLabelBinarizer(classes=labels, sparse_output=True)
-            y_true = binarizer.fit_transform(y_true)
-            y_pred = binarizer.fit_transform(y_pred)
-
         y_true = csr_matrix(y_true)
         y_pred = csr_matrix(y_pred)
         y_type = 'multilabel-indicator'
@@ -985,7 +976,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
         labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
                                                  assume_unique=True)])
 
-    ### Calculate tp_sum, pred_sum, true_sum ###
+    # Calculate tp_sum, pred_sum, true_sum ###
 
     if y_type.startswith('multilabel'):
         sum_axis = 1 if average == 'samples' else 0
@@ -1056,7 +1047,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
         pred_sum = np.array([pred_sum.sum()])
         true_sum = np.array([true_sum.sum()])
 
-    ### Finally, we have all our sufficient statistics. Divide! ###
+    # Finally, we have all our sufficient statistics. Divide! ###
 
     beta2 = beta ** 2
     with np.errstate(divide='ignore', invalid='ignore'):
@@ -1074,7 +1065,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
                    (beta2 * precision + recall))
         f_score[tp_sum == 0] = 0.0
 
-    ## Average the results ##
+    # Average the results ##
 
     if average == 'weighted':
         weights = true_sum
diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py
index 983b321e5b8de78fb8f2f1d07538d900cbe7643d..b980d886cdbe3bbfecd8946f91413a89cf8aec65 100644
--- a/sklearn/metrics/tests/test_classification.py
+++ b/sklearn/metrics/tests/test_classification.py
@@ -10,7 +10,6 @@ from sklearn import datasets
 from sklearn import svm
 
 from sklearn.datasets import make_multilabel_classification
-from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
 from sklearn.preprocessing import label_binarize
 from sklearn.utils.fixes import np_version
 from sklearn.utils.validation import check_random_state
@@ -25,7 +24,6 @@ from sklearn.utils.testing import assert_warns
 from sklearn.utils.testing import assert_no_warnings
 from sklearn.utils.testing import assert_warns_message
 from sklearn.utils.testing import assert_not_equal
-from sklearn.utils.testing import ignore_warnings
 
 from sklearn.metrics import accuracy_score
 from sklearn.metrics import average_precision_score
@@ -114,20 +112,6 @@ def test_multilabel_accuracy_score_subset_accuracy():
     assert_equal(accuracy_score(y1, np.zeros(y1.shape)), 0)
     assert_equal(accuracy_score(y2, np.zeros(y1.shape)), 0)
 
-    with ignore_warnings():  # sequence of sequences is deprecated
-        # List of tuple of label
-        y1 = [(1, 2,), (0, 2,)]
-        y2 = [(2,), (0, 2,)]
-
-        assert_equal(accuracy_score(y1, y2), 0.5)
-        assert_equal(accuracy_score(y1, y1), 1)
-        assert_equal(accuracy_score(y2, y2), 1)
-        assert_equal(accuracy_score(y2, [(), ()]), 0)
-        assert_equal(accuracy_score(y1, y2, normalize=False), 1)
-        assert_equal(accuracy_score(y1, y1, normalize=False), 2)
-        assert_equal(accuracy_score(y2, y2, normalize=False), 2)
-        assert_equal(accuracy_score(y2, [(), ()], normalize=False), 0)
-
 
 def test_precision_recall_f1_score_binary():
     # Test Precision Recall and F1 Score for binary classification task
@@ -161,7 +145,6 @@ def test_precision_recall_f1_score_binary():
                             (1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2)
 
 
-@ignore_warnings
 def test_precision_recall_f_binary_single_class():
     # Test precision, recall and F1 score behave with a single positive or
     # negative class
@@ -175,10 +158,8 @@ def test_precision_recall_f_binary_single_class():
     assert_equal(0., f1_score([-1, -1], [-1, -1]))
 
 
-@ignore_warnings
 def test_precision_recall_f_extra_labels():
-    """Test handling of explicit additional (not in input) labels to PRF
-    """
+    """Test handling of explicit additional (not in input) labels to PRF"""
     y_true = [1, 3, 3, 2]
     y_pred = [1, 1, 3, 2]
     y_true_bin = label_binarize(y_true, classes=np.arange(5))
@@ -216,7 +197,6 @@ def test_precision_recall_f_extra_labels():
                       labels=np.arange(-1, 4), average=average)
 
 
-@ignore_warnings
 def test_precision_recall_f_ignored_labels():
     """Test a subset of labels may be requested for PRF"""
     y_true = [1, 1, 2, 3]
@@ -278,7 +258,6 @@ def test_average_precision_score_tied_values():
     assert_not_equal(average_precision_score(y_true, y_score), 1.)
 
 
-@ignore_warnings
 def test_precision_recall_fscore_support_errors():
     y_true, y_pred, _ = make_prediction(binary=True)
 
@@ -613,15 +592,20 @@ avg / total       0.51      0.53      0.47        75
         assert_equal(report, expected_report)
 
 
-@ignore_warnings  # sequence of sequences is deprecated
 def test_multilabel_classification_report():
     n_classes = 4
     n_samples = 50
-    make_ml = make_multilabel_classification
-    _, y_true_ll = make_ml(n_features=1, n_classes=n_classes, random_state=0,
-                           n_samples=n_samples)
-    _, y_pred_ll = make_ml(n_features=1, n_classes=n_classes, random_state=1,
-                           n_samples=n_samples)
+
+    _, y_true = make_multilabel_classification(n_features=1,
+                                               n_samples=n_samples,
+                                               n_classes=n_classes,
+                                               random_state=0)
+
+    _, y_pred = make_multilabel_classification(n_features=1,
+                                               n_samples=n_samples,
+                                               n_classes=n_classes,
+                                               random_state=1)
+
     expected_report = """\
              precision    recall  f1-score   support
 
@@ -633,14 +617,8 @@ def test_multilabel_classification_report():
 avg / total       0.45      0.51      0.46       104
 """
 
-    lb = MultiLabelBinarizer()
-    lb.fit([range(4)])
-    y_true_bi = lb.transform(y_true_ll)
-    y_pred_bi = lb.transform(y_pred_ll)
-
-    for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]:
-        report = classification_report(y_true, y_pred)
-        assert_equal(report, expected_report)
+    report = classification_report(y_true, y_pred)
+    assert_equal(report, expected_report)
 
 
 def test_multilabel_zero_one_loss_subset():
@@ -656,17 +634,6 @@ def test_multilabel_zero_one_loss_subset():
     assert_equal(zero_one_loss(y1, np.zeros(y1.shape)), 1)
     assert_equal(zero_one_loss(y2, np.zeros(y1.shape)), 1)
 
-    with ignore_warnings():  # sequence of sequences is deprecated
-        # List of tuple of label
-        y1 = [(1, 2,), (0, 2,)]
-        y2 = [(2,), (0, 2,)]
-
-        assert_equal(zero_one_loss(y1, y2), 0.5)
-        assert_equal(zero_one_loss(y1, y1), 0)
-        assert_equal(zero_one_loss(y2, y2), 0)
-        assert_equal(zero_one_loss(y2, [(), ()]), 1)
-        assert_equal(zero_one_loss(y2, [tuple(), (10, )]), 1)
-
 
 def test_multilabel_hamming_loss():
     # Dense label indicator matrix format
@@ -681,19 +648,6 @@ def test_multilabel_hamming_loss():
     assert_equal(hamming_loss(y1, np.zeros(y1.shape)), 4 / 6)
     assert_equal(hamming_loss(y2, np.zeros(y1.shape)), 0.5)
 
-    with ignore_warnings():  # sequence of sequences is deprecated
-        # List of tuple of label
-        y1 = [(1, 2,), (0, 2,)]
-        y2 = [(2,), (0, 2,)]
-
-        assert_equal(hamming_loss(y1, y2), 1 / 6)
-        assert_equal(hamming_loss(y1, y1), 0)
-        assert_equal(hamming_loss(y2, y2), 0)
-        assert_equal(hamming_loss(y2, [(), ()]), 0.75)
-        assert_equal(hamming_loss(y1, [tuple(), (10, )]), 0.625)
-        assert_almost_equal(hamming_loss(y2, [tuple(), (10, )],
-                                         classes=np.arange(11)), 0.1818, 2)
-
 
 def test_multilabel_jaccard_similarity_score():
     # Dense label indicator matrix format
@@ -711,246 +665,202 @@ def test_multilabel_jaccard_similarity_score():
     assert_equal(jaccard_similarity_score(y1, np.zeros(y1.shape)), 0)
     assert_equal(jaccard_similarity_score(y2, np.zeros(y1.shape)), 0)
 
-    with ignore_warnings():  # sequence of sequences is deprecated
-        # List of tuple of label
-        y1 = [(1, 2,), (0, 2,)]
-        y2 = [(2,), (0, 2,)]
 
-        assert_equal(jaccard_similarity_score(y1, y2), 0.75)
-        assert_equal(jaccard_similarity_score(y1, y1), 1)
-        assert_equal(jaccard_similarity_score(y2, y2), 1)
-        assert_equal(jaccard_similarity_score(y2, [(), ()]), 0)
+def test_precision_recall_f1_score_multilabel_1():
+    # Test precision_recall_f1_score on a crafted multilabel example
+    # First crafted example
 
-        # |y3 inter y4 | = [0, 1, 1]
-        # |y3 union y4 | = [2, 1, 3]
-        y3 = [(0,), (1,), (3,)]
-        y4 = [(4,), (4,), (5, 6)]
-        assert_almost_equal(jaccard_similarity_score(y3, y4), 0)
+    y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]])
+    y_pred = np.array([[0, 1, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0]])
 
-        # |y5 inter y6 | = [0, 1, 1]
-        # |y5 union y6 | = [2, 1, 3]
-        y5 = [(0,), (1,), (2, 3)]
-        y6 = [(1,), (1,), (2, 0)]
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
 
-        assert_almost_equal(jaccard_similarity_score(y5, y6), (1 + 1 / 3) / 3)
+    # tp = [0, 1, 1, 0]
+    # fn = [1, 0, 0, 1]
+    # fp = [1, 1, 0, 0]
+    # Check per class
 
+    assert_array_almost_equal(p, [0.0, 0.5, 1.0, 0.0], 2)
+    assert_array_almost_equal(r, [0.0, 1.0, 1.0, 0.0], 2)
+    assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
+    assert_array_almost_equal(s, [1, 1, 1, 1], 2)
 
-@ignore_warnings
-def test_precision_recall_f1_score_multilabel_1():
-    # Test precision_recall_f1_score on a crafted multilabel example
-    # First crafted example
-    y_true_ll = [(0,), (1,), (2, 3)]
-    y_pred_ll = [(1,), (1,), (2, 0)]
-    lb = LabelBinarizer()
-    lb.fit([range(4)])
-    y_true_bi = lb.transform(y_true_ll)
-    y_pred_bi = lb.transform(y_pred_ll)
-
-    for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]:
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average=None)
-        #tp = [0, 1, 1, 0]
-        #fn = [1, 0, 0, 1]
-        #fp = [1, 1, 0, 0]
-        # Check per class
-
-        assert_array_almost_equal(p, [0.0, 0.5, 1.0, 0.0], 2)
-        assert_array_almost_equal(r, [0.0, 1.0, 1.0, 0.0], 2)
-        assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
-        assert_array_almost_equal(s, [1, 1, 1, 1], 2)
-
-        f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
-        support = s
-        assert_array_almost_equal(f2, [0, 0.83, 1, 0], 2)
-
-        # Check macro
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="macro")
-        assert_almost_equal(p, 1.5 / 4)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, 2.5 / 1.5 * 0.25)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="macro"),
-                            np.mean(f2))
-
-        # Check micro
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="micro")
-        assert_almost_equal(p, 0.5)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, 0.5)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="micro"),
-                            (1 + 4) * p * r / (4 * p + r))
-
-        # Check weigted
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="weighted")
-        assert_almost_equal(p, 1.5 / 4)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, 2.5 / 1.5 * 0.25)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="weighted"),
-                            np.average(f2, weights=support))
-        # Check weigted
-        # |h(x_i) inter y_i | = [0, 1, 1]
-        # |y_i| = [1, 1, 2]
-        # |h(x_i)| = [1, 1, 2]
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="samples")
-        assert_almost_equal(p, 0.5)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, 0.5)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="samples"),
-                            0.5)
+    f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
+    support = s
+    assert_array_almost_equal(f2, [0, 0.83, 1, 0], 2)
+
+    # Check macro
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="macro")
+    assert_almost_equal(p, 1.5 / 4)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, 2.5 / 1.5 * 0.25)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, average="macro"),
+                        np.mean(f2))
+
+    # Check micro
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="micro")
+    assert_almost_equal(p, 0.5)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, 0.5)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="micro"),
+                        (1 + 4) * p * r / (4 * p + r))
+
+    # Check weigted
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="weighted")
+    assert_almost_equal(p, 1.5 / 4)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, 2.5 / 1.5 * 0.25)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="weighted"),
+                        np.average(f2, weights=support))
+    # Check weigted
+    # |h(x_i) inter y_i | = [0, 1, 1]
+    # |y_i| = [1, 1, 2]
+    # |h(x_i)| = [1, 1, 2]
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="samples")
+    assert_almost_equal(p, 0.5)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, 0.5)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, average="samples"),
+                        0.5)
 
 
-@ignore_warnings
 def test_precision_recall_f1_score_multilabel_2():
     # Test precision_recall_f1_score on a crafted multilabel example 2
     # Second crafted example
-    y_true_ll = [(1,), (2,), (2, 3)]
-    y_pred_ll = [(4,), (4,), (2, 1)]
-    lb = LabelBinarizer()
-    lb.fit([range(1, 5)])
-    y_true_bi = lb.transform(y_true_ll)
-    y_pred_bi = lb.transform(y_pred_ll)
-
-    for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]:
-        # tp = [ 0.  1.  0.  0.]
-        # fp = [ 1.  0.  0.  2.]
-        # fn = [ 1.  1.  1.  0.]
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average=None)
-        assert_array_almost_equal(p, [0.0, 1.0, 0.0, 0.0], 2)
-        assert_array_almost_equal(r, [0.0, 0.5, 0.0, 0.0], 2)
-        assert_array_almost_equal(f, [0.0, 0.66, 0.0, 0.0], 2)
-        assert_array_almost_equal(s, [1, 2, 1, 0], 2)
-
-        f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
-        support = s
-        assert_array_almost_equal(f2, [0, 0.55, 0, 0], 2)
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="micro")
-        assert_almost_equal(p, 0.25)
-        assert_almost_equal(r, 0.25)
-        assert_almost_equal(f, 2 * 0.25 * 0.25 / 0.5)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="micro"),
-                            (1 + 4) * p * r / (4 * p + r))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="macro")
-        assert_almost_equal(p, 0.25)
-        assert_almost_equal(r, 0.125)
-        assert_almost_equal(f, 2 / 12)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="macro"),
-                            np.mean(f2))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="weighted")
-        assert_almost_equal(p, 2 / 4)
-        assert_almost_equal(r, 1 / 4)
-        assert_almost_equal(f, 2 / 3 * 2 / 4)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="weighted"),
-                            np.average(f2, weights=support))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="samples")
-        # Check weigted
-        # |h(x_i) inter y_i | = [0, 0, 1]
-        # |y_i| = [1, 1, 2]
-        # |h(x_i)| = [1, 1, 2]
-
-        assert_almost_equal(p, 1 / 6)
-        assert_almost_equal(r, 1 / 6)
-        assert_almost_equal(f, 2 / 4 * 1 / 3)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="samples"),
-                            0.1666, 2)
+    y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 1, 1, 0]])
+    y_pred = np.array([[0, 0, 0, 1], [0, 0, 0, 1], [1, 1, 0, 0]])
+
+    # tp = [ 0.  1.  0.  0.]
+    # fp = [ 1.  0.  0.  2.]
+    # fn = [ 1.  1.  1.  0.]
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average=None)
+    assert_array_almost_equal(p, [0.0, 1.0, 0.0, 0.0], 2)
+    assert_array_almost_equal(r, [0.0, 0.5, 0.0, 0.0], 2)
+    assert_array_almost_equal(f, [0.0, 0.66, 0.0, 0.0], 2)
+    assert_array_almost_equal(s, [1, 2, 1, 0], 2)
+
+    f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
+    support = s
+    assert_array_almost_equal(f2, [0, 0.55, 0, 0], 2)
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="micro")
+    assert_almost_equal(p, 0.25)
+    assert_almost_equal(r, 0.25)
+    assert_almost_equal(f, 2 * 0.25 * 0.25 / 0.5)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="micro"),
+                        (1 + 4) * p * r / (4 * p + r))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="macro")
+    assert_almost_equal(p, 0.25)
+    assert_almost_equal(r, 0.125)
+    assert_almost_equal(f, 2 / 12)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="macro"),
+                        np.mean(f2))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="weighted")
+    assert_almost_equal(p, 2 / 4)
+    assert_almost_equal(r, 1 / 4)
+    assert_almost_equal(f, 2 / 3 * 2 / 4)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="weighted"),
+                        np.average(f2, weights=support))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="samples")
+    # Check weigted
+    # |h(x_i) inter y_i | = [0, 0, 1]
+    # |y_i| = [1, 1, 2]
+    # |h(x_i)| = [1, 1, 2]
+
+    assert_almost_equal(p, 1 / 6)
+    assert_almost_equal(r, 1 / 6)
+    assert_almost_equal(f, 2 / 4 * 1 / 3)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="samples"),
+                        0.1666, 2)
 
 
-@ignore_warnings
 def test_precision_recall_f1_score_with_an_empty_prediction():
-    y_true_ll = [(1,), (0,), (2, 1,)]
-    y_pred_ll = [tuple(), (3,), (2, 1)]
-
-    lb = LabelBinarizer()
-    lb.fit([range(4)])
-    y_true_bi = lb.transform(y_true_ll)
-    y_pred_bi = lb.transform(y_pred_ll)
-
-    for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]:
-        # true_pos = [ 0.  1.  1.  0.]
-        # false_pos = [ 0.  0.  0.  1.]
-        # false_neg = [ 1.  1.  0.  0.]
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average=None)
-        assert_array_almost_equal(p, [0.0, 1.0, 1.0, 0.0], 2)
-        assert_array_almost_equal(r, [0.0, 0.5, 1.0, 0.0], 2)
-        assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
-        assert_array_almost_equal(s, [1, 2, 1, 0], 2)
-
-        f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
-        support = s
-        assert_array_almost_equal(f2, [0, 0.55, 1, 0], 2)
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="macro")
-        assert_almost_equal(p, 0.5)
-        assert_almost_equal(r, 1.5 / 4)
-        assert_almost_equal(f, 2.5 / (4 * 1.5))
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="macro"),
-                            np.mean(f2))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="micro")
-        assert_almost_equal(p, 2 / 3)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, 2 / 3 / (2 / 3 + 0.5))
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="micro"),
-                            (1 + 4) * p * r / (4 * p + r))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="weighted")
-        assert_almost_equal(p, 3 / 4)
-        assert_almost_equal(r, 0.5)
-        assert_almost_equal(f, (2 / 1.5 + 1) / 4)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="weighted"),
-                            np.average(f2, weights=support))
-
-        p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
-                                                     average="samples")
-        # |h(x_i) inter y_i | = [0, 0, 2]
-        # |y_i| = [1, 1, 2]
-        # |h(x_i)| = [0, 1, 2]
-        assert_almost_equal(p, 1 / 3)
-        assert_almost_equal(r, 1 / 3)
-        assert_almost_equal(f, 1 / 3)
-        assert_equal(s, None)
-        assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
-                                        average="samples"),
-                            0.333, 2)
+    y_true = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 1, 0]])
+    y_pred = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 0]])
+
+    # true_pos = [ 0.  1.  1.  0.]
+    # false_pos = [ 0.  0.  0.  1.]
+    # false_neg = [ 1.  1.  0.  0.]
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average=None)
+    assert_array_almost_equal(p, [0.0, 1.0, 1.0, 0.0], 2)
+    assert_array_almost_equal(r, [0.0, 0.5, 1.0, 0.0], 2)
+    assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
+    assert_array_almost_equal(s, [1, 2, 1, 0], 2)
+
+    f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
+    support = s
+    assert_array_almost_equal(f2, [0, 0.55, 1, 0], 2)
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="macro")
+    assert_almost_equal(p, 0.5)
+    assert_almost_equal(r, 1.5 / 4)
+    assert_almost_equal(f, 2.5 / (4 * 1.5))
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="macro"),
+                        np.mean(f2))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="micro")
+    assert_almost_equal(p, 2 / 3)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, 2 / 3 / (2 / 3 + 0.5))
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="micro"),
+                        (1 + 4) * p * r / (4 * p + r))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="weighted")
+    assert_almost_equal(p, 3 / 4)
+    assert_almost_equal(r, 0.5)
+    assert_almost_equal(f, (2 / 1.5 + 1) / 4)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="weighted"),
+                        np.average(f2, weights=support))
+
+    p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
+                                                 average="samples")
+    # |h(x_i) inter y_i | = [0, 0, 2]
+    # |y_i| = [1, 1, 2]
+    # |h(x_i)| = [0, 1, 2]
+    assert_almost_equal(p, 1 / 3)
+    assert_almost_equal(r, 1 / 3)
+    assert_almost_equal(f, 1 / 3)
+    assert_equal(s, None)
+    assert_almost_equal(fbeta_score(y_true, y_pred, beta=2,
+                                    average="samples"),
+                        0.333, 2)
 
 
 def test_precision_recall_f1_no_labels():
@@ -995,7 +905,6 @@ def test_precision_recall_f1_no_labels():
 
 
 def test_prf_warnings():
-
     # average of per-label scores
     f, w = precision_recall_fscore_support, UndefinedMetricWarning
     my_assert = assert_warns_message
@@ -1123,12 +1032,10 @@ def test_prf_average_compat():
                      'binary data and pos_label=None')
 
 
-@ignore_warnings  # sequence of sequences is deprecated
 def test__check_targets():
     # Check that _check_targets correctly merges target types, squeezes
     # output and fails if input lengths differ.
     IND = 'multilabel-indicator'
-    SEQ = 'multilabel-sequences'
     MC = 'multiclass'
     BIN = 'binary'
     CNT = 'continuous'
@@ -1139,7 +1046,6 @@ def test__check_targets():
         (IND, np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]])),
         # must not be considered binary
         (IND, np.array([[0, 1], [1, 0], [1, 1]])),
-        (SEQ, [[2, 3], [1], [3]]),
         (MC, [2, 3, 1]),
         (BIN, [0, 1, 1]),
         (CNT, [0., 1.5, 1.]),
@@ -1153,13 +1059,9 @@ def test__check_targets():
     # (types will be tried in either order)
     EXPECTED = {
         (IND, IND): IND,
-        (SEQ, SEQ): IND,
         (MC, MC): MC,
         (BIN, BIN): BIN,
 
-        (IND, SEQ): None,
-        (MC, SEQ): None,
-        (BIN, SEQ): None,
         (MC, IND): None,
         (BIN, IND): None,
         (BIN, MC): MC,
@@ -1169,18 +1071,15 @@ def test__check_targets():
         (MMC, MMC): None,
         (MCN, MCN): None,
         (IND, CNT): None,
-        (SEQ, CNT): None,
         (MC, CNT): None,
         (BIN, CNT): None,
         (MMC, CNT): None,
         (MCN, CNT): None,
         (IND, MMC): None,
-        (SEQ, MMC): None,
         (MC, MMC): None,
         (BIN, MMC): None,
         (MCN, MMC): None,
         (IND, MCN): None,
-        (SEQ, MCN): None,
         (MC, MCN): None,
         (BIN, MCN): None,
     }
@@ -1200,7 +1099,7 @@ def test__check_targets():
                     _check_targets, y1, y2)
 
             else:
-                if type1 not in (BIN, MC, SEQ, IND):
+                if type1 not in (BIN, MC, IND):
                     assert_raise_message(ValueError,
                                          "{0} is not supported".format(type1),
                                          _check_targets, y1, y2)
@@ -1216,6 +1115,12 @@ def test__check_targets():
                 assert_array_equal(y2out, np.squeeze(y2))
             assert_raises(ValueError, _check_targets, y1[:-1], y2)
 
+    # Make sure seq of seq is not supported
+    y1 = [(1, 2,), (0, 2,)]
+    y2 = [(2,), (0, 2,)]
+    assert_raise_message(ValueError, "unknown is not supported",
+                         _check_targets, y1, y2)
+
 
 def test_hinge_loss_binary():
     y_true = np.array([-1, 1, 1, -1])
diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py
index b7662ecf11b9d4d4134c1ae61e89db9fa84cbc04..091c0592570a6ba21169819fa09b1459997efd35 100644
--- a/sklearn/metrics/tests/test_score_objects.py
+++ b/sklearn/metrics/tests/test_score_objects.py
@@ -236,8 +236,7 @@ def test_thresholded_scorers():
 def test_thresholded_scorers_multilabel_indicator_data():
     # Test that the scorer work with multilabel-indicator format
     # for multilabel and multi-output multi-class classifier
-    X, y = make_multilabel_classification(return_indicator=True,
-                                          allow_unlabeled=False,
+    X, y = make_multilabel_classification(allow_unlabeled=False,
                                           random_state=0)
     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
 
@@ -311,7 +310,6 @@ def test_scorer_sample_weight():
     # scores really should be unequal.
     X, y = make_classification(random_state=0)
     _, y_ml = make_multilabel_classification(n_samples=X.shape[0],
-                                             return_indicator=True,
                                              random_state=0)
     split = train_test_split(X, y, y_ml, random_state=0)
     X_train, X_test, y_train, y_test, y_ml_train, y_ml_test = split
diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py
index 23db8425fc06fa08c33fc72a22c994bfb8648c42..d64e2a6ae3d733487d1859fe3fe7d05a4fe24844 100644
--- a/sklearn/preprocessing/label.py
+++ b/sklearn/preprocessing/label.py
@@ -9,7 +9,6 @@
 from collections import defaultdict
 import itertools
 import array
-import warnings
 
 import numpy as np
 import scipy.sparse as sp
@@ -20,8 +19,9 @@ from ..utils.fixes import np_version
 from ..utils.fixes import sparse_min_max
 from ..utils.fixes import astype
 from ..utils.fixes import in1d
-from ..utils import deprecated, column_or_1d
+from ..utils import column_or_1d
 from ..utils.validation import check_array
+from ..utils.validation import check_is_fitted
 from ..utils.validation import _num_samples
 from ..utils.multiclass import unique_labels
 from ..utils.multiclass import type_of_target
@@ -215,8 +215,7 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         Represents the type of the target data as evaluated by
         utils.multiclass.type_of_target. Possible type are 'continuous',
         'continuous-multioutput', 'binary', 'multiclass',
-        'mutliclass-multioutput', 'multilabel-sequences',
-        'multilabel-indicator', and 'unknown'.
+        'mutliclass-multioutput', 'multilabel-indicator', and 'unknown'.
 
     multilabel_ : boolean
         True if the transformer was fitted on a multilabel rather than a
@@ -288,20 +287,6 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         self.pos_label = pos_label
         self.sparse_output = sparse_output
 
-    @property
-    @deprecated("Attribute ``indicator_matrix_`` is deprecated and will be "
-                "removed in 0.17. Use ``y_type_ == 'multilabel-indicator'`` "
-                "instead")
-    def indicator_matrix_(self):
-        return self.y_type_ == 'multilabel-indicator'
-
-    @property
-    @deprecated("Attribute ``multilabel_`` is deprecated and will be removed "
-                "in 0.17. Use ``y_type_.startswith('multilabel')`` "
-                "instead")
-    def multilabel_(self):
-        return self.y_type_.startswith('multilabel')
-
     def _check_fitted(self):
         if not hasattr(self, "classes_"):
             raise ValueError("LabelBinarizer was not fitted yet.")
@@ -348,7 +333,7 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         Y : numpy array or CSR matrix of shape [n_samples, n_classes]
             Shape will be [n_samples, 1] for binary problems.
         """
-        self._check_fitted()
+        check_is_fitted(self, 'classes_')
 
         y_is_multilabel = type_of_target(y).startswith('multilabel')
         if y_is_multilabel and not self.y_type_.startswith('multilabel'):
@@ -411,8 +396,7 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         return y_inv
 
 
-def label_binarize(y, classes, neg_label=0, pos_label=1,
-                   sparse_output=False, multilabel=None):
+def label_binarize(y, classes, neg_label=0, pos_label=1, sparse_output=False):
     """Binarize labels in a one-vs-all fashion
 
     Several regression and binary classification algorithms are
@@ -488,12 +472,6 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
                          "pos_label={0} and neg_label={1}"
                          "".format(pos_label, neg_label))
 
-    if multilabel is not None:
-        warnings.warn("The multilabel parameter is deprecated as of version "
-                      "0.15 and will be removed in 0.17. The parameter is no "
-                      "longer necessary because the value is automatically "
-                      "inferred.", DeprecationWarning)
-
     # To account for pos_label == 0 in the dense case
     pos_switch = pos_label == 0
     if pos_switch:
@@ -542,16 +520,6 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
             data.fill(pos_label)
             Y.data = data
 
-    elif y_type == "multilabel-sequences":
-        Y = MultiLabelBinarizer(classes=classes,
-                                sparse_output=sparse_output).fit_transform(y)
-
-        if sp.issparse(Y):
-            Y.data[:] = pos_label
-        else:
-            Y[Y == 1] = pos_label
-        return Y
-
     if not sparse_output:
         Y = Y.toarray()
         Y = astype(Y, int, copy=False)
@@ -664,15 +632,6 @@ def _inverse_binarize_thresholding(y, output_type, classes, threshold):
     elif output_type == "multilabel-indicator":
         return y
 
-    elif output_type == "multilabel-sequences":
-        warnings.warn('Direct support for sequence of sequences multilabel '
-                      'representation will be unavailable from version 0.17. '
-                      'Use sklearn.preprocessing.MultiLabelBinarizer to '
-                      'convert to a label indicator representation.',
-                      DeprecationWarning)
-        mlb = MultiLabelBinarizer(classes=classes).fit([])
-        return mlb.inverse_transform(y)
-
     else:
         raise ValueError("{0} format is not supported".format(output_type))
 
diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py
index d692fbf3881d9306ca9213a18cd98bfdb469c8f0..788d5f19e10d785fb8a06c2437b001f72d9637b0 100644
--- a/sklearn/preprocessing/tests/test_label.py
+++ b/sklearn/preprocessing/tests/test_label.py
@@ -89,43 +89,6 @@ def test_label_binarizer_unseen_labels():
     assert_array_equal(expected, got)
 
 
-@ignore_warnings
-def test_label_binarizer_column_y():
-    # first for binary classification vs multi-label with 1 possible class
-    # lists are multi-label, array is multi-class :-/
-    inp_list = [[1], [2], [1]]
-    inp_array = np.array(inp_list)
-
-    multilabel_indicator = np.array([[1, 0], [0, 1], [1, 0]])
-    binaryclass_array = np.array([[0], [1], [0]])
-
-    lb_1 = LabelBinarizer()
-    out_1 = lb_1.fit_transform(inp_list)
-
-    lb_2 = LabelBinarizer()
-    out_2 = lb_2.fit_transform(inp_array)
-
-    assert_array_equal(out_1, multilabel_indicator)
-    assert_array_equal(out_2, binaryclass_array)
-
-    # second for multiclass classification vs multi-label with multiple
-    # classes
-    inp_list = [[1], [2], [1], [3]]
-    inp_array = np.array(inp_list)
-
-    # the indicator matrix output is the same in this case
-    indicator = np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])
-
-    lb_1 = LabelBinarizer()
-    out_1 = lb_1.fit_transform(inp_list)
-
-    lb_2 = LabelBinarizer()
-    out_2 = lb_2.fit_transform(inp_array)
-
-    assert_array_equal(out_1, out_2)
-    assert_array_equal(out_2, indicator)
-
-
 def test_label_binarizer_set_label_encoding():
     lb = LabelBinarizer(neg_label=-2, pos_label=0)
 
@@ -174,6 +137,10 @@ def test_label_binarizer_errors():
                   y=csr_matrix([[1, 2], [2, 1]]), output_type="foo",
                   classes=[1, 2], threshold=0)
 
+    # Sequence of seq type should raise ValueError
+    y_seq_of_seqs = [[], [1, 2], [3], [0, 1, 3], [2]]
+    assert_raises(ValueError, LabelBinarizer().fit_transform, y_seq_of_seqs)
+
     # Fail on the number of classes
     assert_raises(ValueError, _inverse_binarize_thresholding,
                   y=csr_matrix([[1, 2], [2, 1]]), output_type="foo",
diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py
index 59796b10eae2c4c8eabab591af6c616f75f32d79..b33e2b4c279d50d8c441a17a889dc4f516e67977 100644
--- a/sklearn/tests/test_cross_validation.py
+++ b/sklearn/tests/test_cross_validation.py
@@ -37,7 +37,7 @@ from sklearn.neighbors import KNeighborsClassifier
 from sklearn.svm import SVC
 from sklearn.cluster import KMeans
 
-from sklearn.preprocessing import Imputer, LabelBinarizer
+from sklearn.preprocessing import Imputer
 from sklearn.pipeline import Pipeline
 
 
@@ -962,15 +962,8 @@ def test_check_cv_return_types():
     assert_true(isinstance(cv, cval.StratifiedKFold))
 
     X = np.ones((5, 2))
-    y_seq_of_seqs = [[], [1, 2], [3], [0, 1, 3], [2]]
-
-    with warnings.catch_warnings(record=True):
-        # deprecated sequence of sequence format
-        cv = cval.check_cv(3, X, y_seq_of_seqs, classifier=True)
-    assert_true(isinstance(cv, cval.KFold))
-
-    y_indicator_matrix = LabelBinarizer().fit_transform(y_seq_of_seqs)
-    cv = cval.check_cv(3, X, y_indicator_matrix, classifier=True)
+    y_multilabel = [[1, 0, 1], [1, 1, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]]
+    cv = cval.check_cv(3, X, y_multilabel, classifier=True)
     assert_true(isinstance(cv, cval.KFold))
 
     y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py
index 8d9e5b601d2f9988d81eeb1b6e81f253455845d2..bbe248c83356941c524b220d58b2c180c773516b 100644
--- a/sklearn/tests/test_grid_search.py
+++ b/sklearn/tests/test_grid_search.py
@@ -629,8 +629,7 @@ def test_pickle():
 def test_grid_search_with_multioutput_data():
     # Test search with multi-output estimator
 
-    X, y = make_multilabel_classification(return_indicator=True,
-                                          random_state=0)
+    X, y = make_multilabel_classification(random_state=0)
 
     est_parameters = {"max_depth": [1, 2, 3, 4]}
     cv = KFold(y.shape[0], random_state=0)
diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py
index 39b1547ea6284e659b89c59aef17d22ba7d4e8c7..965f8eb416f40356bc5436202cdf7f5a40f3fe5b 100644
--- a/sklearn/tests/test_multiclass.py
+++ b/sklearn/tests/test_multiclass.py
@@ -107,7 +107,6 @@ def test_ovr_fit_predict_sparse():
                                                        n_labels=3,
                                                        length=50,
                                                        allow_unlabeled=True,
-                                                       return_indicator=True,
                                                        random_state=0)
 
         X_train, Y_train = X[:80], Y[:80]
@@ -230,35 +229,19 @@ def test_ovr_binary():
         conduct_test(base_clf, test_predict_proba=True)
 
 
-@ignore_warnings
 def test_ovr_multilabel():
     # Toy dataset where features correspond directly to labels.
     X = np.array([[0, 4, 5], [0, 5, 0], [3, 3, 3], [4, 0, 6], [6, 0, 0]])
-    y = [["spam", "eggs"], ["spam"], ["ham", "eggs", "spam"],
-         ["ham", "eggs"], ["ham"]]
-    # y = [[1, 2], [1], [0, 1, 2], [0, 2], [0]]
-    Y = np.array([[0, 1, 1],
+    y = np.array([[0, 1, 1],
                   [0, 1, 0],
                   [1, 1, 1],
                   [1, 0, 1],
                   [1, 0, 0]])
 
-    classes = set("ham eggs spam".split())
-
     for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
                      LinearRegression(), Ridge(),
                      ElasticNet(), Lasso(alpha=0.5)):
-        # test input as lists of tuples
-        clf = assert_warns(DeprecationWarning,
-                           OneVsRestClassifier(base_clf).fit,
-                           X, y)
-        assert_equal(set(clf.classes_), classes)
-        y_pred = clf.predict([[0, 4, 4]])[0]
-        assert_equal(set(y_pred), set(["spam", "eggs"]))
-        assert_true(clf.multilabel_)
-
-        # test input as label indicator matrix
-        clf = OneVsRestClassifier(base_clf).fit(X, Y)
+        clf = OneVsRestClassifier(base_clf).fit(X, y)
         y_pred = clf.predict([[0, 4, 4]])[0]
         assert_array_equal(y_pred, [0, 1, 1])
         assert_true(clf.multilabel_)
@@ -280,7 +263,6 @@ def test_ovr_multilabel_dataset():
                                                        n_labels=2,
                                                        length=50,
                                                        allow_unlabeled=au,
-                                                       return_indicator=True,
                                                        random_state=0)
         X_train, Y_train = X[:80], Y[:80]
         X_test, Y_test = X[80:], Y[80:]
@@ -305,7 +287,6 @@ def test_ovr_multilabel_predict_proba():
                                                        n_labels=3,
                                                        length=50,
                                                        allow_unlabeled=au,
-                                                       return_indicator=True,
                                                        random_state=0)
         X_train, Y_train = X[:80], Y[:80]
         X_test = X[80:]
@@ -357,7 +338,6 @@ def test_ovr_multilabel_decision_function():
                                                    n_labels=3,
                                                    length=50,
                                                    allow_unlabeled=True,
-                                                   return_indicator=True,
                                                    random_state=0)
     X_train, Y_train = X[:80], Y[:80]
     X_test = X[80:]
diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py
index 9b6a014eba58123a4a4713009ba57d90682b089d..3cdae723e873f5ae1e91a6159a115d1176dab1db 100644
--- a/sklearn/tree/tests/test_tree.py
+++ b/sklearn/tree/tests/test_tree.py
@@ -126,7 +126,7 @@ digits.target = digits.target[perm]
 
 random_state = check_random_state(0)
 X_multilabel, y_multilabel = datasets.make_multilabel_classification(
-    random_state=0, return_indicator=True, n_samples=30, n_features=10)
+    random_state=0, n_samples=30, n_features=10)
 
 X_sparse_pos = random_state.uniform(size=(20, 5))
 X_sparse_pos[X_sparse_pos <= 0.8] = 0.
diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py
index 989b98a878c8ffe05507ff274b452eeba56d05f2..816d636cabb6ca9450a1c7b2e0876376a00f3a17 100644
--- a/sklearn/utils/tests/test_multiclass.py
+++ b/sklearn/utils/tests/test_multiclass.py
@@ -2,8 +2,6 @@ from __future__ import division
 import numpy as np
 import scipy.sparse as sp
 
-from itertools import product
-from functools import partial
 from sklearn.externals.six.moves import xrange
 from sklearn.externals.six import iteritems
 
@@ -20,15 +18,13 @@ from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_true
 from sklearn.utils.testing import assert_false
 from sklearn.utils.testing import assert_raises
-from sklearn.utils.testing import assert_warns
-from sklearn.utils.testing import ignore_warnings
 
 from sklearn.utils.multiclass import unique_labels
 from sklearn.utils.multiclass import is_label_indicator_matrix
 from sklearn.utils.multiclass import is_multilabel
-from sklearn.utils.multiclass import is_sequence_of_sequences
 from sklearn.utils.multiclass import type_of_target
 from sklearn.utils.multiclass import class_distribution
+from sklearn.utils.multiclass import _is_sequence_of_sequences
 
 
 class NotAnArray(object):
@@ -60,18 +56,6 @@ EXAMPLES = {
         np.array([[-3, 3], [3, -3]]),
         NotAnArray(np.array([[-3, 3], [3, -3]])),
     ],
-    'multilabel-sequences': [
-        [[0, 1]],
-        [[0], [1]],
-        [[1, 2, 3]],
-        [[1, 2, 1]],  # duplicate values, why not?
-        [[1], [2], [0, 1]],
-        [[1], [2]],
-        [[]],
-        [()],
-        np.array([[], [1, 2]], dtype='object'),
-        NotAnArray(np.array([[], [1, 2]], dtype='object')),
-    ],
     'multiclass': [
         [1, 0, 2, 2, 1, 4, 2, 4, 4, 4],
         np.array([1, 0, 2]),
@@ -133,21 +117,41 @@ EXAMPLES = {
         np.array([[0, .5]]),
     ],
     'unknown': [
-        # empty second dimension
-        np.array([[], []]),
-        # 3d
-        np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),
-        # not currently supported sequence of sequences
+        # multilabel sequences
+        [[0, 1]],
+        [[0], [1]],
+        [[1, 2, 3]],
+        [[1, 2, 1]],  # duplicated label in seq. of seq
+        [[1], [2], [0, 1]],
+        [(), (2), (0, 1)],
+        [[]],
+        [()],
+        np.array([[], [1, 2]], dtype='object'),
+        NotAnArray(np.array([[], [1, 2]], dtype='object')),
+
+        # NOTE: First 10 items are of sequence of sequence type that were
+        # previously supported. This list is split based on this index
+        # of 10 in test_is_sequence_of_sequences.
+
+        # Hence, PLEASE ADD FURTHER UNKNOWN TYPES AFTER THESE 10 ENTRIES.
+
+        # sequence of sequences that were'nt supported even before deprecation
         np.array([np.array([]), np.array([1, 2, 3])], dtype=object),
         [np.array([]), np.array([1, 2, 3])],
         [set([1, 2, 3]), set([1, 2])],
         [frozenset([1, 2, 3]), frozenset([1, 2])],
+
         # and also confusable as sequences of sequences
         [{0: 'a', 1: 'b'}, {0: 'a'}],
+
+        # empty second dimension
+        np.array([[], []]),
+
+        # 3d
+        np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),
     ]
 }
 
-
 NON_ARRAY_LIKE_EXAMPLES = [
     set([1, 2, 3]),
     {0: 'a', 1: 'b'},
@@ -167,16 +171,7 @@ def test_unique_labels():
     assert_array_equal(unique_labels(np.arange(10)), np.arange(10))
     assert_array_equal(unique_labels([4, 0, 2]), np.array([0, 2, 4]))
 
-    # Multilabels
-    assert_array_equal(assert_warns(DeprecationWarning,
-                                    unique_labels,
-                                    [(0, 1, 2), (0,), tuple(), (2, 1)]),
-                       np.arange(3))
-    assert_array_equal(assert_warns(DeprecationWarning,
-                                    unique_labels,
-                                    [[0, 1, 2], [0], list(), [2, 1]]),
-                       np.arange(3))
-
+    # Multilabel indicator
     assert_array_equal(unique_labels(np.array([[0, 0, 1],
                                                [1, 0, 1],
                                                [0, 0, 0]])),
@@ -198,22 +193,12 @@ def test_unique_labels():
     assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))),
                        np.arange(5))
 
-    # Some tests with strings input
-    assert_array_equal(unique_labels(["a", "b", "c"], ["d"]),
-                       ["a", "b", "c", "d"])
-
-    assert_array_equal(assert_warns(DeprecationWarning, unique_labels,
-                                    [["a", "b"], ["c"]], [["d"]]),
-                       ["a", "b", "c", "d"])
 
-
-@ignore_warnings
 def test_unique_labels_non_specific():
     # Test unique_labels with a variety of collected examples
 
     # Smoke test for all supported format
-    for format in ["binary", "multiclass", "multilabel-sequences",
-                   "multilabel-indicator"]:
+    for format in ["binary", "multiclass", "multilabel-indicator"]:
         for y in EXAMPLES[format]:
             unique_labels(y)
 
@@ -227,38 +212,6 @@ def test_unique_labels_non_specific():
             assert_raises(ValueError, unique_labels, example)
 
 
-@ignore_warnings
-def test_unique_labels_mixed_types():
-    # Mix of multilabel-indicator and multilabel-sequences
-    mix_multilabel_format = product(EXAMPLES["multilabel-indicator"],
-                                    EXAMPLES["multilabel-sequences"])
-    for y_multilabel, y_multiclass in mix_multilabel_format:
-        assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel)
-        assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass)
-
-    # Mix with binary or multiclass and multilabel
-    mix_clf_format = product(EXAMPLES["multilabel-indicator"] +
-                             EXAMPLES["multilabel-sequences"],
-                             EXAMPLES["multiclass"] +
-                             EXAMPLES["binary"])
-
-    for y_multilabel, y_multiclass in mix_clf_format:
-        assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel)
-        assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass)
-
-    # Mix string and number input type
-    assert_raises(ValueError, unique_labels, [[1, 2], [3]],
-                  [["a", "d"]])
-    assert_raises(ValueError, unique_labels, ["1", 2])
-    assert_raises(ValueError, unique_labels, [["1", 2], [3]])
-    assert_raises(ValueError, unique_labels, [["1", "2"], [3]])
-
-    assert_array_equal(unique_labels([(2,), (0, 2,)], [(), ()]), [0, 2])
-    assert_array_equal(unique_labels([("2",), ("0", "2",)], [(), ()]),
-                       ["0", "2"])
-
-
-@ignore_warnings
 def test_is_multilabel():
     for group, group_examples in iteritems(EXAMPLES):
         if group.startswith('multilabel'):
@@ -313,21 +266,20 @@ def test_is_label_indicator_matrix():
 
 def test_is_sequence_of_sequences():
     for group, group_examples in iteritems(EXAMPLES):
-        if group == 'multilabel-sequences':
-            assert_, exp = assert_true, 'True'
-            check = partial(assert_warns, DeprecationWarning,
-                            is_sequence_of_sequences)
-        else:
-            assert_, exp = assert_false, 'False'
-            check = is_sequence_of_sequences
-        for example in group_examples:
-            assert_(check(example),
-                    msg='is_sequence_of_sequences(%r) should be %s'
-                    % (example, exp))
+        for i, example in enumerate(group_examples):
+            # The 1st 10 entries of EXAMPLES['unknown'] are seq of seq
+            if (i < 10) and (group == "unknown"):
+                assert_true(_is_sequence_of_sequences(example),
+                            msg=('_is_sequence_of_sequences(%r) should '
+                                 'be True' % example))
+            else:
+                assert_false(_is_sequence_of_sequences(example),
+                             msg=('_is_sequence_of_sequences(%r) should '
+                                  'be False' % example))
 
 
-@ignore_warnings
 def test_type_of_target():
+    # seq of seq is included in the 'unknown' list
     for group, group_examples in iteritems(EXAMPLES):
         for example in group_examples:
             assert_equal(type_of_target(example), group,