diff --git a/scikits/learn/base.py b/scikits/learn/base.py index fccd43759b56b77162c901f7f9d9af42cbc36178..a78679f52d279db9d37635e2269c840c1c47775f 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -46,6 +46,10 @@ def clone(estimator, safe=True): for name, param in new_object_params.iteritems(): new_object_params[name] = clone(param, safe=False) new_object = klass(**new_object_params) + assert new_object._get_params(deep=False) == new_object_params, ( + 'Cannot clone object %s, as the constructor does not ' + 'seem to set parameters' % estimator + ) return new_object diff --git a/scikits/learn/tests/test_base.py b/scikits/learn/tests/test_base.py index 074c70bb2c0202ff309afe4b2f2f87d8e3b20329..53a422f5cbc0cb4a62160b2472495c42ca917017 100644 --- a/scikits/learn/tests/test_base.py +++ b/scikits/learn/tests/test_base.py @@ -23,6 +23,11 @@ class T(BaseEstimator): self.a = a self.b = b +class Buggy(BaseEstimator): + " A buggy estimator that does not set its parameters right. " + + def __init__(self, a=None): + self.a = 1 ################################################################################ # The tests @@ -58,6 +63,11 @@ def test_clone_2(): new_selector = clone(selector) assert_false(hasattr(new_selector, "own_attribute")) +def test_clone_buggy(): + """ Check that clone raises an error on buggy estimators """ + buggy = Buggy() + buggy.a = 2 + assert_raises(AssertionError, clone, buggy) def test_repr(): """ Smoke test the repr of the