diff --git a/doc/modules/multiclass.rst b/doc/modules/multiclass.rst
index 235bebcf0558ca1470baee9641d0125d1bab518c..5094372aca96074a6469a1b4f525f923b982d026 100644
--- a/doc/modules/multiclass.rst
+++ b/doc/modules/multiclass.rst
@@ -348,3 +348,30 @@ Below is an example of multioutput classification:
            [0, 0, 2],
            [2, 0, 0]])
 
+Classifier Chain
+================
+
+Classifier chains (see :class:`ClassifierChain`) are a way of combining a
+number of binary classifiers into a single multi-label model that is capable
+ of exploiting correlations among targets.
+
+For a multi-label classification problem with N classes, N binary
+classifiers are assigned an integer between 0 and N-1. These integers
+define the order of models in the chain. Each classifier is then fit on the
+available training data plus the true labels of the classes whose
+models were assigned a lower number.
+
+When predicting, the true labels will not be available. Instead the
+predictions of each model are passed on to the subsequent models in the
+chain to be used as features.
+
+Clearly the order of the chain is important. The first model in the chain
+has no information about the other labels while the last model in the chain
+has features indicating the presence of all of the other labels. In general
+one does not know the optimal ordering of the models in the chain so
+typically many randomly ordered chains are fit and their predictions are
+averaged together.
+
+.. topic:: References:
+    Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank,
+        "Classifier Chains for Multi-label Classification", 2009.
\ No newline at end of file
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 28cbf1d6e10ea2b6547721b0f1bf89ee82849075..d367c627c27c434ed39039d72ad0a3d361d55924 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -31,6 +31,9 @@ Changelog
 New features
 ............
 
+   - Added :class:`multioutput.ClassifierChain` for multi-label
+     classification. By `Adam Kleczewski <adamklec>`_.
+
    - Validation that input data contains no NaN or inf can now be suppressed
      using :func:`config_context`, at your own risk. This will save on runtime,
      and may be particularly useful for prediction time. :issue:`7548` by
