From 754f73a5962bc9bfbc94b5450347e19836435e6f Mon Sep 17 00:00:00 2001
From: Joel Nothman <joel.nothman@gmail.com>
Date: Wed, 18 Oct 2017 10:15:53 +1100
Subject: [PATCH] Add DeprecationDict for #9677

---
 sklearn/utils/deprecation.py            | 32 ++++++++++++++++++++++++-
 sklearn/utils/tests/test_deprecation.py | 16 +++++++++++++
 2 files changed, 47 insertions(+), 1 deletion(-)

diff --git a/sklearn/utils/deprecation.py b/sklearn/utils/deprecation.py
index 08530be264..5621f436d9 100644
--- a/sklearn/utils/deprecation.py
+++ b/sklearn/utils/deprecation.py
@@ -2,7 +2,7 @@ import sys
 import warnings
 import functools
 
-__all__ = ["deprecated", ]
+__all__ = ["deprecated", "DeprecationDict"]
 
 
 class deprecated(object):
@@ -102,3 +102,33 @@ def _is_deprecated(func):
                                               for c in closures
                      if isinstance(c.cell_contents, str)]))
     return is_deprecated
+
+
+class DeprecationDict(dict):
+    """A dict which raises a warning when some keys are looked up
+
+    Note, this does not raise a warning for __contains__ and iteration.
+
+    It also will raise a warning even after the key has been manually set by
+    the user.
+    """
+    def __init__(self, *args, **kwargs):
+        self._deprecations = {}
+        super(DeprecationDict, self).__init__(*args, **kwargs)
+
+    def __getitem__(self, key):
+        if key in self._deprecations:
+            warn_args, warn_kwargs = self._deprecations[key]
+            warnings.warn(*warn_args, **warn_kwargs)
+        return super(DeprecationDict, self).__getitem__(key)
+
+    def get(self, key, default=None):
+        # dict does not implement it like this, hence it needs to be overridden
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def add_warning(self, key, *args, **kwargs):
+        """Add a warning to be triggered when the specified key is read"""
+        self._deprecations[key] = (args, kwargs)
diff --git a/sklearn/utils/tests/test_deprecation.py b/sklearn/utils/tests/test_deprecation.py
index e5a1f021cd..d7b3f48c18 100644
--- a/sklearn/utils/tests/test_deprecation.py
+++ b/sklearn/utils/tests/test_deprecation.py
@@ -8,7 +8,9 @@ import pickle
 from sklearn.utils.deprecation import _is_deprecated
 from sklearn.utils.deprecation import deprecated
 from sklearn.utils.testing import assert_warns_message
+from sklearn.utils.testing import assert_no_warnings
 from sklearn.utils.testing import SkipTest
+from sklearn.utils.deprecation import DeprecationDict
 
 
 @deprecated('qwerty')
@@ -60,3 +62,17 @@ def test_is_deprecated():
 
 def test_pickle():
     pickle.loads(pickle.dumps(mock_function))
+
+
+def test_deprecationdict():
+    dd = DeprecationDict()
+    dd.add_warning('a', 'hello')
+    dd.add_warning('b', 'world', DeprecationWarning)
+    assert 1 == assert_warns_message(UserWarning, 'hello', dd.get, 'a', 1)
+    dd['a'] = 5
+    dd['b'] = 6
+    dd['c'] = 7
+    assert 5 == assert_warns_message(UserWarning, 'hello', dd.__getitem__, 'a')
+    assert 6 == assert_warns_message(DeprecationWarning, 'world',
+                                     dd.__getitem__, 'b')
+    assert 7 == assert_no_warnings(dd.get, 'c')
-- 
GitLab