diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index bde785969fba761b9cd28b0111dffff0a0c00132..6a522476f4d364c7ce320c3f0521908fce2fdc6d 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -652,22 +652,37 @@ def test_train_test_split_errors(): def test_train_test_split(): X = np.arange(100).reshape((10, 10)) X_s = coo_matrix(X) - y = range(10) - split = cval.train_test_split(X, X_s, y, allow_lists=False) - X_train, X_test, X_s_train, X_s_test, y_train, y_test = split - assert_array_equal(X_train, X_s_train.toarray()) - assert_array_equal(X_test, X_s_test.toarray()) - assert_array_equal(X_train[:, 0], y_train * 10) - assert_array_equal(X_test[:, 0], y_test * 10) + y = np.arange(10) + + # simple test split = cval.train_test_split(X, y, test_size=None, train_size=.5) X_train, X_test, y_train, y_test = split assert_equal(len(y_test), len(y_train)) + # test correspondence of X and y + assert_array_equal(X_train[:, 0], y_train * 10) + assert_array_equal(X_test[:, 0], y_test * 10) + + # conversion of lists to arrays (deprecated?) + split = cval.train_test_split(X, X_s, y.tolist(), force_arrays=True) + X_train, X_test, X_s_train, X_s_test, y_train, y_test = split + assert_array_equal(X_train, X_s_train.toarray()) + assert_array_equal(X_test, X_s_test.toarray()) - split = cval.train_test_split(X, X_s, y) + # don't convert lists to anything else by default + split = cval.train_test_split(X, X_s, y.tolist()) X_train, X_test, X_s_train, X_s_test, y_train, y_test = split assert_true(isinstance(y_train, list)) assert_true(isinstance(y_test, list)) + # allow nd-arrays + X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) + y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) + split = cval.train_test_split(X_4d, y_3d) + assert_equal(split[0].shape, (7, 5, 3, 2)) + assert_equal(split[1].shape, (3, 5, 3, 2)) + assert_equal(split[2].shape, (7, 7, 11)) + assert_equal(split[3].shape, (3, 7, 11)) + def train_test_split_pandas(): # check cross_val_score doesn't destroy pandas dataframe diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 1caf1854672e466ef227255cf0566b62e2b76b80..efc022b16bdd99a832a51608d1f085053f9f635e 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -435,6 +435,18 @@ def test_refit(): clf.fit(X, y) +def test_gridsearch_nd(): + """Pass X as list in GridSearchCV""" + X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) + y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) + check_X = lambda x: x.shape[1:] == (5, 3, 2) + check_y = lambda x: x.shape[1:] == (7, 11) + clf = CheckingClassifier(check_X=check_X, check_y=check_y) + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) + grid_search.fit(X_4d, y_3d).score(X, y) + assert_true(hasattr(grid_search, "grid_scores_")) + + def test_X_as_list(): """Pass X as list in GridSearchCV""" X = np.arange(100).reshape(10, 10)