diff --git a/examples/multioutput/README.txt b/examples/multioutput/README.txt
new file mode 100644
index 0000000000000000000000000000000000000000..57adada325e437a73e346a5d32fba0e83ac63bfb
--- /dev/null
+++ b/examples/multioutput/README.txt
@@ -0,0 +1,6 @@
+.. _multioutput_examples:
+
+Multioutput methods
+----------------
+
+Examples concerning the :mod:`sklearn.multioutput` module.
\ No newline at end of file
diff --git a/examples/multioutput/plot_classifier_chain_yeast.py b/examples/multioutput/plot_classifier_chain_yeast.py
new file mode 100644
index 0000000000000000000000000000000000000000..af649268a6151da0d40e2a6aa741d62e9bf6d3aa
--- /dev/null
+++ b/examples/multioutput/plot_classifier_chain_yeast.py
@@ -0,0 +1,110 @@
+"""
+============================
+Classifier Chain
+============================
+Example of using classifier chain on a multilabel dataset.
+
+For this example we will use the `yeast
+http://mldata.org/repository/data/viewslug/yeast/`_ dataset which
+contains 2417 datapoints each with 103 features and 14 possible labels. Each
+datapoint has at least one label. As a baseline we first train a logistic
+regression classifier for each of the 14 labels. To evaluate the performance
+of these classifiers we predict on a held-out test set and calculate the
+:ref:`User Guide <jaccard_similarity_score>`.
+
+Next we create 10 classifier chains. Each classifier chain contains a
+logistic regression model for each of the 14 labels. The models in each
+chain are ordered randomly. In addition to the 103 features in the dataset,
+each model gets the predictions of the preceding models in the chain as
+features (note that by default at training time each model gets the true
+labels as features). These additional features allow each chain to exploit
+correlations among the classes. The Jaccard similarity score for each chain
+tends to be greater than that of the set independent logistic models.
+
+Because the models in each chain are arranged randomly there is significant
+variation in performance among the chains. Presumably there is an optimal
+ordering of the classes in a chain that will yield the best performance.
+However we do not know that ordering a priori. Instead we can construct an
+voting ensemble of classifier chains by averaging the binary predictions of
+the chains and apply a threshold of 0.5. The Jaccard similarity score of the
+ensemble is greater than that of the independent models and tends to exceed
+the score of each chain in the ensemble (although this is not guaranteed
+with randomly ordered chains).
+"""
+
+print(__doc__)
+
+# Author: Adam Kleczewski
+# License: BSD 3 clause
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.multioutput import ClassifierChain
+from sklearn.model_selection import train_test_split
+from sklearn.multiclass import OneVsRestClassifier
+from sklearn.metrics import jaccard_similarity_score
+from sklearn.linear_model import LogisticRegression
+from sklearn.datasets import fetch_mldata
+
+# Load a multi-label dataset
+yeast = fetch_mldata('yeast')
+X = yeast['data']
+Y = yeast['target'].transpose().toarray()
+X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2,
+                                                    random_state=0)
+
+# Fit an independent logistic regression model for each class using the
+# OneVsRestClassifier wrapper.
+ovr = OneVsRestClassifier(LogisticRegression())
+ovr.fit(X_train, Y_train)
+Y_pred_ovr = ovr.predict(X_test)
+ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr)
+
+# Fit an ensemble of logistic regression classifier chains and take the
+# take the average prediction of all the chains.
+chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i)
+          for i in range(10)]
+for chain in chains:
+    chain.fit(X_train, Y_train)
+
+Y_pred_chains = np.array([chain.predict(X_test) for chain in
+                          chains])
+chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5)
+                        for Y_pred_chain in Y_pred_chains]
+
+Y_pred_ensemble = Y_pred_chains.mean(axis=0)
+ensemble_jaccard_score = jaccard_similarity_score(Y_test,
+                                                  Y_pred_ensemble >= .5)
+
+model_scores = [ovr_jaccard_score] + chain_jaccard_scores
+model_scores.append(ensemble_jaccard_score)
+
+model_names = ('Independent Models',
+               'Chain 1',
+               'Chain 2',
+               'Chain 3',
+               'Chain 4',
+               'Chain 5',
+               'Chain 6',
+               'Chain 7',
+               'Chain 8',
+               'Chain 9',
+               'Chain 10',
+               'Ensemble Average')
+
+y_pos = np.arange(len(model_names))
+y_pos[1:] += 1
+y_pos[-1] += 1
+
+# Plot the Jaccard similarity scores for the independent model, each of the
+# chains, and the ensemble (note that the vertical axis on this plot does
+# not begin at 0).
+
+fig = plt.figure(figsize=(7, 4))
+plt.title('Classifier Chain Ensemble')
+plt.xticks(y_pos, model_names, rotation='vertical')
+plt.ylabel('Jaccard Similarity Score')
+plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
+colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
+plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
+plt.show()
diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py
index bdb85ad890a97358db79d0a9eebf964afbc92e81..64e394272ffd7fd9211d2d356a0ebd597cad0f2d 100644
--- a/sklearn/multioutput.py
+++ b/sklearn/multioutput.py
@@ -14,20 +14,23 @@ extends single output estimators to multioutput estimators.
 #
 # License: BSD 3 clause
 
-import numpy as np
+from abc import ABCMeta
 
+import numpy as np
+import scipy.sparse as sp
 from abc import ABCMeta, abstractmethod
 from .base import BaseEstimator, clone, MetaEstimatorMixin
 from .base import RegressorMixin, ClassifierMixin
-from .utils import check_array, check_X_y
+from .model_selection import cross_val_predict
+from .utils import check_array, check_X_y, check_random_state
 from .utils.fixes import parallel_helper
-from .utils.validation import check_is_fitted, has_fit_parameter
 from .utils.metaestimators import if_delegate_has_method
+from .utils.validation import check_is_fitted, has_fit_parameter
 from .utils.multiclass import check_classification_targets
 from .externals.joblib import Parallel, delayed
 from .externals import six
 
-__all__ = ["MultiOutputRegressor", "MultiOutputClassifier"]
+__all__ = ["MultiOutputRegressor", "MultiOutputClassifier", "ClassifierChain"]
 
 
 def _fit_estimator(estimator, X, y, sample_weight=None):
@@ -365,3 +368,240 @@ class MultiOutputClassifier(MultiOutputEstimator, ClassifierMixin):
                              format(n_outputs_, y.shape[1]))
         y_pred = self.predict(X)
         return np.mean(np.all(y == y_pred, axis=1))
