diff --git a/scikits/learn/linreg/tests/test_ridge.py b/scikits/learn/linreg/tests/test_ridge.py index e6b4cdc64a0ed0dbfaf52257ad96a6f65322d824..46309b6c9bf150f4ac5be2a061fcb609b334e1b5 100644 --- a/scikits/learn/linreg/tests/test_ridge.py +++ b/scikits/learn/linreg/tests/test_ridge.py @@ -84,6 +84,19 @@ def test_toy_ard_regression(): def test_toy_ridge_object(): + """ + Test BayesianRegression ridge classifier + TODO: test also nsamples > nfeatures + """ + X = np.array([[1], [2]]) + Y = np.array([1, 2]) + clf = Ridge(alpha=0.0) + clf.fit(X, Y) + Test = [[1], [2], [3], [4]] + assert_array_equal(clf.predict(Test), [1, 2, 3, 4]) # identity + + +def test_toy_bayesian_ridge_object(): """ Test BayesianRegression ridge classifier """ @@ -92,7 +105,7 @@ def test_toy_ridge_object(): clf = BayesianRidge() clf.fit(X, Y) Test = [[1], [2], [3], [4]] - assert(np.abs(clf.predict(Test)-[1, 2, 3, 4]).sum()<1.) # identity + assert_array_equal(clf.predict(Test), [1, 2, 3, 4]) # identity def test_toy_ard_object():