diff --git a/doc/developers/index.rst b/doc/developers/index.rst index 23dd7dfff623c33a10bfbf40e441c0c96551f710..9dd0d9b7bcc77bf9e5e96b87cf4e584442b68f1b 100644 --- a/doc/developers/index.rst +++ b/doc/developers/index.rst @@ -328,6 +328,11 @@ classifier or a regressor. All estimators implement the fit method:: estimator.fit(X, y) +All built-in estimators also have a ``set_params`` method, which sets +data-independent parameters (overriding previous parameter values passed +to ``__init__``). This method is not required for an object to be an +estimator, but is used by the :class:`grid_search.GridSearchCV` class. + Instantiation ^^^^^^^^^^^^^ diff --git a/examples/svm/plot_svm_anova.py b/examples/svm/plot_svm_anova.py index 6187afe801218d78b39c388146e0e357cf133f22..100cd0593f67c024cf957690befb2263e8068bb1 100644 --- a/examples/svm/plot_svm_anova.py +++ b/examples/svm/plot_svm_anova.py @@ -40,7 +40,7 @@ score_stds = list() percentiles = (1, 3, 6, 10, 15, 20, 30, 40, 60, 80, 100) for percentile in percentiles: - clf._set_params(anova__percentile=percentile) + clf.set_params(anova__percentile=percentile) # Compute cross-validation score using all CPUs this_scores = cross_val.cross_val_score(clf, X, y, n_jobs=1) score_means.append(this_scores.mean()) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index 255b995240c7576e09206e68f36e5a7cd8e2cfaf..944fa82209129dcad2ab345af765a794d24c2458 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -6,8 +6,10 @@ import copy import inspect import numpy as np from scipy import sparse +import warnings from .metrics import r2_score +from .utils import deprecated ############################################################################### @@ -172,13 +174,17 @@ class BaseEstimator(object): out[key] = value return out - def _set_params(self, **params): + def set_params(self, **params): """ Set the parameters of the estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The former have parameters of the - form <component>__<parameter> so that the its possible to - update each component of the nested object. + form <component>__<parameter> so that it's possible to update + each component of a nested object. + + Returns + ------- + self """ if not params: # Simple optimisation to gain speed (inspect is slow) @@ -198,7 +204,7 @@ class BaseEstimator(object): 'sub parameter %s' % (sub_name, self.__class__.__name__, sub_name) ) - sub_object._set_params(**{sub_name: value}) + sub_object.set_params(**{sub_name: value}) else: # simple objects case assert key in valid_params, ('Invalid parameter %s ' @@ -207,6 +213,13 @@ class BaseEstimator(object): setattr(self, key, value) return self + def _set_params(self, **params): + if params != {}: + warnings.warn("Passing estimator parameters to fit is deprecated;" + " use set_params instead", + category=DeprecationWarning) + return self.set_params(**params) + def __repr__(self): class_name = self.__class__.__name__ return '%s(%s)' % ( diff --git a/scikits/learn/cluster/dbscan_.py b/scikits/learn/cluster/dbscan_.py index 7ba2c0cc680a376aff384ac3dc2513b83299a09f..00b68b2d63f2b8ee8f42287c7820d75023704b32 100644 --- a/scikits/learn/cluster/dbscan_.py +++ b/scikits/learn/cluster/dbscan_.py @@ -188,7 +188,7 @@ class DBSCAN(BaseEstimator): Overwrite keywords from __init__. """ - self._set_params(**params) + self.set_params(**params) self.core_sample_indices_, self.labels_ = dbscan(X, **self._get_params()) self.components_ = X[self.core_sample_indices_].copy() diff --git a/scikits/learn/cluster/k_means_.py b/scikits/learn/cluster/k_means_.py index b91cee3329240ed9cfdd6a0941a2e4236f13dbe4..cc54f0e9222ebbc21c726b94627a5da1b9e66b85 100644 --- a/scikits/learn/cluster/k_means_.py +++ b/scikits/learn/cluster/k_means_.py @@ -482,25 +482,25 @@ class KMeans(BaseEstimator): self.random_state = random_state self.copy_x = copy_x - def _check_data(self, X, **params): - """ - Set parameters and check the sample given is larger than k - """ + def _check_data(self, X): + """Verify that the number of samples given is larger than k""" if sp.issparse(X): raise ValueError("K-Means does not support sparse input matrices.") X = np.asanyarray(X) if X.shape[0] < self.k: raise ValueError("n_samples=%d should be larger than k=%d" % ( X.shape[0], self.k)) - self._set_params(**params) return X - def fit(self, X, **params): + def fit(self, X, k=None, **params): """Compute k-means""" self.random_state = check_random_state(self.random_state) X = self._check_data(X, **params) + if k != None: + self.k = k + self._set_params(**params) self.cluster_centers_, self.labels_, self.inertia_ = k_means( X, k=self.k, init=self.init, n_init=self.n_init, @@ -647,7 +647,7 @@ class MiniBatchKMeans(KMeans): self.cluster_centers_ = None self.chunk_size = chunk_size - def fit(self, X, y=None, **params): + def fit(self, X, y=None): """ Calculates the centroids on a batch X @@ -656,7 +656,6 @@ class MiniBatchKMeans(KMeans): X: array-like, shape = [n_samples, n_features] Coordinates of the data points to cluster """ - self._set_params(**params) self.random_state = check_random_state(self.random_state) X = check_arrays(X, sparse_format="csr", copy=False)[0] n_samples, n_features = X.shape @@ -709,7 +708,7 @@ class MiniBatchKMeans(KMeans): return self - def partial_fit(self, X, y=None, **params): + def partial_fit(self, X, y=None): """Update k means estimate on a single mini-batch X. Parameters diff --git a/scikits/learn/cluster/tests/test_dbscan.py b/scikits/learn/cluster/tests/test_dbscan.py index 8b05d98f3cd9b81f20efa4fbf7e8d71965009f55..7c36e710e7cd2bd5928e7bb75131a078f04ff846 100644 --- a/scikits/learn/cluster/tests/test_dbscan.py +++ b/scikits/learn/cluster/tests/test_dbscan.py @@ -55,9 +55,8 @@ def test_dbscan_feature(): n_clusters_1 = len(set(labels)) - int(-1 in labels) assert_equal(n_clusters_1, n_clusters) - db = DBSCAN() - labels = db.fit(X, metric=metric, - eps=eps, min_samples=min_samples).labels_ + db = DBSCAN(metric=metric) + labels = db.fit(X, eps=eps, min_samples=min_samples).labels_ n_clusters_2 = len(set(labels)) - int(-1 in labels) assert_equal(n_clusters_2, n_clusters) @@ -80,9 +79,8 @@ def test_dbscan_callable(): n_clusters_1 = len(set(labels)) - int(-1 in labels) assert_equal(n_clusters_1, n_clusters) - db = DBSCAN() - labels = db.fit(X, metric=metric, - eps=eps, min_samples=min_samples).labels_ + db = DBSCAN(metric=metric) + labels = db.fit(X, eps=eps, min_samples=min_samples).labels_ n_clusters_2 = len(set(labels)) - int(-1 in labels) assert_equal(n_clusters_2, n_clusters) diff --git a/scikits/learn/feature_selection/univariate_selection.py b/scikits/learn/feature_selection/univariate_selection.py index 534e15398a85c33f7cf63df630a4a62db71f9e67..8306cc651b8e33611910bf735aebd77e673c52d7 100644 --- a/scikits/learn/feature_selection/univariate_selection.py +++ b/scikits/learn/feature_selection/univariate_selection.py @@ -446,9 +446,9 @@ class GenericUnivariateSelect(_AbstractUnivariateFilter): selector = self._selection_modes[self.mode](lambda x: x) selector._pvalues = self._pvalues selector._scores = self._scores - # Now make some acrobaties to set the right named parameter in + # Now perform some acrobatics to set the right named parameter in # the selector possible_params = selector._get_param_names() possible_params.remove('score_func') - selector._set_params(**{possible_params[0]: self.param}) + selector.set_params(**{possible_params[0]: self.param}) return selector._get_support_mask() diff --git a/scikits/learn/grid_search.py b/scikits/learn/grid_search.py index 9056996b762e4ed115559092292fdc928f82f067..9f9f87bbb090ea202e66d5736fa60d965de13b38 100644 --- a/scikits/learn/grid_search.py +++ b/scikits/learn/grid_search.py @@ -69,7 +69,7 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func, print "[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.') # update parameters of the classifier after a copy of its base structure clf = copy.deepcopy(base_clf) - clf._set_params(**clf_params) + clf.set_params(**clf_params) if isinstance(X, list) or isinstance(X, tuple): X_train = [X[i] for i, cond in enumerate(train) if cond] diff --git a/scikits/learn/tests/test_base.py b/scikits/learn/tests/test_base.py index 53a422f5cbc0cb4a62160b2472495c42ca917017..4d1fb4ab572b3b7e9797b2f64c0de3fb91a42375 100644 --- a/scikits/learn/tests/test_base.py +++ b/scikits/learn/tests/test_base.py @@ -93,9 +93,9 @@ def test_get_params(): assert_true('a__d' in test._get_params(deep=True)) assert_true('a__d' not in test._get_params(deep=False)) - test._set_params(a__d=2) + test.set_params(a__d=2) assert test.a.d == 2 - assert_raises(AssertionError, test._set_params, a__a=2) + assert_raises(AssertionError, test.set_params, a__a=2) def test_is_classifier(): diff --git a/scikits/learn/tests/test_cross_val.py b/scikits/learn/tests/test_cross_val.py index c849ae1328748acb4f51eddfa510a367414b0065..c130d4178825d4394899a332c0348317a77b201e 100644 --- a/scikits/learn/tests/test_cross_val.py +++ b/scikits/learn/tests/test_cross_val.py @@ -23,8 +23,7 @@ class MockClassifier(BaseEstimator): def __init__(self, a=0): self.a = a - def fit(self, X, Y, **params): - self._set_params(**params) + def fit(self, X, Y): return self def predict(self, T): diff --git a/scikits/learn/tests/test_grid_search.py b/scikits/learn/tests/test_grid_search.py index e52e06ba2a08263bd21bb4d61d8af4f4e954bd1d..83e18ab46766ddf0bbfc928c47db1db50312e4d0 100644 --- a/scikits/learn/tests/test_grid_search.py +++ b/scikits/learn/tests/test_grid_search.py @@ -21,8 +21,7 @@ class MockClassifier(BaseEstimator): def __init__(self, foo_param=0): self.foo_param = foo_param - def fit(self, X, Y, **params): - self._set_params(**params) + def fit(self, X, Y): return self def predict(self, T): @@ -88,7 +87,7 @@ def test_grid_search_sparse_score_func(): clf = LinearSVC() cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) # XXX: set refit to False due to a random bug when True (default) - cv.fit(X_[:180], y_[:180], refit=False) + cv.set_params(refit=False).fit(X_[:180], y_[:180]) y_pred = cv.predict(X_[180:]) C = cv.best_estimator.C @@ -96,7 +95,7 @@ def test_grid_search_sparse_score_func(): clf = SparseLinearSVC() cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) # XXX: set refit to False due to a random bug when True (default) - cv.fit(X_[:180], y_[:180], refit=False) + cv.set_params(refit=False).fit(X_[:180], y_[:180]) y_pred2 = cv.predict(X_[180:]) C2 = cv.best_estimator.C diff --git a/scikits/learn/tests/test_naive_bayes.py b/scikits/learn/tests/test_naive_bayes.py index 81321931d1122a7b60623148f39ff5f825746019..730b7e5866ab204078dcfd39593a0c1b189be012 100644 --- a/scikits/learn/tests/test_naive_bayes.py +++ b/scikits/learn/tests/test_naive_bayes.py @@ -123,7 +123,8 @@ def test_discretenb_uniform_prior(): when fit_prior=False and class_prior=None""" for cls in [BernoulliNB, MultinomialNB]: - clf = cls(fit_prior=False) + clf = cls() + clf.set_params(fit_prior=False) clf.fit([[0], [0], [1]], [0, 0, 1]) prior = np.exp(clf.class_log_prior_) assert prior[0] == prior[1] diff --git a/scikits/learn/tests/test_pipeline.py b/scikits/learn/tests/test_pipeline.py index 98911a44c0a6b3328795ad351510801c0120968a..42e1f7f78fa1eda79f35ac5003cf0a227548f79c 100644 --- a/scikits/learn/tests/test_pipeline.py +++ b/scikits/learn/tests/test_pipeline.py @@ -44,7 +44,7 @@ def test_pipeline_init(): dict(svc__a=None, svc__b=None, svc=clf)) # Check that params are set - pipe._set_params(svc__a=0.1) + pipe.set_params(svc__a=0.1) assert_equal(clf.a, 0.1) # Smoke test the repr: repr(pipe) @@ -55,13 +55,13 @@ def test_pipeline_init(): pipe = Pipeline([('anova', filter1), ('svc', clf)]) # Check that params are set - pipe._set_params(svc__C=0.1) + pipe.set_params(svc__C=0.1) assert_equal(clf.C, 0.1) # Smoke test the repr: repr(pipe) # Check that params are not set when naming them wrong - assert_raises(AssertionError, pipe._set_params, anova__C=0.1) + assert_raises(AssertionError, pipe.set_params, anova__C=0.1) # Test clone pipe2 = clone(pipe)