diff --git a/doc/whats_new.rst b/doc/whats_new.rst index ce72f193ed8ddf1a4417c6f03631dfa9d839c867..816471cb5232f8ffcc4b93c6ce9f5697c9037f9c 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -163,6 +163,13 @@ Enhancements - In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict`` is a lot faster with ``return_std=True`` by :user:`Hadrien Bertrand <hbertrand>`. + - Added ability to use sparse matrices in :func:`feature_selection.f_regression` + with ``center=True``. :issue:`8065` by :user:`Daniel LeJeune <acadiansith>`. + + - :class:`ensemble.VotingClassifier` now allow changing estimators by using + :meth:`ensemble.VotingClassifier.set_params`. Estimators can also be + removed by setting it to `None`. + :issue:`7674` by:user:`Yichuan Liu <yl565>`. Bug fixes ......... diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py index 2ad007741940c6b423cc13f6bf135239fba62e1b..d61d8bfac62bee6d42fb2c98b2ce4e527907670f 100644 --- a/sklearn/ensemble/tests/test_voting_classifier.py +++ b/sklearn/ensemble/tests/test_voting_classifier.py @@ -2,7 +2,7 @@ import numpy as np from sklearn.utils.testing import assert_almost_equal, assert_array_equal -from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_equal, assert_true, assert_false from sklearn.utils.testing import assert_raise_message from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression @@ -40,6 +40,19 @@ def test_estimator_init(): '; got 2 weights, 1 estimators') assert_raise_message(ValueError, msg, eclf.fit, X, y) + eclf = VotingClassifier(estimators=[('lr', clf), ('lr', clf)], + weights=[1, 2]) + msg = "Names provided are not unique: ['lr', 'lr']" + assert_raise_message(ValueError, msg, eclf.fit, X, y) + + eclf = VotingClassifier(estimators=[('lr__', clf)]) + msg = "Estimator names must not contain __: got ['lr__']" + assert_raise_message(ValueError, msg, eclf.fit, X, y) + + eclf = VotingClassifier(estimators=[('estimators', clf)]) + msg = "Estimator names conflict with constructor arguments: ['estimators']" + assert_raise_message(ValueError, msg, eclf.fit, X, y) + def test_predictproba_hardvoting(): eclf = VotingClassifier(estimators=[('lr1', LogisticRegression()), @@ -260,6 +273,82 @@ def test_sample_weight(): assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight) +def test_set_params(): + """set_params should be able to set estimators""" + clf1 = LogisticRegression(random_state=123, C=1.0) + clf2 = RandomForestClassifier(random_state=123, max_depth=None) + clf3 = GaussianNB() + eclf1 = VotingClassifier([('lr', clf1), ('rf', clf2)], voting='soft', + weights=[1, 2]) + eclf1.fit(X, y) + eclf2 = VotingClassifier([('lr', clf1), ('nb', clf3)], voting='soft', + weights=[1, 2]) + eclf2.set_params(nb=clf2).fit(X, y) + assert_false(hasattr(eclf2, 'nb')) + + assert_array_equal(eclf1.predict(X), eclf2.predict(X)) + assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) + assert_equal(eclf2.estimators[0][1].get_params(), clf1.get_params()) + assert_equal(eclf2.estimators[1][1].get_params(), clf2.get_params()) + + eclf1.set_params(lr__C=10.0) + eclf2.set_params(nb__max_depth=5) + + assert_true(eclf1.estimators[0][1].get_params()['C'] == 10.0) + assert_true(eclf2.estimators[1][1].get_params()['max_depth'] == 5) + assert_equal(eclf1.get_params()["lr__C"], + eclf1.get_params()["lr"].get_params()['C']) + + +def test_set_estimator_none(): + """VotingClassifier set_params should be able to set estimators as None""" + # Test predict + clf1 = LogisticRegression(random_state=123) + clf2 = RandomForestClassifier(random_state=123) + clf3 = GaussianNB() + eclf1 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), + ('nb', clf3)], + voting='hard', weights=[1, 0, 0.5]).fit(X, y) + + eclf2 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), + ('nb', clf3)], + voting='hard', weights=[1, 1, 0.5]) + eclf2.set_params(rf=None).fit(X, y) + assert_array_equal(eclf1.predict(X), eclf2.predict(X)) + + assert_true(dict(eclf2.estimators)["rf"] is None) + assert_true(len(eclf2.estimators_) == 2) + assert_true(all([not isinstance(est, RandomForestClassifier) for est in + eclf2.estimators_])) + assert_true(eclf2.get_params()["rf"] is None) + + eclf1.set_params(voting='soft').fit(X, y) + eclf2.set_params(voting='soft').fit(X, y) + assert_array_equal(eclf1.predict(X), eclf2.predict(X)) + assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) + msg = ('All estimators are None. At least one is required' + ' to be a classifier!') + assert_raise_message( + ValueError, msg, eclf2.set_params(lr=None, rf=None, nb=None).fit, X, y) + + # Test soft voting transform + X1 = np.array([[1], [2]]) + y1 = np.array([1, 2]) + eclf1 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)], + voting='soft', weights=[0, 0.5]).fit(X1, y1) + + eclf2 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)], + voting='soft', weights=[1, 0.5]) + eclf2.set_params(rf=None).fit(X1, y1) + assert_array_equal(eclf1.transform(X1), np.array([[[0.7, 0.3], [0.3, 0.7]], + [[1., 0.], [0., 1.]]])) + assert_array_equal(eclf2.transform(X1), np.array([[[1., 0.], [0., 1.]]])) + eclf1.set_params(voting='hard') + eclf2.set_params(voting='hard') + assert_array_equal(eclf1.transform(X1), np.array([[0, 0], [1, 1]])) + assert_array_equal(eclf2.transform(X1), np.array([[0], [1]])) + + def test_estimator_weights_format(): # Test estimator weights inputs as list and array clf1 = LogisticRegression(random_state=123) diff --git a/sklearn/ensemble/voting_classifier.py b/sklearn/ensemble/voting_classifier.py index cb0d6ad19c983b302a8c07b0cdd671d7fe865a3d..44cf4fe775ce3c385e484f55170a23452ae91d10 100644 --- a/sklearn/ensemble/voting_classifier.py +++ b/sklearn/ensemble/voting_classifier.py @@ -13,14 +13,13 @@ classification estimators. import numpy as np -from ..base import BaseEstimator from ..base import ClassifierMixin from ..base import TransformerMixin from ..base import clone from ..preprocessing import LabelEncoder -from ..externals import six from ..externals.joblib import Parallel, delayed from ..utils.validation import has_fit_parameter, check_is_fitted +from ..utils.metaestimators import _BaseComposition def _parallel_fit_estimator(estimator, X, y, sample_weight): @@ -32,7 +31,7 @@ def _parallel_fit_estimator(estimator, X, y, sample_weight): return estimator -class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): +class VotingClassifier(_BaseComposition, ClassifierMixin, TransformerMixin): """Soft Voting/Majority Rule classifier for unfitted estimators. .. versionadded:: 0.17 @@ -44,7 +43,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): estimators : list of (string, estimator) tuples Invoking the ``fit`` method on the ``VotingClassifier`` will fit clones of those original estimators that will be stored in the class attribute - `self.estimators_`. + ``self.estimators_``. An estimator can be set to `None` using + ``set_params``. voting : str, {'hard', 'soft'} (default='hard') If 'hard', uses predicted class labels for majority rule voting. @@ -64,7 +64,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): Attributes ---------- estimators_ : list of classifiers - The collection of fitted sub-estimators. + The collection of fitted sub-estimators as defined in ``estimators`` + that are not `None`. classes_ : array-like, shape = [n_predictions] The classes labels. @@ -102,11 +103,14 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): def __init__(self, estimators, voting='hard', weights=None, n_jobs=1): self.estimators = estimators - self.named_estimators = dict(estimators) self.voting = voting self.weights = weights self.n_jobs = n_jobs + @property + def named_estimators(self): + return dict(self.estimators) + def fit(self, X, y, sample_weight=None): """ Fit the estimators. @@ -150,11 +154,16 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): if sample_weight is not None: for name, step in self.estimators: if not has_fit_parameter(step, 'sample_weight'): - raise ValueError('Underlying estimator \'%s\' does not support' - ' sample weights.' % name) - - self.le_ = LabelEncoder() - self.le_.fit(y) + raise ValueError('Underlying estimator \'%s\' does not' + ' support sample weights.' % name) + names, clfs = zip(*self.estimators) + self._validate_names(names) + + n_isnone = np.sum([clf is None for _, clf in self.estimators]) + if n_isnone == len(self.estimators): + raise ValueError('All estimators are None. At least one is ' + 'required to be a classifier!') + self.le_ = LabelEncoder().fit(y) self.classes_ = self.le_.classes_ self.estimators_ = [] @@ -162,11 +171,19 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_parallel_fit_estimator)(clone(clf), X, transformed_y, - sample_weight) - for _, clf in self.estimators) + sample_weight) + for clf in clfs if clf is not None) return self + @property + def _weights_not_none(self): + """Get the weights of not `None` estimators""" + if self.weights is None: + return None + return [w for est, w in zip(self.estimators, + self.weights) if est[1] is not None] + def predict(self, X): """ Predict class labels for X. @@ -188,11 +205,10 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): else: # 'hard' voting predictions = self._predict(X) - maj = np.apply_along_axis(lambda x: - np.argmax(np.bincount(x, - weights=self.weights)), - axis=1, - arr=predictions.astype('int')) + maj = np.apply_along_axis( + lambda x: np.argmax( + np.bincount(x, weights=self._weights_not_none)), + axis=1, arr=predictions.astype('int')) maj = self.le_.inverse_transform(maj) @@ -208,7 +224,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): raise AttributeError("predict_proba is not available when" " voting=%r" % self.voting) check_is_fitted(self, 'estimators_') - avg = np.average(self._collect_probas(X), axis=0, weights=self.weights) + avg = np.average(self._collect_probas(X), axis=0, + weights=self._weights_not_none) return avg @property @@ -252,17 +269,42 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): else: return self._predict(X) + def set_params(self, **params): + """ Setting the parameters for the voting classifier + + Valid parameter keys can be listed with get_params(). + + Parameters + ---------- + params: keyword arguments + Specific parameters using e.g. set_params(parameter_name=new_value) + In addition, to setting the parameters of the ``VotingClassifier``, + the individual classifiers of the ``VotingClassifier`` can also be + set or replaced by setting them to None. + + Examples + -------- + # In this example, the RandomForestClassifier is removed + clf1 = LogisticRegression() + clf2 = RandomForestClassifier() + eclf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2)] + eclf.set_params(rf=None) + + """ + super(VotingClassifier, self)._set_params('estimators', **params) + return self + def get_params(self, deep=True): - """Return estimator parameter names for GridSearch support""" - if not deep: - return super(VotingClassifier, self).get_params(deep=False) - else: - out = super(VotingClassifier, self).get_params(deep=False) - out.update(self.named_estimators.copy()) - for name, step in six.iteritems(self.named_estimators): - for key, value in six.iteritems(step.get_params(deep=True)): - out['%s__%s' % (name, key)] = value - return out + """ Get the parameters of the VotingClassifier + + Parameters + ---------- + deep: bool + Setting it to True gets the various classifiers and the parameters + of the classifiers as well + """ + return super(VotingClassifier, + self)._get_params('estimators', deep=deep) def _predict(self, X): """Collect results from clf.predict calls. """ diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 0361e109015ff2a918222eacb1416db4381c1df0..9377c8e2fd7aaace8210c16276b442d692ea475d 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -10,6 +10,7 @@ estimator, as a chain of transforms and estimators. # License: BSD from collections import defaultdict + from abc import ABCMeta, abstractmethod import numpy as np @@ -22,68 +23,12 @@ from .utils import tosequence from .utils.metaestimators import if_delegate_has_method from .utils import Bunch -__all__ = ['Pipeline', 'FeatureUnion'] - +from .utils.metaestimators import _BaseComposition -class _BasePipeline(six.with_metaclass(ABCMeta, BaseEstimator)): - """Handles parameter management for classifiers composed of named steps. - """ +__all__ = ['Pipeline', 'FeatureUnion'] - @abstractmethod - def __init__(self): - pass - - def _replace_step(self, steps_attr, name, new_val): - # assumes `name` is a valid step name - new_steps = getattr(self, steps_attr)[:] - for i, (step_name, _) in enumerate(new_steps): - if step_name == name: - new_steps[i] = (name, new_val) - break - setattr(self, steps_attr, new_steps) - - def _get_params(self, steps_attr, deep=True): - out = super(_BasePipeline, self).get_params(deep=False) - if not deep: - return out - steps = getattr(self, steps_attr) - out.update(steps) - for name, estimator in steps: - if estimator is None: - continue - for key, value in six.iteritems(estimator.get_params(deep=True)): - out['%s__%s' % (name, key)] = value - return out - - def _set_params(self, steps_attr, **params): - # Ensure strict ordering of parameter setting: - # 1. All steps - if steps_attr in params: - setattr(self, steps_attr, params.pop(steps_attr)) - # 2. Step replacement - step_names, _ = zip(*getattr(self, steps_attr)) - for name in list(six.iterkeys(params)): - if '__' not in name and name in step_names: - self._replace_step(steps_attr, name, params.pop(name)) - # 3. Step parameters and other initilisation arguments - super(_BasePipeline, self).set_params(**params) - return self - def _validate_names(self, names): - if len(set(names)) != len(names): - raise ValueError('Names provided are not unique: ' - '{0!r}'.format(list(names))) - invalid_names = set(names).intersection(self.get_params(deep=False)) - if invalid_names: - raise ValueError('Step names conflict with constructor arguments: ' - '{0!r}'.format(sorted(invalid_names))) - invalid_names = [name for name in names if '__' in name] - if invalid_names: - raise ValueError('Step names must not contain __: got ' - '{0!r}'.format(invalid_names)) - - -class Pipeline(_BasePipeline): +class Pipeline(_BaseComposition): """Pipeline of transforms with a final estimator. Sequentially apply a list of transforms and a final estimator. @@ -631,7 +576,7 @@ def _fit_transform_one(transformer, weight, X, y, return res * weight, transformer -class FeatureUnion(_BasePipeline, TransformerMixin): +class FeatureUnion(_BaseComposition, TransformerMixin): """Concatenates results of multiple transformer objects. This estimator applies a list of transformer objects in parallel to the diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d4c4844fe375de21ab4ab65f2822a89191d0c572..a7c8e4593420ff290d87aae6085a69c7a0bf4a39 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -811,9 +811,9 @@ def test_step_name_validation(): # we validate in construction (despite scikit-learn convention) bad_steps3 = [('a', Mult(2)), (param, Mult(3))] for bad_steps, message in [ - (bad_steps1, "Step names must not contain __: got ['a__q']"), + (bad_steps1, "Estimator names must not contain __: got ['a__q']"), (bad_steps2, "Names provided are not unique: ['a', 'a']"), - (bad_steps3, "Step names conflict with constructor " + (bad_steps3, "Estimator names conflict with constructor " "arguments: ['%s']" % param), ]: # three ways to make invalid: diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 3123bb1778ce36d1bde905326d7be883d978a5c6..fcbbbf894b76e8b6e05b13bdc754e13499a8e33b 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -3,14 +3,75 @@ # Andreas Mueller # License: BSD +from abc import ABCMeta, abstractmethod from operator import attrgetter from functools import update_wrapper import numpy as np + from ..utils import safe_indexing +from ..externals import six +from ..base import BaseEstimator __all__ = ['if_delegate_has_method'] +class _BaseComposition(six.with_metaclass(ABCMeta, BaseEstimator)): + """Handles parameter management for classifiers composed of named estimators. + """ + @abstractmethod + def __init__(self): + pass + + def _get_params(self, attr, deep=True): + out = super(_BaseComposition, self).get_params(deep=False) + if not deep: + return out + estimators = getattr(self, attr) + out.update(estimators) + for name, estimator in estimators: + if estimator is None: + continue + for key, value in six.iteritems(estimator.get_params(deep=True)): + out['%s__%s' % (name, key)] = value + return out + + def _set_params(self, attr, **params): + # Ensure strict ordering of parameter setting: + # 1. All steps + if attr in params: + setattr(self, attr, params.pop(attr)) + # 2. Step replacement + names, _ = zip(*getattr(self, attr)) + for name in list(six.iterkeys(params)): + if '__' not in name and name in names: + self._replace_estimator(attr, name, params.pop(name)) + # 3. Step parameters and other initilisation arguments + super(_BaseComposition, self).set_params(**params) + return self + + def _replace_estimator(self, attr, name, new_val): + # assumes `name` is a valid estimator name + new_estimators = getattr(self, attr)[:] + for i, (estimator_name, _) in enumerate(new_estimators): + if estimator_name == name: + new_estimators[i] = (name, new_val) + break + setattr(self, attr, new_estimators) + + def _validate_names(self, names): + if len(set(names)) != len(names): + raise ValueError('Names provided are not unique: ' + '{0!r}'.format(list(names))) + invalid_names = set(names).intersection(self.get_params(deep=False)) + if invalid_names: + raise ValueError('Estimator names conflict with constructor ' + 'arguments: {0!r}'.format(sorted(invalid_names))) + invalid_names = [name for name in names if '__' in name] + if invalid_names: + raise ValueError('Estimator names must not contain __: got ' + '{0!r}'.format(invalid_names)) + + class _IffHasAttrDescriptor(object): """Implements a conditional property using the descriptor protocol.