diff --git a/sklearn/base.py b/sklearn/base.py index c2b36887bbea815d3847fb80ac4c4e9657ac31b8..d1628f39b3727fc193e1cad07e71d3c62ac10217 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -12,6 +12,7 @@ from .externals import six from .utils.fixes import signature from .utils.deprecation import deprecated from .exceptions import ChangedBehaviorWarning as _ChangedBehaviorWarning +from . import __version__ @deprecated("ChangedBehaviorWarning has been moved into the sklearn.exceptions" @@ -296,6 +297,24 @@ class BaseEstimator(object): return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), offset=len(class_name),),) + def __getstate__(self): + if type(self).__module__.startswith('sklearn.'): + return dict(self.__dict__.items(), _sklearn_version=__version__) + else: + return dict(self.__dict__.items()) + + def __setstate__(self, state): + if type(self).__module__.startswith('sklearn.'): + pickle_version = state.pop("_sklearn_version", "pre-0.18") + if pickle_version != __version__: + warnings.warn( + "Trying to unpickle estimator {0} from version {1} when " + "using version {2}. This might lead to breaking code or " + "invalid results. Use at your own risk.".format( + self.__class__.__name__, pickle_version, __version__), + UserWarning) + self.__dict__.update(state) + ############################################################################### class ClassifierMixin(object): diff --git a/sklearn/isotonic.py b/sklearn/isotonic.py index 827a2ef2da3d6cbb672d6b0947bb42951744c68c..0585438e8708f62651928848644d3c3caa7a67e9 100644 --- a/sklearn/isotonic.py +++ b/sklearn/isotonic.py @@ -412,8 +412,7 @@ class IsotonicRegression(BaseEstimator, TransformerMixin, RegressorMixin): def __getstate__(self): """Pickle-protocol - return state of the estimator. """ - # copy __dict__ - state = dict(self.__dict__) + state = super(IsotonicRegression, self).__getstate__() # remove interpolation method state.pop('f_', None) return state @@ -423,6 +422,6 @@ class IsotonicRegression(BaseEstimator, TransformerMixin, RegressorMixin): We need to rebuild the interpolation function. """ - self.__dict__.update(state) + super(IsotonicRegression, self).__setstate__(state) if hasattr(self, '_necessary_X_') and hasattr(self, '_necessary_y_'): self._build_f(self._necessary_X_, self._necessary_y_) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index ae6d9fdb1334457cdba23ff14c57e31fe3e54483..1c5c9bb7447a72637b7b4d611ff9f53a15077f38 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -6,22 +6,29 @@ import sys import numpy as np import scipy.sparse as sp +import sklearn from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_warns_message from sklearn.base import BaseEstimator, clone, is_classifier from sklearn.svm import SVC from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV + +from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeRegressor +from sklearn import datasets from sklearn.utils import deprecated from sklearn.base import TransformerMixin from sklearn.utils.mocking import MockDataFrame +import pickle ############################################################################# @@ -235,8 +242,8 @@ def test_is_classifier(): assert_true(is_classifier(svc)) assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]}))) assert_true(is_classifier(Pipeline([('svc', svc)]))) - assert_true(is_classifier(Pipeline([('svc_cv', - GridSearchCV(svc, {'C': [0.1, 1]}))]))) + assert_true(is_classifier(Pipeline( + [('svc_cv', GridSearchCV(svc, {'C': [0.1, 1]}))]))) def test_set_params(): @@ -253,9 +260,6 @@ def test_set_params(): def test_score_sample_weight(): - from sklearn.tree import DecisionTreeClassifier - from sklearn.tree import DecisionTreeRegressor - from sklearn import datasets rng = np.random.RandomState(0) @@ -313,3 +317,45 @@ def test_clone_pandas_dataframe(): # the test assert_true((e.df == cloned_e.df).values.all()) assert_equal(e.scalar_param, cloned_e.scalar_param) + + +class TreeNoVersion(DecisionTreeClassifier): + def __getstate__(self): + return self.__dict__ + + +def test_pickle_version_warning(): + # check that warnings are raised when unpickling in a different version + + # first, check no warning when in the same version: + iris = datasets.load_iris() + tree = DecisionTreeClassifier().fit(iris.data, iris.target) + tree_pickle = pickle.dumps(tree) + assert_true(b"version" in tree_pickle) + assert_no_warnings(pickle.loads, tree_pickle) + + # check that warning is raised on different version + tree_pickle_other = tree_pickle.replace(sklearn.__version__.encode(), + b"something") + message = ("Trying to unpickle estimator DecisionTreeClassifier from " + "version {0} when using version {1}. This might lead to " + "breaking code or invalid results. " + "Use at your own risk.".format("something", + sklearn.__version__)) + assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other) + + # check that not including any version also works: + # TreeNoVersion has no getstate, like pre-0.18 + tree = TreeNoVersion().fit(iris.data, iris.target) + + tree_pickle_noversion = pickle.dumps(tree) + assert_false(b"version" in tree_pickle_noversion) + message = message.replace("something", "pre-0.18") + message = message.replace("DecisionTreeClassifier", "TreeNoVersion") + # check we got the warning about using pre-0.18 pickle + assert_warns_message(UserWarning, message, pickle.loads, + tree_pickle_noversion) + + # check that no warning is raised for external estimators + TreeNoVersion.__module__ = "notsklearn" + assert_no_warnings(pickle.loads, tree_pickle_noversion) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5c031c881addce6794ab04719eb40e31dfac3d08..672886cca7231ca917409acdd00572e032599e40 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -835,6 +835,8 @@ def check_estimators_pickle(name, Estimator): # pickle and unpickle! pickled_estimator = pickle.dumps(estimator) + if Estimator.__module__.startswith('sklearn.'): + assert_true(b"version" in pickled_estimator) unpickled_estimator = pickle.loads(pickled_estimator) for method in result: