diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index ce72f193ed8ddf1a4417c6f03631dfa9d839c867..816471cb5232f8ffcc4b93c6ce9f5697c9037f9c 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -163,6 +163,13 @@ Enhancements
 
    - In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict`` 
      is a lot faster with ``return_std=True`` by :user:`Hadrien Bertrand <hbertrand>`.
+   - Added ability to use sparse matrices in :func:`feature_selection.f_regression`
+     with ``center=True``. :issue:`8065` by :user:`Daniel LeJeune <acadiansith>`.
+
+   - :class:`ensemble.VotingClassifier` now allow changing estimators by using
+     :meth:`ensemble.VotingClassifier.set_params`. Estimators can also be
+     removed by setting it to `None`.
+     :issue:`7674` by:user:`Yichuan Liu <yl565>`.
 
 Bug fixes
 .........
diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py
index 2ad007741940c6b423cc13f6bf135239fba62e1b..d61d8bfac62bee6d42fb2c98b2ce4e527907670f 100644
--- a/sklearn/ensemble/tests/test_voting_classifier.py
+++ b/sklearn/ensemble/tests/test_voting_classifier.py
@@ -2,7 +2,7 @@
 
 import numpy as np
 from sklearn.utils.testing import assert_almost_equal, assert_array_equal
-from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_equal, assert_true, assert_false
 from sklearn.utils.testing import assert_raise_message
 from sklearn.exceptions import NotFittedError
 from sklearn.linear_model import LogisticRegression
@@ -40,6 +40,19 @@ def test_estimator_init():
            '; got 2 weights, 1 estimators')
     assert_raise_message(ValueError, msg, eclf.fit, X, y)
 
+    eclf = VotingClassifier(estimators=[('lr', clf), ('lr', clf)],
+                            weights=[1, 2])
+    msg = "Names provided are not unique: ['lr', 'lr']"
+    assert_raise_message(ValueError, msg, eclf.fit, X, y)
+
+    eclf = VotingClassifier(estimators=[('lr__', clf)])
+    msg = "Estimator names must not contain __: got ['lr__']"
+    assert_raise_message(ValueError, msg, eclf.fit, X, y)
+
+    eclf = VotingClassifier(estimators=[('estimators', clf)])
+    msg = "Estimator names conflict with constructor arguments: ['estimators']"
+    assert_raise_message(ValueError, msg, eclf.fit, X, y)
+
 
 def test_predictproba_hardvoting():
     eclf = VotingClassifier(estimators=[('lr1', LogisticRegression()),
@@ -260,6 +273,82 @@ def test_sample_weight():
     assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)
 
 
+def test_set_params():
+    """set_params should be able to set estimators"""
+    clf1 = LogisticRegression(random_state=123, C=1.0)
+    clf2 = RandomForestClassifier(random_state=123, max_depth=None)
+    clf3 = GaussianNB()
+    eclf1 = VotingClassifier([('lr', clf1), ('rf', clf2)], voting='soft',
+                             weights=[1, 2])
+    eclf1.fit(X, y)
+    eclf2 = VotingClassifier([('lr', clf1), ('nb', clf3)], voting='soft',
+                             weights=[1, 2])
+    eclf2.set_params(nb=clf2).fit(X, y)
+    assert_false(hasattr(eclf2, 'nb'))
+
+    assert_array_equal(eclf1.predict(X), eclf2.predict(X))
+    assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
+    assert_equal(eclf2.estimators[0][1].get_params(), clf1.get_params())
+    assert_equal(eclf2.estimators[1][1].get_params(), clf2.get_params())
+
+    eclf1.set_params(lr__C=10.0)
+    eclf2.set_params(nb__max_depth=5)
+
+    assert_true(eclf1.estimators[0][1].get_params()['C'] == 10.0)
+    assert_true(eclf2.estimators[1][1].get_params()['max_depth'] == 5)
+    assert_equal(eclf1.get_params()["lr__C"],
+                 eclf1.get_params()["lr"].get_params()['C'])
+
+
+def test_set_estimator_none():
+    """VotingClassifier set_params should be able to set estimators as None"""
+    # Test predict
+    clf1 = LogisticRegression(random_state=123)
+    clf2 = RandomForestClassifier(random_state=123)
+    clf3 = GaussianNB()
+    eclf1 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2),
+                                         ('nb', clf3)],
+                             voting='hard', weights=[1, 0, 0.5]).fit(X, y)
+
+    eclf2 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2),
+                                         ('nb', clf3)],
+                             voting='hard', weights=[1, 1, 0.5])
+    eclf2.set_params(rf=None).fit(X, y)
+    assert_array_equal(eclf1.predict(X), eclf2.predict(X))
+
+    assert_true(dict(eclf2.estimators)["rf"] is None)
+    assert_true(len(eclf2.estimators_) == 2)
+    assert_true(all([not isinstance(est, RandomForestClassifier) for est in
+                     eclf2.estimators_]))
+    assert_true(eclf2.get_params()["rf"] is None)
+
+    eclf1.set_params(voting='soft').fit(X, y)
+    eclf2.set_params(voting='soft').fit(X, y)
+    assert_array_equal(eclf1.predict(X), eclf2.predict(X))
+    assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
+    msg = ('All estimators are None. At least one is required'
+           ' to be a classifier!')
+    assert_raise_message(
+        ValueError, msg, eclf2.set_params(lr=None, rf=None, nb=None).fit, X, y)
+
+    # Test soft voting transform
+    X1 = np.array([[1], [2]])
+    y1 = np.array([1, 2])
+    eclf1 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
+                             voting='soft', weights=[0, 0.5]).fit(X1, y1)
+
+    eclf2 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
+                             voting='soft', weights=[1, 0.5])
+    eclf2.set_params(rf=None).fit(X1, y1)
+    assert_array_equal(eclf1.transform(X1), np.array([[[0.7, 0.3], [0.3, 0.7]],
+                                                      [[1., 0.], [0., 1.]]]))
+    assert_array_equal(eclf2.transform(X1), np.array([[[1., 0.], [0., 1.]]]))
+    eclf1.set_params(voting='hard')
+    eclf2.set_params(voting='hard')
+    assert_array_equal(eclf1.transform(X1), np.array([[0, 0], [1, 1]]))
+    assert_array_equal(eclf2.transform(X1), np.array([[0], [1]]))
+
+
 def test_estimator_weights_format():
     # Test estimator weights inputs as list and array
     clf1 = LogisticRegression(random_state=123)
diff --git a/sklearn/ensemble/voting_classifier.py b/sklearn/ensemble/voting_classifier.py
index cb0d6ad19c983b302a8c07b0cdd671d7fe865a3d..44cf4fe775ce3c385e484f55170a23452ae91d10 100644
--- a/sklearn/ensemble/voting_classifier.py
+++ b/sklearn/ensemble/voting_classifier.py
@@ -13,14 +13,13 @@ classification estimators.
 
 import numpy as np
 
-from ..base import BaseEstimator
 from ..base import ClassifierMixin
 from ..base import TransformerMixin
 from ..base import clone
 from ..preprocessing import LabelEncoder
-from ..externals import six
 from ..externals.joblib import Parallel, delayed
 from ..utils.validation import has_fit_parameter, check_is_fitted
+from ..utils.metaestimators import _BaseComposition
 
 
 def _parallel_fit_estimator(estimator, X, y, sample_weight):
@@ -32,7 +31,7 @@ def _parallel_fit_estimator(estimator, X, y, sample_weight):
     return estimator
 
 
-class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
+class VotingClassifier(_BaseComposition, ClassifierMixin, TransformerMixin):
     """Soft Voting/Majority Rule classifier for unfitted estimators.
 
     .. versionadded:: 0.17
@@ -44,7 +43,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
     estimators : list of (string, estimator) tuples
         Invoking the ``fit`` method on the ``VotingClassifier`` will fit clones
         of those original estimators that will be stored in the class attribute
-        `self.estimators_`.
+        ``self.estimators_``. An estimator can be set to `None` using
+        ``set_params``.
 
     voting : str, {'hard', 'soft'} (default='hard')
         If 'hard', uses predicted class labels for majority rule voting.
@@ -64,7 +64,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
     Attributes
     ----------
     estimators_ : list of classifiers
-        The collection of fitted sub-estimators.
+        The collection of fitted sub-estimators as defined in ``estimators``
+        that are not `None`.
 
     classes_ : array-like, shape = [n_predictions]
         The classes labels.
@@ -102,11 +103,14 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
 
     def __init__(self, estimators, voting='hard', weights=None, n_jobs=1):
         self.estimators = estimators
-        self.named_estimators = dict(estimators)
         self.voting = voting
         self.weights = weights
         self.n_jobs = n_jobs
 
+    @property
+    def named_estimators(self):
+        return dict(self.estimators)
+
     def fit(self, X, y, sample_weight=None):
         """ Fit the estimators.
 
@@ -150,11 +154,16 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
         if sample_weight is not None:
             for name, step in self.estimators:
                 if not has_fit_parameter(step, 'sample_weight'):
-                    raise ValueError('Underlying estimator \'%s\' does not support'
-                                     ' sample weights.' % name)
-
-        self.le_ = LabelEncoder()
-        self.le_.fit(y)
+                    raise ValueError('Underlying estimator \'%s\' does not'
+                                     ' support sample weights.' % name)
+        names, clfs = zip(*self.estimators)
+        self._validate_names(names)
+
+        n_isnone = np.sum([clf is None for _, clf in self.estimators])
+        if n_isnone == len(self.estimators):
+            raise ValueError('All estimators are None. At least one is '
+                             'required to be a classifier!')
+        self.le_ = LabelEncoder().fit(y)
         self.classes_ = self.le_.classes_
         self.estimators_ = []
 
@@ -162,11 +171,19 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
 
         self.estimators_ = Parallel(n_jobs=self.n_jobs)(
                 delayed(_parallel_fit_estimator)(clone(clf), X, transformed_y,
-                    sample_weight)
-                    for _, clf in self.estimators)
+                                                 sample_weight)
+                for clf in clfs if clf is not None)
 
         return self
 
+    @property
+    def _weights_not_none(self):
+        """Get the weights of not `None` estimators"""
+        if self.weights is None:
+            return None
+        return [w for est, w in zip(self.estimators,
+                                    self.weights) if est[1] is not None]
+
     def predict(self, X):
         """ Predict class labels for X.
 
@@ -188,11 +205,10 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
 
         else:  # 'hard' voting
             predictions = self._predict(X)
-            maj = np.apply_along_axis(lambda x:
-                                      np.argmax(np.bincount(x,
-                                                weights=self.weights)),
-                                      axis=1,
-                                      arr=predictions.astype('int'))
+            maj = np.apply_along_axis(
+                lambda x: np.argmax(
+                    np.bincount(x, weights=self._weights_not_none)),
+                axis=1, arr=predictions.astype('int'))
 
         maj = self.le_.inverse_transform(maj)
 
@@ -208,7 +224,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
             raise AttributeError("predict_proba is not available when"
                                  " voting=%r" % self.voting)
         check_is_fitted(self, 'estimators_')
-        avg = np.average(self._collect_probas(X), axis=0, weights=self.weights)
+        avg = np.average(self._collect_probas(X), axis=0,
+                         weights=self._weights_not_none)
         return avg
 
     @property
@@ -252,17 +269,42 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
         else:
             return self._predict(X)
 
+    def set_params(self, **params):
+        """ Setting the parameters for the voting classifier
+
+        Valid parameter keys can be listed with get_params().
+
+        Parameters
+        ----------
+        params: keyword arguments
+            Specific parameters using e.g. set_params(parameter_name=new_value)
+            In addition, to setting the parameters of the ``VotingClassifier``,
+            the individual classifiers of the ``VotingClassifier`` can also be
+            set or replaced by setting them to None.
+
+        Examples
+        --------
+        # In this example, the RandomForestClassifier is removed
+        clf1 = LogisticRegression()
+        clf2 = RandomForestClassifier()
+        eclf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2)]
+        eclf.set_params(rf=None)
+
+        """
+        super(VotingClassifier, self)._set_params('estimators', **params)
+        return self
+
     def get_params(self, deep=True):
-        """Return estimator parameter names for GridSearch support"""
-        if not deep:
-            return super(VotingClassifier, self).get_params(deep=False)
-        else:
-            out = super(VotingClassifier, self).get_params(deep=False)
-            out.update(self.named_estimators.copy())
-            for name, step in six.iteritems(self.named_estimators):
-                for key, value in six.iteritems(step.get_params(deep=True)):
-                    out['%s__%s' % (name, key)] = value
-            return out
+        """ Get the parameters of the VotingClassifier
+
+        Parameters
+        ----------
+        deep: bool
+            Setting it to True gets the various classifiers and the parameters
+            of the classifiers as well
+        """
+        return super(VotingClassifier,
+                     self)._get_params('estimators', deep=deep)
 
     def _predict(self, X):
         """Collect results from clf.predict calls. """
diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py
index 0361e109015ff2a918222eacb1416db4381c1df0..9377c8e2fd7aaace8210c16276b442d692ea475d 100644
--- a/sklearn/pipeline.py
+++ b/sklearn/pipeline.py
@@ -10,6 +10,7 @@ estimator, as a chain of transforms and estimators.
 # License: BSD
 
 from collections import defaultdict
+
 from abc import ABCMeta, abstractmethod
 
 import numpy as np
@@ -22,68 +23,12 @@ from .utils import tosequence
 from .utils.metaestimators import if_delegate_has_method
 from .utils import Bunch
 
-__all__ = ['Pipeline', 'FeatureUnion']
-
+from .utils.metaestimators import _BaseComposition
 
-class _BasePipeline(six.with_metaclass(ABCMeta, BaseEstimator)):
-    """Handles parameter management for classifiers composed of named steps.
-    """
+__all__ = ['Pipeline', 'FeatureUnion']
 
-    @abstractmethod
-    def __init__(self):
-        pass
-
-    def _replace_step(self, steps_attr, name, new_val):
-        # assumes `name` is a valid step name
-        new_steps = getattr(self, steps_attr)[:]
-        for i, (step_name, _) in enumerate(new_steps):
-            if step_name == name:
-                new_steps[i] = (name, new_val)
-                break
-        setattr(self, steps_attr, new_steps)
-
-    def _get_params(self, steps_attr, deep=True):
-        out = super(_BasePipeline, self).get_params(deep=False)
-        if not deep:
-            return out
-        steps = getattr(self, steps_attr)
-        out.update(steps)
-        for name, estimator in steps:
-            if estimator is None:
-                continue
-            for key, value in six.iteritems(estimator.get_params(deep=True)):
-                out['%s__%s' % (name, key)] = value
-        return out
-
-    def _set_params(self, steps_attr, **params):
-        # Ensure strict ordering of parameter setting:
-        # 1. All steps
-        if steps_attr in params:
-            setattr(self, steps_attr, params.pop(steps_attr))
-        # 2. Step replacement
-        step_names, _ = zip(*getattr(self, steps_attr))
-        for name in list(six.iterkeys(params)):
-            if '__' not in name and name in step_names:
-                self._replace_step(steps_attr, name, params.pop(name))
-        # 3. Step parameters and other initilisation arguments
-        super(_BasePipeline, self).set_params(**params)
-        return self
 
-    def _validate_names(self, names):
-        if len(set(names)) != len(names):
-            raise ValueError('Names provided are not unique: '
-                             '{0!r}'.format(list(names)))
-        invalid_names = set(names).intersection(self.get_params(deep=False))
-        if invalid_names:
-            raise ValueError('Step names conflict with constructor arguments: '
-                             '{0!r}'.format(sorted(invalid_names)))
-        invalid_names = [name for name in names if '__' in name]
-        if invalid_names:
-            raise ValueError('Step names must not contain __: got '
-                             '{0!r}'.format(invalid_names))
-
-
-class Pipeline(_BasePipeline):
+class Pipeline(_BaseComposition):
     """Pipeline of transforms with a final estimator.
 
     Sequentially apply a list of transforms and a final estimator.
