From 868a58b2e0ff23427fbfb625e4fc7f997c734480 Mon Sep 17 00:00:00 2001 From: Raghav RV <rvraghav93@gmail.com> Date: Mon, 10 Oct 2016 21:33:44 +0200 Subject: [PATCH] [MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able (#7594) * FIX Subclass a new MaskedArray which allows pickling even when dype=object * TST unpickling too * FIX Use MaskedArray from utils.fixes rather than from numpy * FIX imports * Don't assign a variable * FIX np --> numpy * Use tostring instead of tobytes for old numpy * COSMIT pickle-able --> picklable * use #noqa comment to turn off flake8 * TST/ENH Check if the pickled est's predict matches with the original one's --- sklearn/model_selection/_search.py | 7 +++++-- sklearn/model_selection/tests/test_search.py | 8 +++++-- sklearn/utils/fixes.py | 18 ++++++++++++++++ sklearn/utils/tests/test_fixes.py | 22 +++++++++++++++++--- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index b2345398aa..82516f1e6b 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -30,6 +30,7 @@ from ..externals import six from ..utils import check_random_state from ..utils.fixes import sp_version from ..utils.fixes import rankdata +from ..utils.fixes import MaskedArray from ..utils.random import sample_without_replacement from ..utils.validation import indexable, check_is_fitted from ..utils.metaestimators import if_delegate_has_method @@ -611,10 +612,12 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, best_index = np.flatnonzero(results["rank_test_score"] == 1)[0] best_parameters = candidate_params[best_index] - # Use one np.MaskedArray and mask all the places where the param is not + # Use one MaskedArray and mask all the places where the param is not # applicable for that candidate. Use defaultdict as each candidate may # not contain all the params - param_results = defaultdict(partial(np.ma.masked_all, (n_candidates,), + param_results = defaultdict(partial(MaskedArray, + np.empty(n_candidates,), + mask=True, dtype=object)) for cand_i, params in enumerate(candidate_params): for name, value in params.items(): diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 30daacf4c4..36e6965a11 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -940,12 +940,16 @@ def test_pickle(): clf = MockClassifier() grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True) grid_search.fit(X, y) - pickle.dumps(grid_search) # smoke test + grid_search_pickled = pickle.loads(pickle.dumps(grid_search)) + assert_array_almost_equal(grid_search.predict(X), + grid_search_pickled.predict(X)) random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True, n_iter=3) random_search.fit(X, y) - pickle.dumps(random_search) # smoke test + random_search_pickled = pickle.loads(pickle.dumps(random_search)) + assert_array_almost_equal(random_search.predict(X), + random_search_pickled.predict(X)) def test_grid_search_with_multioutput_data(): diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index aa27bf5434..682ab7733c 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -401,3 +401,21 @@ if sp_version < (0, 13, 0): return .5 * (count[dense] + count[dense - 1] + 1) else: from scipy.stats import rankdata + + +if np_version < (1, 12, 0): + class MaskedArray(np.ma.MaskedArray): + # Before numpy 1.12, np.ma.MaskedArray object is not picklable + # This fix is needed to make our model_selection.GridSearchCV + # picklable as the ``cv_results_`` param uses MaskedArray + def __getstate__(self): + """Return the internal state of the masked array, for pickling + purposes. + + """ + cf = 'CF'[self.flags.fnc] + data_state = super(np.ma.MaskedArray, self).__reduce__()[2] + return data_state + (np.ma.getmaskarray(self).tostring(cf), + self._fill_value) +else: + from numpy.ma import MaskedArray # noqa diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index f5817f246b..ef1110bfc4 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -3,13 +3,19 @@ # Lars Buitinck # License: BSD 3 clause +import pickle import numpy as np -from numpy.testing import (assert_almost_equal, - assert_array_almost_equal) +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_false +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_array_almost_equal + from sklearn.utils.fixes import divide, expit from sklearn.utils.fixes import astype -from sklearn.utils.testing import assert_equal, assert_false, assert_true +from sklearn.utils.fixes import MaskedArray def test_expit(): @@ -50,3 +56,13 @@ def test_astype_copy_memory(): e_int32 = astype(a_int32, dtype=np.int32) assert_false(np.may_share_memory(e_int32, a_int32)) + + +def test_masked_array_obj_dtype_pickleable(): + marr = MaskedArray([1, None, 'a'], dtype=object) + + for mask in (True, False, [0, 1, 0]): + marr.mask = mask + marr_pickled = pickle.loads(pickle.dumps(marr)) + assert_array_equal(marr.data, marr_pickled.data) + assert_array_equal(marr.mask, marr_pickled.mask) -- GitLab