+
+
+class ClassifierChain(BaseEstimator):
+    """A multi-label model that arranges binary classifiers into a chain.
+
+    Each model makes a prediction in the order specified by the chain using
+    all of the available features provided to the model plus the predictions
+    of models that are earlier in the chain.
+
+    Parameters
+    ----------
+    base_estimator : estimator
+        The base estimator from which the classifier chain is built.
+
+    order : array-like, shape=[n_outputs] or 'random', optional
+        By default the order will be determined by the order of columns in
+        the label matrix Y.::
+
+            order = [0, 1, 2, ..., Y.shape[1] - 1]
+
+        The order of the chain can be explicitly set by providing a list of
+        integers. For example, for a chain of length 5.::
+
+            order = [1, 3, 2, 4, 0]
+
+        means that the first model in the chain will make predictions for
+        column 1 in the Y matrix, the second model will make predictions
+        for column 3, etc.
+
+        If order is 'random' a random ordering will be used.
+
+    cv : int, cross-validation generator or an iterable, optional (
+    default=None)
+        Determines whether to use cross validated predictions or true
+        labels for the results of previous estimators in the chain.
+        If cv is None the true labels are used when fitting. Otherwise
+        possible inputs for cv are:
+            * integer, to specify the number of folds in a (Stratified)KFold,
+            * An object to be used as a cross-validation generator.
+            * An iterable yielding train, test splits.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by `np.random`.
+
+        The random number generator is used to generate random chain orders.
+
+    Attributes
+    ----------
+    classes_ : list
+        A list of arrays of length len(estimators_) containing the
+        class labels for each estimator in the chain.
+
+    estimators_ : list
+        A list of clones of base_estimator.
+
+    order_ : list
+        The order of labels in the classifier chain.
+
+    References
+    ----------
+    Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank, "Classifier
+    Chains for Multi-label Classification", 2009.
+
+    """
+    def __init__(self, base_estimator, order=None, cv=None, random_state=None):
+        self.base_estimator = base_estimator
+        self.order = order
+        self.cv = cv
+        self.random_state = random_state
+
+    def fit(self, X, Y):
+        """Fit the model to data matrix X and targets Y.
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix}, shape (n_samples, n_features)
+            The input data.
+        Y : array-like, shape (n_samples, n_classes)
+            The target values.
+
+        Returns
+        -------
+        self : object
+            Returns self.
+        """
+        X, Y = check_X_y(X, Y,  multi_output=True, accept_sparse=True)
+
+        random_state = check_random_state(self.random_state)
+        check_array(X, accept_sparse=True)
+        self.order_ = self.order
+        if self.order_ is None:
+            self.order_ = np.array(range(Y.shape[1]))
+        elif isinstance(self.order_, str):
+            if self.order_ == 'random':
+                self.order_ = random_state.permutation(Y.shape[1])
+        elif sorted(self.order_) != list(range(Y.shape[1])):
+                raise ValueError("invalid order")
+
+        self.estimators_ = [clone(self.base_estimator)
+                            for _ in range(Y.shape[1])]
+
+        self.classes_ = []
+
+        if self.cv is None:
+            Y_pred_chain = Y[:, self.order_]
+            if sp.issparse(X):
+                X_aug = sp.hstack((X, Y_pred_chain), format='lil')
+                X_aug = X_aug.tocsr()
+            else:
+                X_aug = np.hstack((X, Y_pred_chain))
+
+        elif sp.issparse(X):
+            Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1]))
+            X_aug = sp.hstack((X, Y_pred_chain), format='lil')
+
+        else:
+            Y_pred_chain = np.zeros((X.shape[0], Y.shape[1]))
+            X_aug = np.hstack((X, Y_pred_chain))
+
+        del Y_pred_chain
+
+        for chain_idx, estimator in enumerate(self.estimators_):
+            y = Y[:, self.order_[chain_idx]]
+            estimator.fit(X_aug[:, :(X.shape[1] + chain_idx)], y)
+            if self.cv is not None and chain_idx < len(self.estimators_) - 1:
+                col_idx = X.shape[1] + chain_idx
+                cv_result = cross_val_predict(
+                    self.base_estimator, X_aug[:, :col_idx],
+                    y=y, cv=self.cv)
+                if sp.issparse(X_aug):
+                    X_aug[:, col_idx] = np.expand_dims(cv_result, 1)
+                else:
+                    X_aug[:, col_idx] = cv_result
+
+            self.classes_.append(estimator.classes_)
+        return self
+
+    def predict(self, X):
+        """Predict on the data matrix X using the ClassifierChain model.
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix}, shape (n_samples, n_features)
+            The input data.
+
+        Returns
+        -------
+        Y_pred : array-like, shape (n_samples, n_classes)
+            The predicted values.
+
+        """
+        X = check_array(X, accept_sparse=True)
+        Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
+        for chain_idx, estimator in enumerate(self.estimators_):
+            previous_predictions = Y_pred_chain[:, :chain_idx]
+            if sp.issparse(X):
+                if chain_idx == 0:
+                    X_aug = X
+                else:
+                    X_aug = sp.hstack((X, previous_predictions))
+            else:
+                X_aug = np.hstack((X, previous_predictions))
+            Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
+
+        inv_order = np.empty_like(self.order_)
+        inv_order[self.order_] = np.arange(len(self.order_))
+        Y_pred = Y_pred_chain[:, inv_order]
+
+        return Y_pred
+
+    @if_delegate_has_method('base_estimator')
+    def predict_proba(self, X):
+        """Predict probability estimates.
+
+        By default the inputs to later models in a chain is the binary class
+        predictions not the class probabilities. To use class probabilities
+        as features in subsequent models set the cv property to be one of
+        the allowed values other than None.
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix}, shape (n_samples, n_features)
+
+        Returns
+        -------
+        Y_prob : array-like, shape (n_samples, n_classes)
+        """
+        X = check_array(X, accept_sparse=True)
+        Y_prob_chain = np.zeros((X.shape[0], len(self.estimators_)))
+        Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
+        for chain_idx, estimator in enumerate(self.estimators_):
+            previous_predictions = Y_pred_chain[:, :chain_idx]
+            if sp.issparse(X):
+                X_aug = sp.hstack((X, previous_predictions))
+            else:
+                X_aug = np.hstack((X, previous_predictions))
+            Y_prob_chain[:, chain_idx] = estimator.predict_proba(X_aug)[:, 1]
+            Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
+        inv_order = np.empty_like(self.order_)
+        inv_order[self.order_] = np.arange(len(self.order_))
+        Y_prob = Y_prob_chain[:, inv_order]
+
+        return Y_prob
+
+    @if_delegate_has_method('base_estimator')
+    def decision_function(self, X):
+        """Evaluate the decision_function of the models in the chain.
+
+        Parameters
+        ----------
+        X : array-like, shape (n_samples, n_features)
+
+        Returns
+        -------
+        Y_decision : array-like, shape (n_samples, n_classes )
+            Returns the decision function of the sample for each model
+            in the chain.
+        """
+        Y_decision_chain = np.zeros((X.shape[0], len(self.estimators_)))
+        Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
+        for chain_idx, estimator in enumerate(self.estimators_):
+            previous_predictions = Y_pred_chain[:, :chain_idx]
+            if sp.issparse(X):
+                X_aug = sp.hstack((X, previous_predictions))
+            else:
+                X_aug = np.hstack((X, previous_predictions))
+            Y_decision_chain[:, chain_idx] = estimator.decision_function(X_aug)
+            Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
+
+        inv_order = np.empty_like(self.order_)
+        inv_order[self.order_] = np.arange(len(self.order_))
+        Y_decision = Y_decision_chain[:, inv_order]
+
+        return Y_decision
diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py
index 26647c3d19a74e4e77f652fdb647d32af7ccce78..00085a32af94f0c33d8b42902438ced863213356 100644
--- a/sklearn/tests/test_multioutput.py
+++ b/sklearn/tests/test_multioutput.py
@@ -1,7 +1,8 @@
 from __future__ import division
