diff --git a/scikits/learn/grid_search.py b/scikits/learn/grid_search.py index 90774361bd9a7f51d044c363fe0cfb090d45931e..6c434561dfab67e8ab00c8ce72686c9cef50184f 100644 --- a/scikits/learn/grid_search.py +++ b/scikits/learn/grid_search.py @@ -12,6 +12,9 @@ from .externals.joblib import Parallel, delayed from .cross_val import KFold, StratifiedKFold from .base import BaseEstimator, is_classifier, clone +import numpy as np +import scipy.sparse as sp + try: from itertools import product except: @@ -70,6 +73,12 @@ def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid, score = 0. n_test_samples = 0. for train, test in cv: + if sp.issparse(X): + # slicing only works with indices in sparse matrices + ind = np.arange(X.shape[0]) + train = ind[train] + test = ind[test] + clf.fit(X[train], y[train], **fit_params) y_test = y[test] if loss_func is not None: @@ -78,7 +87,7 @@ def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid, else: this_score = clf.score(X[test], y_test) if iid: - this_n_test_samples = y.shape[0] + this_n_test_samples = y.shape[0] this_score *= this_n_test_samples n_test_samples += this_n_test_samples score += this_score @@ -150,7 +159,7 @@ class GridSearchCV(BaseEstimator): fit_params={}, n_jobs=1, iid=True): assert hasattr(estimator, 'fit') and hasattr(estimator, 'predict'), ( "estimator should a be an estimator implementing 'fit' and " - "'predict' methods, %s (type %s) was passed" % + "'predict' methods, %s (type %s) was passed" % (estimator, type(estimator)) ) if loss_func is None: diff --git a/scikits/learn/tests/test_grid_search.py b/scikits/learn/tests/test_grid_search.py index e3d29099e14959bdc5306e0b4c18c6d34a90a7a0..bb36eb67aacec8332d0a047c2deaee542887840f 100644 --- a/scikits/learn/tests/test_grid_search.py +++ b/scikits/learn/tests/test_grid_search.py @@ -3,14 +3,20 @@ Testing for grid search module (scikits.learn.grid_search) """ from nose.tools import assert_equal +from numpy.testing import assert_array_equal import numpy as np +import scipy.sparse as sp + from scikits.learn.base import BaseEstimator from scikits.learn.grid_search import GridSearchCV +from scikits.learn.datasets.samples_generator import test_dataset_classif +from scikits.learn.svm import LinearSVC +from scikits.learn.svm.sparse import LinearSVC as SparseLinearSVC class MockClassifier(BaseEstimator): """Dummy classifier to test the cross-validation - + """ def __init__(self, foo_param=0): self.foo_param = foo_param @@ -33,7 +39,30 @@ X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) y = np.array([1, 1, 2, 2]) -def test_GridSearch(): +def test_grid_search(): + """Test that the best estimator contains the right value for foo_param""" clf = MockClassifier() cross_validation = GridSearchCV(clf, {'foo_param':[1, 2, 3]}) assert_equal(cross_validation.fit(X, y).best_estimator.foo_param, 2) + + +def test_grid_search_sparse(): + """Test that grid search works with both dense and sparse matrices""" + X_, y_ = test_dataset_classif(n_samples=200, n_features=100, seed=0) + + clf = LinearSVC() + cv = GridSearchCV(clf, {'C':[0.1, 1.0]}) + cv.fit(X_[:180], y_[:180]) + y_pred = cv.predict(X_[180:]) + C = cv.best_estimator.C + + X_ = sp.csr_matrix(X_) + clf = SparseLinearSVC() + cv = GridSearchCV(clf, {'C':[0.1, 1.0]}) + cv.fit(X_[:180], y_[:180]) + y_pred2 = cv.predict(X_[180:]) + C2 = cv.best_estimator.C + + assert_array_equal(y_pred, y_pred2) + assert_equal(C, C2) +