diff --git a/doc/modules/multiclass.rst b/doc/modules/multiclass.rst index 235bebcf0558ca1470baee9641d0125d1bab518c..5094372aca96074a6469a1b4f525f923b982d026 100644 --- a/doc/modules/multiclass.rst +++ b/doc/modules/multiclass.rst @@ -348,3 +348,30 @@ Below is an example of multioutput classification: [0, 0, 2], [2, 0, 0]]) +Classifier Chain +================ + +Classifier chains (see :class:`ClassifierChain`) are a way of combining a +number of binary classifiers into a single multi-label model that is capable + of exploiting correlations among targets. + +For a multi-label classification problem with N classes, N binary +classifiers are assigned an integer between 0 and N-1. These integers +define the order of models in the chain. Each classifier is then fit on the +available training data plus the true labels of the classes whose +models were assigned a lower number. + +When predicting, the true labels will not be available. Instead the +predictions of each model are passed on to the subsequent models in the +chain to be used as features. + +Clearly the order of the chain is important. The first model in the chain +has no information about the other labels while the last model in the chain +has features indicating the presence of all of the other labels. In general +one does not know the optimal ordering of the models in the chain so +typically many randomly ordered chains are fit and their predictions are +averaged together. + +.. topic:: References: + Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank, + "Classifier Chains for Multi-label Classification", 2009. \ No newline at end of file diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 28cbf1d6e10ea2b6547721b0f1bf89ee82849075..d367c627c27c434ed39039d72ad0a3d361d55924 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -31,6 +31,9 @@ Changelog New features ............ + - Added :class:`multioutput.ClassifierChain` for multi-label + classification. By `Adam Kleczewski <adamklec>`_. + - Validation that input data contains no NaN or inf can now be suppressed using :func:`config_context`, at your own risk. This will save on runtime, and may be particularly useful for prediction time. :issue:`7548` by diff --git a/examples/multioutput/README.txt b/examples/multioutput/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..57adada325e437a73e346a5d32fba0e83ac63bfb --- /dev/null +++ b/examples/multioutput/README.txt @@ -0,0 +1,6 @@ +.. _multioutput_examples: + +Multioutput methods +---------------- + +Examples concerning the :mod:`sklearn.multioutput` module. \ No newline at end of file diff --git a/examples/multioutput/plot_classifier_chain_yeast.py b/examples/multioutput/plot_classifier_chain_yeast.py new file mode 100644 index 0000000000000000000000000000000000000000..af649268a6151da0d40e2a6aa741d62e9bf6d3aa --- /dev/null +++ b/examples/multioutput/plot_classifier_chain_yeast.py @@ -0,0 +1,110 @@ +""" +============================ +Classifier Chain +============================ +Example of using classifier chain on a multilabel dataset. + +For this example we will use the `yeast +http://mldata.org/repository/data/viewslug/yeast/`_ dataset which +contains 2417 datapoints each with 103 features and 14 possible labels. Each +datapoint has at least one label. As a baseline we first train a logistic +regression classifier for each of the 14 labels. To evaluate the performance +of these classifiers we predict on a held-out test set and calculate the +:ref:`User Guide <jaccard_similarity_score>`. + +Next we create 10 classifier chains. Each classifier chain contains a +logistic regression model for each of the 14 labels. The models in each +chain are ordered randomly. In addition to the 103 features in the dataset, +each model gets the predictions of the preceding models in the chain as +features (note that by default at training time each model gets the true +labels as features). These additional features allow each chain to exploit +correlations among the classes. The Jaccard similarity score for each chain +tends to be greater than that of the set independent logistic models. + +Because the models in each chain are arranged randomly there is significant +variation in performance among the chains. Presumably there is an optimal +ordering of the classes in a chain that will yield the best performance. +However we do not know that ordering a priori. Instead we can construct an +voting ensemble of classifier chains by averaging the binary predictions of +the chains and apply a threshold of 0.5. The Jaccard similarity score of the +ensemble is greater than that of the independent models and tends to exceed +the score of each chain in the ensemble (although this is not guaranteed +with randomly ordered chains). +""" + +print(__doc__) + +# Author: Adam Kleczewski +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.multioutput import ClassifierChain +from sklearn.model_selection import train_test_split +from sklearn.multiclass import OneVsRestClassifier +from sklearn.metrics import jaccard_similarity_score +from sklearn.linear_model import LogisticRegression +from sklearn.datasets import fetch_mldata + +# Load a multi-label dataset +yeast = fetch_mldata('yeast') +X = yeast['data'] +Y = yeast['target'].transpose().toarray() +X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, + random_state=0) + +# Fit an independent logistic regression model for each class using the +# OneVsRestClassifier wrapper. +ovr = OneVsRestClassifier(LogisticRegression()) +ovr.fit(X_train, Y_train) +Y_pred_ovr = ovr.predict(X_test) +ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr) + +# Fit an ensemble of logistic regression classifier chains and take the +# take the average prediction of all the chains. +chains = [ClassifierChain(LogisticRegression(), order='random', random_state=i) + for i in range(10)] +for chain in chains: + chain.fit(X_train, Y_train) + +Y_pred_chains = np.array([chain.predict(X_test) for chain in + chains]) +chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5) + for Y_pred_chain in Y_pred_chains] + +Y_pred_ensemble = Y_pred_chains.mean(axis=0) +ensemble_jaccard_score = jaccard_similarity_score(Y_test, + Y_pred_ensemble >= .5) + +model_scores = [ovr_jaccard_score] + chain_jaccard_scores +model_scores.append(ensemble_jaccard_score) + +model_names = ('Independent Models', + 'Chain 1', + 'Chain 2', + 'Chain 3', + 'Chain 4', + 'Chain 5', + 'Chain 6', + 'Chain 7', + 'Chain 8', + 'Chain 9', + 'Chain 10', + 'Ensemble Average') + +y_pos = np.arange(len(model_names)) +y_pos[1:] += 1 +y_pos[-1] += 1 + +# Plot the Jaccard similarity scores for the independent model, each of the +# chains, and the ensemble (note that the vertical axis on this plot does +# not begin at 0). + +fig = plt.figure(figsize=(7, 4)) +plt.title('Classifier Chain Ensemble') +plt.xticks(y_pos, model_names, rotation='vertical') +plt.ylabel('Jaccard Similarity Score') +plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1]) +colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g'] +plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors) +plt.show() diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index bdb85ad890a97358db79d0a9eebf964afbc92e81..64e394272ffd7fd9211d2d356a0ebd597cad0f2d 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -14,20 +14,23 @@ extends single output estimators to multioutput estimators. # # License: BSD 3 clause -import numpy as np +from abc import ABCMeta +import numpy as np +import scipy.sparse as sp from abc import ABCMeta, abstractmethod from .base import BaseEstimator, clone, MetaEstimatorMixin from .base import RegressorMixin, ClassifierMixin -from .utils import check_array, check_X_y +from .model_selection import cross_val_predict +from .utils import check_array, check_X_y, check_random_state from .utils.fixes import parallel_helper -from .utils.validation import check_is_fitted, has_fit_parameter from .utils.metaestimators import if_delegate_has_method +from .utils.validation import check_is_fitted, has_fit_parameter from .utils.multiclass import check_classification_targets from .externals.joblib import Parallel, delayed from .externals import six -__all__ = ["MultiOutputRegressor", "MultiOutputClassifier"] +__all__ = ["MultiOutputRegressor", "MultiOutputClassifier", "ClassifierChain"] def _fit_estimator(estimator, X, y, sample_weight=None): @@ -365,3 +368,240 @@ class MultiOutputClassifier(MultiOutputEstimator, ClassifierMixin): format(n_outputs_, y.shape[1])) y_pred = self.predict(X) return np.mean(np.all(y == y_pred, axis=1)) + + +class ClassifierChain(BaseEstimator): + """A multi-label model that arranges binary classifiers into a chain. + + Each model makes a prediction in the order specified by the chain using + all of the available features provided to the model plus the predictions + of models that are earlier in the chain. + + Parameters + ---------- + base_estimator : estimator + The base estimator from which the classifier chain is built. + + order : array-like, shape=[n_outputs] or 'random', optional + By default the order will be determined by the order of columns in + the label matrix Y.:: + + order = [0, 1, 2, ..., Y.shape[1] - 1] + + The order of the chain can be explicitly set by providing a list of + integers. For example, for a chain of length 5.:: + + order = [1, 3, 2, 4, 0] + + means that the first model in the chain will make predictions for + column 1 in the Y matrix, the second model will make predictions + for column 3, etc. + + If order is 'random' a random ordering will be used. + + cv : int, cross-validation generator or an iterable, optional ( + default=None) + Determines whether to use cross validated predictions or true + labels for the results of previous estimators in the chain. + If cv is None the true labels are used when fitting. Otherwise + possible inputs for cv are: + * integer, to specify the number of folds in a (Stratified)KFold, + * An object to be used as a cross-validation generator. + * An iterable yielding train, test splits. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + The random number generator is used to generate random chain orders. + + Attributes + ---------- + classes_ : list + A list of arrays of length len(estimators_) containing the + class labels for each estimator in the chain. + + estimators_ : list + A list of clones of base_estimator. + + order_ : list + The order of labels in the classifier chain. + + References + ---------- + Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank, "Classifier + Chains for Multi-label Classification", 2009. + + """ + def __init__(self, base_estimator, order=None, cv=None, random_state=None): + self.base_estimator = base_estimator + self.order = order + self.cv = cv + self.random_state = random_state + + def fit(self, X, Y): + """Fit the model to data matrix X and targets Y. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + The input data. + Y : array-like, shape (n_samples, n_classes) + The target values. + + Returns + ------- + self : object + Returns self. + """ + X, Y = check_X_y(X, Y, multi_output=True, accept_sparse=True) + + random_state = check_random_state(self.random_state) + check_array(X, accept_sparse=True) + self.order_ = self.order + if self.order_ is None: + self.order_ = np.array(range(Y.shape[1])) + elif isinstance(self.order_, str): + if self.order_ == 'random': + self.order_ = random_state.permutation(Y.shape[1]) + elif sorted(self.order_) != list(range(Y.shape[1])): + raise ValueError("invalid order") + + self.estimators_ = [clone(self.base_estimator) + for _ in range(Y.shape[1])] + + self.classes_ = [] + + if self.cv is None: + Y_pred_chain = Y[:, self.order_] + if sp.issparse(X): + X_aug = sp.hstack((X, Y_pred_chain), format='lil') + X_aug = X_aug.tocsr() + else: + X_aug = np.hstack((X, Y_pred_chain)) + + elif sp.issparse(X): + Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1])) + X_aug = sp.hstack((X, Y_pred_chain), format='lil') + + else: + Y_pred_chain = np.zeros((X.shape[0], Y.shape[1])) + X_aug = np.hstack((X, Y_pred_chain)) + + del Y_pred_chain + + for chain_idx, estimator in enumerate(self.estimators_): + y = Y[:, self.order_[chain_idx]] + estimator.fit(X_aug[:, :(X.shape[1] + chain_idx)], y) + if self.cv is not None and chain_idx < len(self.estimators_) - 1: + col_idx = X.shape[1] + chain_idx + cv_result = cross_val_predict( + self.base_estimator, X_aug[:, :col_idx], + y=y, cv=self.cv) + if sp.issparse(X_aug): + X_aug[:, col_idx] = np.expand_dims(cv_result, 1) + else: + X_aug[:, col_idx] = cv_result + + self.classes_.append(estimator.classes_) + return self + + def predict(self, X): + """Predict on the data matrix X using the ClassifierChain model. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + The input data. + + Returns + ------- + Y_pred : array-like, shape (n_samples, n_classes) + The predicted values. + + """ + X = check_array(X, accept_sparse=True) + Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) + for chain_idx, estimator in enumerate(self.estimators_): + previous_predictions = Y_pred_chain[:, :chain_idx] + if sp.issparse(X): + if chain_idx == 0: + X_aug = X + else: + X_aug = sp.hstack((X, previous_predictions)) + else: + X_aug = np.hstack((X, previous_predictions)) + Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) + + inv_order = np.empty_like(self.order_) + inv_order[self.order_] = np.arange(len(self.order_)) + Y_pred = Y_pred_chain[:, inv_order] + + return Y_pred + + @if_delegate_has_method('base_estimator') + def predict_proba(self, X): + """Predict probability estimates. + + By default the inputs to later models in a chain is the binary class + predictions not the class probabilities. To use class probabilities + as features in subsequent models set the cv property to be one of + the allowed values other than None. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + + Returns + ------- + Y_prob : array-like, shape (n_samples, n_classes) + """ + X = check_array(X, accept_sparse=True) + Y_prob_chain = np.zeros((X.shape[0], len(self.estimators_))) + Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) + for chain_idx, estimator in enumerate(self.estimators_): + previous_predictions = Y_pred_chain[:, :chain_idx] + if sp.issparse(X): + X_aug = sp.hstack((X, previous_predictions)) + else: + X_aug = np.hstack((X, previous_predictions)) + Y_prob_chain[:, chain_idx] = estimator.predict_proba(X_aug)[:, 1] + Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) + inv_order = np.empty_like(self.order_) + inv_order[self.order_] = np.arange(len(self.order_)) + Y_prob = Y_prob_chain[:, inv_order] + + return Y_prob + + @if_delegate_has_method('base_estimator') + def decision_function(self, X): + """Evaluate the decision_function of the models in the chain. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + Y_decision : array-like, shape (n_samples, n_classes ) + Returns the decision function of the sample for each model + in the chain. + """ + Y_decision_chain = np.zeros((X.shape[0], len(self.estimators_))) + Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) + for chain_idx, estimator in enumerate(self.estimators_): + previous_predictions = Y_pred_chain[:, :chain_idx] + if sp.issparse(X): + X_aug = sp.hstack((X, previous_predictions)) + else: + X_aug = np.hstack((X, previous_predictions)) + Y_decision_chain[:, chain_idx] = estimator.decision_function(X_aug) + Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) + + inv_order = np.empty_like(self.order_) + inv_order[self.order_] = np.arange(len(self.order_)) + Y_decision = Y_decision_chain[:, inv_order] + + return Y_decision diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 26647c3d19a74e4e77f652fdb647d32af7ccce78..00085a32af94f0c33d8b42902438ced863213356 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -1,7 +1,8 @@ from __future__ import division + import numpy as np import scipy.sparse as sp -from sklearn.utils import shuffle + from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_false @@ -9,19 +10,26 @@ from sklearn.utils.testing import assert_raises_regex from sklearn.utils.testing import assert_raise_message from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_array_almost_equal -from sklearn.exceptions import NotFittedError from sklearn import datasets from sklearn.base import clone +from sklearn.datasets import fetch_mldata +from sklearn.datasets import make_classification from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier +from sklearn.exceptions import NotFittedError from sklearn.linear_model import Lasso +from sklearn.linear_model import LogisticRegression from sklearn.linear_model import SGDClassifier from sklearn.linear_model import SGDRegressor -from sklearn.linear_model import LogisticRegression -from sklearn.svm import LinearSVC +from sklearn.metrics import jaccard_similarity_score from sklearn.multiclass import OneVsRestClassifier -from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier +from sklearn.multioutput import ClassifierChain +from sklearn.multioutput import MultiOutputClassifier +from sklearn.multioutput import MultiOutputRegressor +from sklearn.svm import LinearSVC +from sklearn.utils import shuffle def test_multi_target_regression(): @@ -339,3 +347,147 @@ def test_multi_output_exceptions(): assert_raises(ValueError, moc.score, X, y_new) # ValueError when y is continuous assert_raise_message(ValueError, "Unknown label type", moc.fit, X, X[:, 1]) + + +def generate_multilabel_dataset_with_correlations(): + # Generate a multilabel data set from a multiclass dataset as a way of + # by representing the integer number of the original class using a binary + # encoding. + X, y = make_classification(n_samples=1000, + n_features=100, + n_classes=16, + n_informative=10) + + Y_multi = np.array([[int(yyy) for yyy in format(yy, '#06b')[2:]] + for yy in y]) + return X, Y_multi + + +def test_classifier_chain_fit_and_predict_with_logistic_regression(): + # Fit classifier chain and verify predict performance + X, Y = generate_multilabel_dataset_with_correlations() + classifier_chain = ClassifierChain(LogisticRegression()) + classifier_chain.fit(X, Y) + + Y_pred = classifier_chain.predict(X) + assert_equal(Y_pred.shape, Y.shape) + + Y_prob = classifier_chain.predict_proba(X) + Y_binary = (Y_prob >= .5) + assert_array_equal(Y_binary, Y_pred) + + assert_equal([c.coef_.size for c in classifier_chain.estimators_], + list(range(X.shape[1], X.shape[1] + Y.shape[1]))) + + +def test_classifier_chain_fit_and_predict_with_linear_svc(): + # Fit classifier chain and verify predict performance using LinearSVC + X, Y = generate_multilabel_dataset_with_correlations() + classifier_chain = ClassifierChain(LinearSVC()) + classifier_chain.fit(X, Y) + + Y_pred = classifier_chain.predict(X) + assert_equal(Y_pred.shape, Y.shape) + + Y_decision = classifier_chain.decision_function(X) + + Y_binary = (Y_decision >= 0) + assert_array_equal(Y_binary, Y_pred) + assert not hasattr(classifier_chain, 'predict_proba') + + +def test_classifier_chain_fit_and_predict_with_sparse_data(): + # Fit classifier chain with sparse data + X, Y = generate_multilabel_dataset_with_correlations() + X_sparse = sp.csr_matrix(X) + + classifier_chain = ClassifierChain(LogisticRegression()) + classifier_chain.fit(X_sparse, Y) + Y_pred_sparse = classifier_chain.predict(X_sparse) + + classifier_chain = ClassifierChain(LogisticRegression()) + classifier_chain.fit(X, Y) + Y_pred_dense = classifier_chain.predict(X) + + assert_array_equal(Y_pred_sparse, Y_pred_dense) + + +def test_classifier_chain_fit_and_predict_with_sparse_data_and_cv(): + # Fit classifier chain with sparse data cross_val_predict + X, Y = generate_multilabel_dataset_with_correlations() + X_sparse = sp.csr_matrix(X) + classifier_chain = ClassifierChain(LogisticRegression(), cv=3) + classifier_chain.fit(X_sparse, Y) + Y_pred = classifier_chain.predict(X_sparse) + assert_equal(Y_pred.shape, Y.shape) + + +def test_classifier_chain_random_order(): + # Fit classifier chain with random order + X, Y = generate_multilabel_dataset_with_correlations() + classifier_chain_random = ClassifierChain(LogisticRegression(), + order='random', + random_state=42) + classifier_chain_random.fit(X, Y) + Y_pred_random = classifier_chain_random.predict(X) + + assert_not_equal(list(classifier_chain_random.order), list(range(4))) + assert_equal(len(classifier_chain_random.order_), 4) + assert_equal(len(set(classifier_chain_random.order_)), 4) + + classifier_chain_fixed = \ + ClassifierChain(LogisticRegression(), + order=classifier_chain_random.order_) + classifier_chain_fixed.fit(X, Y) + Y_pred_fixed = classifier_chain_fixed.predict(X) + + # Randomly ordered chain should behave identically to a fixed order chain + # with the same order. + assert_array_equal(Y_pred_random, Y_pred_fixed) + + +def test_classifier_chain_crossval_fit_and_predict(): + # Fit classifier chain with cross_val_predict and verify predict + # performance + X, Y = generate_multilabel_dataset_with_correlations() + classifier_chain_cv = ClassifierChain(LogisticRegression(), cv=3) + classifier_chain_cv.fit(X, Y) + + classifier_chain = ClassifierChain(LogisticRegression()) + classifier_chain.fit(X, Y) + + Y_pred_cv = classifier_chain_cv.predict(X) + Y_pred = classifier_chain.predict(X) + + assert_equal(Y_pred_cv.shape, Y.shape) + assert_greater(jaccard_similarity_score(Y, Y_pred_cv), 0.4) + + assert_not_equal(jaccard_similarity_score(Y, Y_pred_cv), + jaccard_similarity_score(Y, Y_pred)) + + +def test_classifier_chain_vs_independent_models(): + # Verify that an ensemble of classifier chains (each of length + # N) can achieve a higher Jaccard similarity score than N independent + # models + yeast = fetch_mldata('yeast') + X = yeast['data'] + Y = yeast['target'].transpose().toarray() + X_train = X[:2000, :] + X_test = X[2000:, :] + Y_train = Y[:2000, :] + Y_test = Y[2000:, :] + + ovr = OneVsRestClassifier(LogisticRegression()) + ovr.fit(X_train, Y_train) + Y_pred_ovr = ovr.predict(X_test) + + chain = ClassifierChain(LogisticRegression(), + order=np.array([0, 2, 4, 6, 8, 10, + 12, 1, 3, 5, 7, 9, + 11, 13])) + chain.fit(X_train, Y_train) + Y_pred_chain = chain.predict(X_test) + + assert_greater(jaccard_similarity_score(Y_test, Y_pred_chain), + jaccard_similarity_score(Y_test, Y_pred_ovr)) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index fcbbbf894b76e8b6e05b13bdc754e13499a8e33b..df97ed0134ee19087b879d3194d24b831016d0fd 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -129,7 +129,7 @@ def if_delegate_has_method(delegate): delegate : string, list of strings or tuple of strings Name of the sub-estimator that can be accessed as an attribute of the base object. If a list or a tuple of names are provided, the first - sub-estimator that is an attribute of the base object will be used. + sub-estimator that is an attribute of the base object will be used. """ if isinstance(delegate, list): diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 9638efadd493151c24c0f3364143653080aaeeff..035b901abe952446f58eafacabe017f4a1fccfaf 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -508,7 +508,7 @@ def uninstall_mldata_mock(): META_ESTIMATORS = ["OneVsOneClassifier", "MultiOutputEstimator", "MultiOutputRegressor", "MultiOutputClassifier", "OutputCodeClassifier", "OneVsRestClassifier", - "RFE", "RFECV", "BaseEnsemble"] + "RFE", "RFECV", "BaseEnsemble", "ClassifierChain"] # estimators that there is no way to default-construct sensibly OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV", "SelectFromModel"]