From 61b3753b658ab726c214d40f3faa1f9de4b0ff27 Mon Sep 17 00:00:00 2001 From: Gael varoquaux <gael.varoquaux@normalesup.org> Date: Tue, 23 Nov 2010 09:21:30 +0100 Subject: [PATCH] ENH: Raise error when cloning bug estimators Estimators with non conformant __init__ don't clone right. This prevents cross_val from working, and leads to highly non-trivial bugs. Detect it in clone and raise an error in this case. --- scikits/learn/base.py | 4 ++++ scikits/learn/tests/test_base.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index fccd43759b..a78679f52d 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -46,6 +46,10 @@ def clone(estimator, safe=True): for name, param in new_object_params.iteritems(): new_object_params[name] = clone(param, safe=False) new_object = klass(**new_object_params) + assert new_object._get_params(deep=False) == new_object_params, ( + 'Cannot clone object %s, as the constructor does not ' + 'seem to set parameters' % estimator + ) return new_object diff --git a/scikits/learn/tests/test_base.py b/scikits/learn/tests/test_base.py index 074c70bb2c..53a422f5cb 100644 --- a/scikits/learn/tests/test_base.py +++ b/scikits/learn/tests/test_base.py @@ -23,6 +23,11 @@ class T(BaseEstimator): self.a = a self.b = b +class Buggy(BaseEstimator): + " A buggy estimator that does not set its parameters right. " + + def __init__(self, a=None): + self.a = 1 ################################################################################ # The tests @@ -58,6 +63,11 @@ def test_clone_2(): new_selector = clone(selector) assert_false(hasattr(new_selector, "own_attribute")) +def test_clone_buggy(): + """ Check that clone raises an error on buggy estimators """ + buggy = Buggy() + buggy.a = 2 + assert_raises(AssertionError, clone, buggy) def test_repr(): """ Smoke test the repr of the -- GitLab