From 754f73a5962bc9bfbc94b5450347e19836435e6f Mon Sep 17 00:00:00 2001 From: Joel Nothman <joel.nothman@gmail.com> Date: Wed, 18 Oct 2017 10:15:53 +1100 Subject: [PATCH] Add DeprecationDict for #9677 --- sklearn/utils/deprecation.py | 32 ++++++++++++++++++++++++- sklearn/utils/tests/test_deprecation.py | 16 +++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/deprecation.py b/sklearn/utils/deprecation.py index 08530be264..5621f436d9 100644 --- a/sklearn/utils/deprecation.py +++ b/sklearn/utils/deprecation.py @@ -2,7 +2,7 @@ import sys import warnings import functools -__all__ = ["deprecated", ] +__all__ = ["deprecated", "DeprecationDict"] class deprecated(object): @@ -102,3 +102,33 @@ def _is_deprecated(func): for c in closures if isinstance(c.cell_contents, str)])) return is_deprecated + + +class DeprecationDict(dict): + """A dict which raises a warning when some keys are looked up + + Note, this does not raise a warning for __contains__ and iteration. + + It also will raise a warning even after the key has been manually set by + the user. + """ + def __init__(self, *args, **kwargs): + self._deprecations = {} + super(DeprecationDict, self).__init__(*args, **kwargs) + + def __getitem__(self, key): + if key in self._deprecations: + warn_args, warn_kwargs = self._deprecations[key] + warnings.warn(*warn_args, **warn_kwargs) + return super(DeprecationDict, self).__getitem__(key) + + def get(self, key, default=None): + # dict does not implement it like this, hence it needs to be overridden + try: + return self[key] + except KeyError: + return default + + def add_warning(self, key, *args, **kwargs): + """Add a warning to be triggered when the specified key is read""" + self._deprecations[key] = (args, kwargs) diff --git a/sklearn/utils/tests/test_deprecation.py b/sklearn/utils/tests/test_deprecation.py index e5a1f021cd..d7b3f48c18 100644 --- a/sklearn/utils/tests/test_deprecation.py +++ b/sklearn/utils/tests/test_deprecation.py @@ -8,7 +8,9 @@ import pickle from sklearn.utils.deprecation import _is_deprecated from sklearn.utils.deprecation import deprecated from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import SkipTest +from sklearn.utils.deprecation import DeprecationDict @deprecated('qwerty') @@ -60,3 +62,17 @@ def test_is_deprecated(): def test_pickle(): pickle.loads(pickle.dumps(mock_function)) + + +def test_deprecationdict(): + dd = DeprecationDict() + dd.add_warning('a', 'hello') + dd.add_warning('b', 'world', DeprecationWarning) + assert 1 == assert_warns_message(UserWarning, 'hello', dd.get, 'a', 1) + dd['a'] = 5 + dd['b'] = 6 + dd['c'] = 7 + assert 5 == assert_warns_message(UserWarning, 'hello', dd.__getitem__, 'a') + assert 6 == assert_warns_message(DeprecationWarning, 'world', + dd.__getitem__, 'b') + assert 7 == assert_no_warnings(dd.get, 'c') -- GitLab