@@ -631,7 +576,7 @@ def _fit_transform_one(transformer, weight, X, y,
     return res * weight, transformer
 
 
-class FeatureUnion(_BasePipeline, TransformerMixin):
+class FeatureUnion(_BaseComposition, TransformerMixin):
     """Concatenates results of multiple transformer objects.
 
     This estimator applies a list of transformer objects in parallel to the
diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py
index d4c4844fe375de21ab4ab65f2822a89191d0c572..a7c8e4593420ff290d87aae6085a69c7a0bf4a39 100644
--- a/sklearn/tests/test_pipeline.py
+++ b/sklearn/tests/test_pipeline.py
@@ -811,9 +811,9 @@ def test_step_name_validation():
         # we validate in construction (despite scikit-learn convention)
         bad_steps3 = [('a', Mult(2)), (param, Mult(3))]
         for bad_steps, message in [
-            (bad_steps1, "Step names must not contain __: got ['a__q']"),
+            (bad_steps1, "Estimator names must not contain __: got ['a__q']"),
             (bad_steps2, "Names provided are not unique: ['a', 'a']"),
-            (bad_steps3, "Step names conflict with constructor "
+            (bad_steps3, "Estimator names conflict with constructor "
                          "arguments: ['%s']" % param),
         ]:
             # three ways to make invalid:
diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py
index 3123bb1778ce36d1bde905326d7be883d978a5c6..fcbbbf894b76e8b6e05b13bdc754e13499a8e33b 100644
--- a/sklearn/utils/metaestimators.py
+++ b/sklearn/utils/metaestimators.py
@@ -3,14 +3,75 @@
 #         Andreas Mueller
 # License: BSD
 
+from abc import ABCMeta, abstractmethod
 from operator import attrgetter
 from functools import update_wrapper
 import numpy as np
+
 from ..utils import safe_indexing
+from ..externals import six
+from ..base import BaseEstimator
 
 __all__ = ['if_delegate_has_method']
 
 
+class _BaseComposition(six.with_metaclass(ABCMeta, BaseEstimator)):
+    """Handles parameter management for classifiers composed of named estimators.
+    """
+    @abstractmethod
+    def __init__(self):
+        pass
+
+    def _get_params(self, attr, deep=True):
+        out = super(_BaseComposition, self).get_params(deep=False)
+        if not deep:
+            return out
+        estimators = getattr(self, attr)
+        out.update(estimators)
+        for name, estimator in estimators:
+            if estimator is None:
+                continue
+            for key, value in six.iteritems(estimator.get_params(deep=True)):
+                out['%s__%s' % (name, key)] = value
+        return out
+
+    def _set_params(self, attr, **params):
+        # Ensure strict ordering of parameter setting:
+        # 1. All steps
+        if attr in params:
+            setattr(self, attr, params.pop(attr))
+        # 2. Step replacement
+        names, _ = zip(*getattr(self, attr))
+        for name in list(six.iterkeys(params)):
+            if '__' not in name and name in names:
+                self._replace_estimator(attr, name, params.pop(name))
+        # 3. Step parameters and other initilisation arguments
+        super(_BaseComposition, self).set_params(**params)
+        return self
+
+    def _replace_estimator(self, attr, name, new_val):
+        # assumes `name` is a valid estimator name
+        new_estimators = getattr(self, attr)[:]
+        for i, (estimator_name, _) in enumerate(new_estimators):
+            if estimator_name == name:
+                new_estimators[i] = (name, new_val)
+                break
+        setattr(self, attr, new_estimators)
+
+    def _validate_names(self, names):
+        if len(set(names)) != len(names):
+            raise ValueError('Names provided are not unique: '
+                             '{0!r}'.format(list(names)))
+        invalid_names = set(names).intersection(self.get_params(deep=False))
+        if invalid_names:
+            raise ValueError('Estimator names conflict with constructor '
+                             'arguments: {0!r}'.format(sorted(invalid_names)))
+        invalid_names = [name for name in names if '__' in name]
+        if invalid_names:
+            raise ValueError('Estimator names must not contain __: got '
+                             '{0!r}'.format(invalid_names))
+
+
 class _IffHasAttrDescriptor(object):
     """Implements a conditional property using the descriptor protocol.