diff --git a/scikits/learn/grid_search.py b/scikits/learn/grid_search.py index bab564c04fd41e23e391ad531ed7bd3b512ee5c6..727e2cf449b9dfcb7b207d9feb0ba8ec30347949 100644 --- a/scikits/learn/grid_search.py +++ b/scikits/learn/grid_search.py @@ -95,7 +95,7 @@ def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, score_func, iid, y_pred = clf.predict(X_test) this_score = -loss_func(y_test, y_pred) elif score_func is not None: - y_pred = clf.predict(X_text) + y_pred = clf.predict(X_test) this_score = score_func(y_test, y_pred) else: this_score = clf.score(X_test, y_test) diff --git a/scikits/learn/tests/test_grid_search.py b/scikits/learn/tests/test_grid_search.py index bb36eb67aacec8332d0a047c2deaee542887840f..cd40f52da1440b723ae252d98dee01417835b7b0 100644 --- a/scikits/learn/tests/test_grid_search.py +++ b/scikits/learn/tests/test_grid_search.py @@ -13,6 +13,7 @@ from scikits.learn.grid_search import GridSearchCV from scikits.learn.datasets.samples_generator import test_dataset_classif from scikits.learn.svm import LinearSVC from scikits.learn.svm.sparse import LinearSVC as SparseLinearSVC +from scikits.learn.metrics import f1_score class MockClassifier(BaseEstimator): """Dummy classifier to test the cross-validation @@ -66,3 +67,24 @@ def test_grid_search_sparse(): assert_array_equal(y_pred, y_pred2) assert_equal(C, C2) + +def test_grid_search_sparse_score_func(): + X_, y_ = test_dataset_classif(n_samples=200, n_features=100, seed=0) + + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) + cv.fit(X_[:180], y_[:180]) + y_pred = cv.predict(X_[180:]) + C = cv.best_estimator.C + + X_ = sp.csr_matrix(X_) + clf = SparseLinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) + cv.fit(X_[:180], y_[:180]) + y_pred2 = cv.predict(X_[180:]) + C2 = cv.best_estimator.C + + assert_array_equal(y_pred, y_pred2) + assert_equal(C, C2) + +