Skip to content
Snippets Groups Projects
Commit 78e5a01f authored by Olivier Grisel's avatar Olivier Grisel
Browse files

make the pipeline / grid_search object nicer to introspect in tests

parent a30698ce
No related branches found
No related tags found
No related merge requests found
......@@ -277,10 +277,12 @@ def test_dense_vectorizer_pipeline_grid_selection():
pred = grid_search.fit(list(train_data), y_train).predict(list(test_data))
assert_array_equal(pred, y_test)
# on this toy dataset bigram representation yields higher predictive
# accurracy
# TODO: unstable test...
# assert_equal(grid_search.best_estimator.steps[0][1].analyzer.max_n, 2)
# on this toy dataset bigram representation which is used in the last of the
# grid_search is considered the best estimator since they all converge to
# 100% accurracy models
assert_equal(grid_search.best_score, 1.0)
best_vectorizer = grid_search.best_estimator.named_steps['vect']
assert_equal(best_vectorizer.analyzer.max_n, 2)
def test_pickle():
......
......@@ -62,6 +62,7 @@ def iter_grid(param_grid):
def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid,
**fit_params):
"""Run fit on one set of parameters
Returns the score and the instance of the classifier
"""
# update parameters of the classifier after a copy of its base structure
......@@ -209,7 +210,7 @@ class GridSearchCV(BaseEstimator):
for clf_params in grid)
# Out is a list of pairs: score, estimator
best_estimator = max(out)[1] # get maximum score
self.best_score, best_estimator = max(out) # get maximum score
if refit:
# fit the best estimator using the entire dataset
......
......@@ -92,10 +92,10 @@ class Pipeline(BaseEstimator):
fit/transform) that are chained, in the order in which
they are chained, with the last object an estimator.
"""
self._named_steps = dict(steps)
self.named_steps = dict(steps)
names, estimators = zip(*steps)
self.steps = steps
assert len(self._named_steps) == len(steps), ("Names provided are "
assert len(self.named_steps) == len(steps), ("Names provided are "
"not unique: %s" % names)
transforms = estimators[:-1]
estimator = estimators[-1]
......@@ -115,8 +115,8 @@ class Pipeline(BaseEstimator):
if not deep:
return super(Pipeline, self)._get_params(deep=False)
else:
out = self._named_steps.copy()
for name, step in self._named_steps.iteritems():
out = self.named_steps.copy()
for name, step in self.named_steps.iteritems():
for key, value in step._get_params(deep=True).iteritems():
out['%s__%s' % (name, key)] = value
return out
......
......@@ -56,7 +56,7 @@ def test_pipeline_init():
# Test clone
pipe2 = clone(pipe)
assert_false(pipe._named_steps['svc'] is pipe2._named_steps['svc'])
assert_false(pipe.named_steps['svc'] is pipe2.named_steps['svc'])
# Check that appart from estimators, the parameters are the same
params = pipe._get_params()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment