diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index f51600547ec2f4d90c899c02aaea77e68ce7a074..f0a7f539b0494a39d3ef54b4e626f63f1bf8ae08 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -22,23 +22,28 @@ except ImportError: assert_false(x in container, msg="%r in %r" % (x, container)) +def _assert_less(a, b, msg=None): + message = "%r is not lower than %r" % (a, b) + if msg is not None: + message += ": " + msg + assert a < b, message + +def _assert_greater(a, b, msg=None): + message = "%r is not lower than %r" % (a, b) + if msg is not None: + message += ": " + msg + assert a > b, message + + try: from nose.tools import assert_less except ImportError: - def assert_less(a, b, msg=None): - message = "%r is not lower than %r" % (a, b) - if msg is not None: - message += ": " + msg - assert a < b, message + assert_less = _assert_less try: from nose.tools import assert_greater except ImportError: - def assert_greater(a, b, msg=None): - message = "%r is not lower than %r" % (a, b) - if msg is not None: - message += ": " + msg - assert a < b, message + assert_greater = _assert_greater def fake_mldata_cache(columns_dict, dataname, matfile, ordering=None): diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..89c662cecbd0d686b34657f5b2be9627db17f81e --- /dev/null +++ b/sklearn/utils/tests/test_testing.py @@ -0,0 +1,32 @@ +from nose.tools import assert_raises + +from sklearn.utils.testing import _assert_less, _assert_greater + +try: + from nose.tools import assert_less + + def test_assert_less(): + # Check that the nose implementation of assert_less gives the + # same thing as the scikit's + assert_less(0, 1) + _assert_less(0, 1) + assert_raises(AssertionError, assert_less, 1, 0) + assert_raises(AssertionError, _assert_less, 1, 0) + +except ImportError: + pass + +try: + from nose.tools import assert_greater + + def test_assert_greater(): + # Check that the nose implementation of assert_less gives the + # same thing as the scikit's + assert_greater(1, 0) + _assert_greater(1, 0) + assert_raises(AssertionError, assert_greater, 0, 1) + assert_raises(AssertionError, _assert_greater, 0, 1) + +except ImportError: + pass +