+
 import numpy as np
 import scipy.sparse as sp
-from sklearn.utils import shuffle
+
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_false
@@ -9,19 +10,26 @@ from sklearn.utils.testing import assert_raises_regex
 from sklearn.utils.testing import assert_raise_message
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_greater
 from sklearn.utils.testing import assert_not_equal
 from sklearn.utils.testing import assert_array_almost_equal
-from sklearn.exceptions import NotFittedError
 from sklearn import datasets
 from sklearn.base import clone
+from sklearn.datasets import fetch_mldata
+from sklearn.datasets import make_classification
 from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
+from sklearn.exceptions import NotFittedError
 from sklearn.linear_model import Lasso
+from sklearn.linear_model import LogisticRegression
 from sklearn.linear_model import SGDClassifier
 from sklearn.linear_model import SGDRegressor
-from sklearn.linear_model import LogisticRegression
-from sklearn.svm import LinearSVC
+from sklearn.metrics import jaccard_similarity_score
 from sklearn.multiclass import OneVsRestClassifier
-from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
+from sklearn.multioutput import ClassifierChain
+from sklearn.multioutput import MultiOutputClassifier
+from sklearn.multioutput import MultiOutputRegressor
+from sklearn.svm import LinearSVC
+from sklearn.utils import shuffle
 
 
 def test_multi_target_regression():
