From 868a58b2e0ff23427fbfb625e4fc7f997c734480 Mon Sep 17 00:00:00 2001
From: Raghav RV <rvraghav93@gmail.com>
Date: Mon, 10 Oct 2016 21:33:44 +0200
Subject: [PATCH] [MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are
 pickle-able (#7594)

* FIX Subclass a new MaskedArray which allows pickling even when dype=object

* TST unpickling too

* FIX Use MaskedArray from utils.fixes rather than from numpy

* FIX imports

* Don't assign a variable

* FIX np --> numpy

* Use tostring instead of tobytes for old numpy

* COSMIT pickle-able --> picklable

* use #noqa comment to turn off flake8

* TST/ENH Check if the pickled est's predict matches with the original one's
---
 sklearn/model_selection/_search.py           |  7 +++++--
 sklearn/model_selection/tests/test_search.py |  8 +++++--
 sklearn/utils/fixes.py                       | 18 ++++++++++++++++
 sklearn/utils/tests/test_fixes.py            | 22 +++++++++++++++++---
 4 files changed, 48 insertions(+), 7 deletions(-)

diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index b2345398aa..82516f1e6b 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -30,6 +30,7 @@ from ..externals import six
 from ..utils import check_random_state
 from ..utils.fixes import sp_version
 from ..utils.fixes import rankdata
+from ..utils.fixes import MaskedArray
 from ..utils.random import sample_without_replacement
 from ..utils.validation import indexable, check_is_fitted
 from ..utils.metaestimators import if_delegate_has_method
@@ -611,10 +612,12 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
         best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
         best_parameters = candidate_params[best_index]
 
-        # Use one np.MaskedArray and mask all the places where the param is not
+        # Use one MaskedArray and mask all the places where the param is not
         # applicable for that candidate. Use defaultdict as each candidate may
         # not contain all the params
-        param_results = defaultdict(partial(np.ma.masked_all, (n_candidates,),
+        param_results = defaultdict(partial(MaskedArray,
+                                            np.empty(n_candidates,),
+                                            mask=True,
                                             dtype=object))
         for cand_i, params in enumerate(candidate_params):
             for name, value in params.items():
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 30daacf4c4..36e6965a11 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -940,12 +940,16 @@ def test_pickle():
     clf = MockClassifier()
     grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True)
     grid_search.fit(X, y)
-    pickle.dumps(grid_search)  # smoke test
+    grid_search_pickled = pickle.loads(pickle.dumps(grid_search))
+    assert_array_almost_equal(grid_search.predict(X),
+                              grid_search_pickled.predict(X))
 
     random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]},
                                        refit=True, n_iter=3)
     random_search.fit(X, y)
-    pickle.dumps(random_search)  # smoke test
+    random_search_pickled = pickle.loads(pickle.dumps(random_search))
+    assert_array_almost_equal(random_search.predict(X),
+                              random_search_pickled.predict(X))
 
 
 def test_grid_search_with_multioutput_data():
diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py
index aa27bf5434..682ab7733c 100644
--- a/sklearn/utils/fixes.py
+++ b/sklearn/utils/fixes.py
@@ -401,3 +401,21 @@ if sp_version < (0, 13, 0):
         return .5 * (count[dense] + count[dense - 1] + 1)
 else:
     from scipy.stats import rankdata
+
+
+if np_version < (1, 12, 0):
+    class MaskedArray(np.ma.MaskedArray):
+        # Before numpy 1.12, np.ma.MaskedArray object is not picklable
+        # This fix is needed to make our model_selection.GridSearchCV
+        # picklable as the ``cv_results_`` param uses MaskedArray
+        def __getstate__(self):
+            """Return the internal state of the masked array, for pickling
+            purposes.
+
+            """
+            cf = 'CF'[self.flags.fnc]
+            data_state = super(np.ma.MaskedArray, self).__reduce__()[2]
+            return data_state + (np.ma.getmaskarray(self).tostring(cf),
+                                 self._fill_value)
+else:
+    from numpy.ma import MaskedArray    # noqa
diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py
index f5817f246b..ef1110bfc4 100644
--- a/sklearn/utils/tests/test_fixes.py
+++ b/sklearn/utils/tests/test_fixes.py
@@ -3,13 +3,19 @@
 #          Lars Buitinck
 # License: BSD 3 clause
 
+import pickle
 import numpy as np
 
-from numpy.testing import (assert_almost_equal,
-                           assert_array_almost_equal)
+from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_false
+from sklearn.utils.testing import assert_true
+from sklearn.utils.testing import assert_almost_equal
+from sklearn.utils.testing import assert_array_equal
+from sklearn.utils.testing import assert_array_almost_equal
+
 from sklearn.utils.fixes import divide, expit
 from sklearn.utils.fixes import astype
-from sklearn.utils.testing import assert_equal, assert_false, assert_true
+from sklearn.utils.fixes import MaskedArray
 
 
 def test_expit():
@@ -50,3 +56,13 @@ def test_astype_copy_memory():
 
     e_int32 = astype(a_int32, dtype=np.int32)
     assert_false(np.may_share_memory(e_int32, a_int32))
+
+
+def test_masked_array_obj_dtype_pickleable():
+    marr = MaskedArray([1, None, 'a'], dtype=object)
+
+    for mask in (True, False, [0, 1, 0]):
+        marr.mask = mask
+        marr_pickled = pickle.loads(pickle.dumps(marr))
+        assert_array_equal(marr.data, marr_pickled.data)
+        assert_array_equal(marr.mask, marr_pickled.mask)
-- 
GitLab