From 61b3753b658ab726c214d40f3faa1f9de4b0ff27 Mon Sep 17 00:00:00 2001
From: Gael varoquaux <gael.varoquaux@normalesup.org>
Date: Tue, 23 Nov 2010 09:21:30 +0100
Subject: [PATCH] ENH: Raise error when cloning bug estimators

Estimators with non conformant __init__ don't clone right. This prevents
cross_val from working, and leads to highly non-trivial bugs. Detect it
in clone and raise an error in this case.
---
 scikits/learn/base.py            |  4 ++++
 scikits/learn/tests/test_base.py | 10 ++++++++++
 2 files changed, 14 insertions(+)

diff --git a/scikits/learn/base.py b/scikits/learn/base.py
index fccd43759b..a78679f52d 100644
--- a/scikits/learn/base.py
+++ b/scikits/learn/base.py
@@ -46,6 +46,10 @@ def clone(estimator, safe=True):
     for name, param in new_object_params.iteritems():
         new_object_params[name] = clone(param, safe=False)
     new_object = klass(**new_object_params)
+    assert new_object._get_params(deep=False) == new_object_params, (
+            'Cannot clone object %s, as the constructor does not '
+            'seem to set parameters' % estimator
+        )
 
     return new_object
 
diff --git a/scikits/learn/tests/test_base.py b/scikits/learn/tests/test_base.py
index 074c70bb2c..53a422f5cb 100644
--- a/scikits/learn/tests/test_base.py
+++ b/scikits/learn/tests/test_base.py
@@ -23,6 +23,11 @@ class T(BaseEstimator):
         self.a = a
         self.b = b
 
+class Buggy(BaseEstimator):
+    " A buggy estimator that does not set its parameters right. "
+
+    def __init__(self, a=None):
+        self.a = 1
 
 ################################################################################
 # The tests
@@ -58,6 +63,11 @@ def test_clone_2():
     new_selector = clone(selector)
     assert_false(hasattr(new_selector, "own_attribute"))
 
+def test_clone_buggy():
+    """ Check that clone raises an error on buggy estimators """
+    buggy = Buggy()
+    buggy.a = 2
+    assert_raises(AssertionError, clone, buggy)
 
 def test_repr():
     """ Smoke test the repr of the
-- 
GitLab