From 9b4be76534ceb66e6fbbb3d7790eec982e84fec8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Wed, 2 May 2012 04:24:54 +0200 Subject: [PATCH] FIX #807: non regression test for KPCA on make_circles dataset --- .../decomposition/tests/test_kernel_pca.py | 29 ++++++++++++++++++- sklearn/utils/testing.py | 5 ++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index e0ef36a29b..e9a680b7c1 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -5,7 +5,11 @@ from numpy.testing import assert_array_almost_equal from nose.tools import assert_equal from nose.tools import assert_raises -from .. import PCA, KernelPCA +from sklearn.decomposition import PCA, KernelPCA +from sklearn.datasets import make_circles +from sklearn.linear_model import Perceptron +from sklearn.cross_validation import cross_val_score +from sklearn.utils.testing import assert_lower def test_kernel_pca(): @@ -114,6 +118,29 @@ def test_kernel_pca_invalid_kernel(): assert_raises(ValueError, kpca.fit, X_fit) +def test_nested_circles(): + """Test the linear separability of the first 2D KPCA transform""" + X, y = make_circles(n_samples=400, factor=.3, noise=.05, + random_state=0) + + # 2D nested circles are not linearly separable + train_score = Perceptron().fit(X, y).score(X, y) + assert_lower(train_score, 0.8) + + # Project the circles data into the first 2 components of a RBF Kernel + # PCA model. + # Not that the gamma value is data dependent. If this test breaks + # and the gamma value has to be updated, the Kernel PCA example will + # have to be updated too. + kpca = KernelPCA(kernel="rbf", n_components=2, + fit_inverse_transform=True, gamma=10.) + X_kpca = kpca.fit_transform(X) + + # The data is perfectly linearly separable in that space + train_score = Perceptron().fit(X_kpca, y).score(X_kpca, y) + assert_equal(train_score, 1.0) + + if __name__ == '__main__': import nose nose.run(argv=['', __file__]) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 2e9ba61658..96d1a13811 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -22,6 +22,11 @@ except ImportError: assert_false(x in container, msg="%r in %r" % (x, container)) +def assert_lower(a, b): + message = "%r is not lower than %r" % (a, b) + assert a < b, message + + def fake_mldata_cache(columns_dict, dataname, matfile, ordering=None): """Create a fake mldata data set in the cache_path. -- GitLab