from nose.tools import assert_true, assert_false, assert_equal, \ assert_raises from ..base import BaseEstimator, clone ################################################################################ # A few test classes class MyEstimator(BaseEstimator): def __init__(self, l1=0): self.l1 = l1 class K(BaseEstimator): def __init__(self, c=None, d=None): self.c = c self.d = d class T(BaseEstimator): def __init__(self, a=None, b=None): self.a = a self.b = b ################################################################################ # The tests def test_clone(): """Tests that clone creates a correct deep copy. We create an estimator, make a copy of its original state (which, in this case, is the current state of the setimator), and check that the obtained copy is a correct deep copy. """ from scikits.learn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) new_selector = clone(selector) assert_true(selector is not new_selector) assert_equal(selector._get_params(), new_selector._get_params()) def test_clone_2(): """Tests that clone doesn't copy everything. We first create an estimator, give it an own attribute, and make a copy of its original state. Then we check that the copy doesn't have the specific attribute we manually added to the initial estimator. """ from scikits.learn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) selector.own_attribute = "test" new_selector = clone(selector) assert_false(hasattr(new_selector, "own_attribute")) def test_repr(): """ Smoke test the repr of the """ my_estimator = MyEstimator() repr(my_estimator) test = T(K(), K()) assert_equal(repr(test), "T(a=K(c=None, d=None), b=K(c=None, d=None))" ) def test_str(): """ Smoke test the str of the """ my_estimator = MyEstimator() str(my_estimator) def test_get_params(): test = T(K(), K()) assert_true('a__d' in test._get_params(deep=True)) assert_true('a__d' not in test._get_params(deep=False)) test._set_params(a__d=2) assert test.a.d == 2 assert_raises(AssertionError, test._set_params, a__a=2)