diff --git a/scikits/learn/feature_extraction/tests/test_text.py b/scikits/learn/feature_extraction/tests/test_text.py index ce1f8e1a5a0af1daee89cdf607203f394dfa9dea..d6124d833ebf3bb6244e7e68841b09964d9b6281 100644 --- a/scikits/learn/feature_extraction/tests/test_text.py +++ b/scikits/learn/feature_extraction/tests/test_text.py @@ -277,10 +277,12 @@ def test_dense_vectorizer_pipeline_grid_selection(): pred = grid_search.fit(list(train_data), y_train).predict(list(test_data)) assert_array_equal(pred, y_test) - # on this toy dataset bigram representation yields higher predictive - # accurracy - # TODO: unstable test... - # assert_equal(grid_search.best_estimator.steps[0][1].analyzer.max_n, 2) + # on this toy dataset bigram representation which is used in the last of the + # grid_search is considered the best estimator since they all converge to + # 100% accurracy models + assert_equal(grid_search.best_score, 1.0) + best_vectorizer = grid_search.best_estimator.named_steps['vect'] + assert_equal(best_vectorizer.analyzer.max_n, 2) def test_pickle(): diff --git a/scikits/learn/grid_search.py b/scikits/learn/grid_search.py index 47bca81ca0964af26c5fa18f4ce258622e92fe32..1d8ca1378f4031158b704cb6aa727e883a772995 100644 --- a/scikits/learn/grid_search.py +++ b/scikits/learn/grid_search.py @@ -62,6 +62,7 @@ def iter_grid(param_grid): def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid, **fit_params): """Run fit on one set of parameters + Returns the score and the instance of the classifier """ # update parameters of the classifier after a copy of its base structure @@ -209,7 +210,7 @@ class GridSearchCV(BaseEstimator): for clf_params in grid) # Out is a list of pairs: score, estimator - best_estimator = max(out)[1] # get maximum score + self.best_score, best_estimator = max(out) # get maximum score if refit: # fit the best estimator using the entire dataset diff --git a/scikits/learn/pipeline.py b/scikits/learn/pipeline.py index 0b89ae03b8577958e8de69ac28cc7d714da2b69c..899fc5ec6d210ecd6875103c8133bfd9e64ec098 100644 --- a/scikits/learn/pipeline.py +++ b/scikits/learn/pipeline.py @@ -92,10 +92,10 @@ class Pipeline(BaseEstimator): fit/transform) that are chained, in the order in which they are chained, with the last object an estimator. """ - self._named_steps = dict(steps) + self.named_steps = dict(steps) names, estimators = zip(*steps) self.steps = steps - assert len(self._named_steps) == len(steps), ("Names provided are " + assert len(self.named_steps) == len(steps), ("Names provided are " "not unique: %s" % names) transforms = estimators[:-1] estimator = estimators[-1] @@ -115,8 +115,8 @@ class Pipeline(BaseEstimator): if not deep: return super(Pipeline, self)._get_params(deep=False) else: - out = self._named_steps.copy() - for name, step in self._named_steps.iteritems(): + out = self.named_steps.copy() + for name, step in self.named_steps.iteritems(): for key, value in step._get_params(deep=True).iteritems(): out['%s__%s' % (name, key)] = value return out diff --git a/scikits/learn/tests/test_pipeline.py b/scikits/learn/tests/test_pipeline.py index 01807f26ce59eafa0b75e83d0589d40b78fc82df..9e8bdbdbdfb60a83d870b46df9d4ffe3f1f69f9e 100644 --- a/scikits/learn/tests/test_pipeline.py +++ b/scikits/learn/tests/test_pipeline.py @@ -56,7 +56,7 @@ def test_pipeline_init(): # Test clone pipe2 = clone(pipe) - assert_false(pipe._named_steps['svc'] is pipe2._named_steps['svc']) + assert_false(pipe.named_steps['svc'] is pipe2.named_steps['svc']) # Check that appart from estimators, the parameters are the same params = pipe._get_params()