diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 86a67cc16b69217b654a8b19b635e10b7bace43b..10721daa823caf5dc814a7a7115c12805cda1523 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -54,7 +54,9 @@ from sklearn.base import (ClassifierMixin, RegressorMixin, TransformerMixin, __all__ = ["assert_equal", "assert_not_equal", "assert_raises", "assert_raises_regexp", "raises", "with_setup", "assert_true", "assert_false", "assert_almost_equal", "assert_array_equal", - "assert_array_almost_equal", "assert_array_less"] + "assert_array_almost_equal", "assert_array_less", + "assert_less", "assert_less_equal", + "assert_greater", "assert_greater_equal"] try: @@ -105,6 +107,20 @@ def _assert_greater(a, b, msg=None): assert a > b, message +def assert_less_equal(a, b, msg=None): + message = "%r is not lower than or equal to %r" % (a, b) + if msg is not None: + message += ": " + msg + assert a <= b, message + + +def assert_greater_equal(a, b, msg=None): + message = "%r is not greater than or equal to %r" % (a, b) + if msg is not None: + message += ": " + msg + assert a >= b, message + + # To remove when we support numpy 1.7 def assert_warns(warning_class, func, *args, **kw): """Test that a certain warning occurs.