@@ -339,3 +347,147 @@ def test_multi_output_exceptions():
     assert_raises(ValueError, moc.score, X, y_new)
     # ValueError when y is continuous
     assert_raise_message(ValueError, "Unknown label type", moc.fit, X, X[:, 1])
+
+
+def generate_multilabel_dataset_with_correlations():
+    # Generate a multilabel data set from a multiclass dataset as a way of
+    # by representing the integer number of the original class using a binary
+    # encoding.
+    X, y = make_classification(n_samples=1000,
+                               n_features=100,
+                               n_classes=16,
+                               n_informative=10)
+
+    Y_multi = np.array([[int(yyy) for yyy in format(yy, '#06b')[2:]]
+                        for yy in y])
+    return X, Y_multi
+
+
+def test_classifier_chain_fit_and_predict_with_logistic_regression():
+    # Fit classifier chain and verify predict performance
+    X, Y = generate_multilabel_dataset_with_correlations()
+    classifier_chain = ClassifierChain(LogisticRegression())
+    classifier_chain.fit(X, Y)
+
+    Y_pred = classifier_chain.predict(X)
+    assert_equal(Y_pred.shape, Y.shape)
+
+    Y_prob = classifier_chain.predict_proba(X)
+    Y_binary = (Y_prob >= .5)
+    assert_array_equal(Y_binary, Y_pred)
+
+    assert_equal([c.coef_.size for c in classifier_chain.estimators_],
+                 list(range(X.shape[1], X.shape[1] + Y.shape[1])))
+
+
+def test_classifier_chain_fit_and_predict_with_linear_svc():
+    # Fit classifier chain and verify predict performance using LinearSVC
+    X, Y = generate_multilabel_dataset_with_correlations()
+    classifier_chain = ClassifierChain(LinearSVC())
+    classifier_chain.fit(X, Y)
+
+    Y_pred = classifier_chain.predict(X)
+    assert_equal(Y_pred.shape, Y.shape)
+
+    Y_decision = classifier_chain.decision_function(X)
+
+    Y_binary = (Y_decision >= 0)
+    assert_array_equal(Y_binary, Y_pred)
+    assert not hasattr(classifier_chain, 'predict_proba')
+
+
+def test_classifier_chain_fit_and_predict_with_sparse_data():
+    # Fit classifier chain with sparse data
+    X, Y = generate_multilabel_dataset_with_correlations()
+    X_sparse = sp.csr_matrix(X)
+
+    classifier_chain = ClassifierChain(LogisticRegression())
+    classifier_chain.fit(X_sparse, Y)
+    Y_pred_sparse = classifier_chain.predict(X_sparse)
+
+    classifier_chain = ClassifierChain(LogisticRegression())
+    classifier_chain.fit(X, Y)
+    Y_pred_dense = classifier_chain.predict(X)
+
+    assert_array_equal(Y_pred_sparse, Y_pred_dense)
+
+
+def test_classifier_chain_fit_and_predict_with_sparse_data_and_cv():
+    # Fit classifier chain with sparse data cross_val_predict
+    X, Y = generate_multilabel_dataset_with_correlations()
+    X_sparse = sp.csr_matrix(X)
+    classifier_chain = ClassifierChain(LogisticRegression(), cv=3)
+    classifier_chain.fit(X_sparse, Y)
+    Y_pred = classifier_chain.predict(X_sparse)
+    assert_equal(Y_pred.shape, Y.shape)
+
+
+def test_classifier_chain_random_order():
+    # Fit classifier chain with random order
+    X, Y = generate_multilabel_dataset_with_correlations()
+    classifier_chain_random = ClassifierChain(LogisticRegression(),
+                                              order='random',
+                                              random_state=42)
+    classifier_chain_random.fit(X, Y)
+    Y_pred_random = classifier_chain_random.predict(X)
+
+    assert_not_equal(list(classifier_chain_random.order), list(range(4)))
+    assert_equal(len(classifier_chain_random.order_), 4)
+    assert_equal(len(set(classifier_chain_random.order_)), 4)
+
+    classifier_chain_fixed = \
+        ClassifierChain(LogisticRegression(),
+                        order=classifier_chain_random.order_)
+    classifier_chain_fixed.fit(X, Y)
+    Y_pred_fixed = classifier_chain_fixed.predict(X)
+
+    # Randomly ordered chain should behave identically to a fixed order chain
+    # with the same order.
+    assert_array_equal(Y_pred_random, Y_pred_fixed)
+
+
+def test_classifier_chain_crossval_fit_and_predict():
+    # Fit classifier chain with cross_val_predict and verify predict
+    # performance
+    X, Y = generate_multilabel_dataset_with_correlations()
+    classifier_chain_cv = ClassifierChain(LogisticRegression(), cv=3)
+    classifier_chain_cv.fit(X, Y)
+
+    classifier_chain = ClassifierChain(LogisticRegression())
+    classifier_chain.fit(X, Y)
+
+    Y_pred_cv = classifier_chain_cv.predict(X)
+    Y_pred = classifier_chain.predict(X)
+
+    assert_equal(Y_pred_cv.shape, Y.shape)
+    assert_greater(jaccard_similarity_score(Y, Y_pred_cv), 0.4)
+
+    assert_not_equal(jaccard_similarity_score(Y, Y_pred_cv),
+                     jaccard_similarity_score(Y, Y_pred))
+
+
+def test_classifier_chain_vs_independent_models():
+    # Verify that an ensemble of classifier chains (each of length
+    # N) can achieve a higher Jaccard similarity score than N independent
+    # models
+    yeast = fetch_mldata('yeast')
+    X = yeast['data']
+    Y = yeast['target'].transpose().toarray()
+    X_train = X[:2000, :]
+    X_test = X[2000:, :]
+    Y_train = Y[:2000, :]
+    Y_test = Y[2000:, :]
+
+    ovr = OneVsRestClassifier(LogisticRegression())
+    ovr.fit(X_train, Y_train)
+    Y_pred_ovr = ovr.predict(X_test)
+
+    chain = ClassifierChain(LogisticRegression(),
+                            order=np.array([0, 2, 4, 6, 8, 10,
+                                            12, 1, 3, 5, 7, 9,
+                                            11, 13]))
+    chain.fit(X_train, Y_train)
+    Y_pred_chain = chain.predict(X_test)
+
+    assert_greater(jaccard_similarity_score(Y_test, Y_pred_chain),
+                   jaccard_similarity_score(Y_test, Y_pred_ovr))
diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py
index fcbbbf894b76e8b6e05b13bdc754e13499a8e33b..df97ed0134ee19087b879d3194d24b831016d0fd 100644
--- a/sklearn/utils/metaestimators.py
+++ b/sklearn/utils/metaestimators.py
@@ -129,7 +129,7 @@ def if_delegate_has_method(delegate):
     delegate : string, list of strings or tuple of strings
         Name of the sub-estimator that can be accessed as an attribute of the
         base object. If a list or a tuple of names are provided, the first
-        sub-estimator that is an attribute of the base object  will be used.
+        sub-estimator that is an attribute of the base object will be used.
 
     """
     if isinstance(delegate, list):
diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py
index 9638efadd493151c24c0f3364143653080aaeeff..035b901abe952446f58eafacabe017f4a1fccfaf 100644
--- a/sklearn/utils/testing.py
+++ b/sklearn/utils/testing.py
@@ -508,7 +508,7 @@ def uninstall_mldata_mock():
 META_ESTIMATORS = ["OneVsOneClassifier", "MultiOutputEstimator",
                    "MultiOutputRegressor", "MultiOutputClassifier",
                    "OutputCodeClassifier", "OneVsRestClassifier",
-                   "RFE", "RFECV", "BaseEnsemble"]
+                   "RFE", "RFECV", "BaseEnsemble", "ClassifierChain"]
 # estimators that there is no way to default-construct sensibly
 OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV",
          "SelectFromModel"]