From a03db89eba7978cbe8d22573cf64def4df8b5d72 Mon Sep 17 00:00:00 2001 From: b0noI <viacheslav@b0noi.com> Date: Wed, 7 Sep 2016 14:06:44 -0700 Subject: [PATCH] [MRG+1] Support of the collections.Sequence type has been added to the _check_param_grid method from model_selection. (#7323) * Support of the collections.Sequence type has been added to the _check_param_grid method from model_selection. * test_grid_search_when_param_grid_includes_range test was refactored (parts that are not nessesary have been removed). * test_grid_search_bad_param_grid now checks that value is not string. This is important since string is a Sequence. * _check_param_grid now checks is the type is not string together with the check for other types. * whats_new.rst has been updated to include information about bug fix for bug #7322. * Description of the fix for the bug #7322 has been updated. * Fix for indented in model_selection._search.py. --- doc/whats_new.rst | 5 ++ sklearn/model_selection/_search.py | 9 ++-- sklearn/model_selection/tests/test_search.py | 48 ++++++++++++-------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3da1dab973..3f280a943f 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -296,6 +296,11 @@ Enhancements Bug fixes ......... + - :func: `model_selection.tests._search._check_param_grid` now works correctly with all types + that extends/implements `Sequence` (except string), including range (Python 3.x) and xrange + (Python 2.x). + (`#7323 <https://github.com/scikit-learn/scikit-learn/pull/7323>`_) by `Viacheslav Kovalevskyi`_. + - :class:`StratifiedKFold` now raises error if all n_labels for individual classes is less than n_folds. (`#6182 <https://github.com/scikit-learn/scikit-learn/pull/6182>`_) by `Devashish Deshpande`_. diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 435cabc68d..c96ee7f19d 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -12,7 +12,7 @@ from __future__ import division # License: BSD 3 clause from abc import ABCMeta, abstractmethod -from collections import Mapping, namedtuple, Sized, defaultdict +from collections import Mapping, namedtuple, Sized, defaultdict, Sequence from functools import partial, reduce from itertools import product import operator @@ -332,10 +332,11 @@ def _check_param_grid(param_grid): if isinstance(v, np.ndarray) and v.ndim > 1: raise ValueError("Parameter array should be one-dimensional.") - check = [isinstance(v, k) for k in (list, tuple, np.ndarray)] - if True not in check: + if (isinstance(v, six.string_types) or + not isinstance(v, (np.ndarray, Sequence))): raise ValueError("Parameter values for parameter ({0}) need " - "to be a sequence.".format(name)) + "to be a sequence(but not a string) or" + " np.ndarray.".format(name)) if len(v) == 0: raise ValueError("Parameter values for parameter ({0}) need " diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 65fc1964dd..03eafdafb1 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -3,6 +3,7 @@ from collections import Iterable, Sized from sklearn.externals.six.moves import cStringIO as StringIO from sklearn.externals.six.moves import xrange +from sklearn.externals.joblib._compat import PY3_OR_LATER from itertools import chain, product import pickle import sys @@ -169,22 +170,6 @@ def test_grid_search(): assert_raises(ValueError, grid_search.fit, X, y) -def test_grid_search_incorrect_param_grid(): - clf = MockClassifier() - assert_raise_message( - ValueError, - "Parameter values for parameter (C) need to be a sequence.", - GridSearchCV, clf, {'C': 1}) - - -def test_grid_search_param_grid_includes_sequence_of_a_zero_length(): - clf = MockClassifier() - assert_raise_message( - ValueError, - "Parameter values for parameter (C) need to be a non-empty sequence.", - GridSearchCV, clf, {'C': []}) - - @ignore_warnings def test_grid_search_no_score(): # Test grid-search on classifier that has no score function. @@ -320,14 +305,41 @@ def test_grid_search_one_grid_point(): assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) +def test_grid_search_when_param_grid_includes_range(): + # Test that the best estimator contains the right value for foo_param + clf = MockClassifier() + grid_search = None + if PY3_OR_LATER: + grid_search = GridSearchCV(clf, {'foo_param': range(1, 4)}) + else: + grid_search = GridSearchCV(clf, {'foo_param': xrange(1, 4)}) + grid_search.fit(X, y) + assert_equal(grid_search.best_estimator_.foo_param, 2) + + def test_grid_search_bad_param_grid(): param_dict = {"C": 1.0} clf = SVC() - assert_raises(ValueError, GridSearchCV, clf, param_dict) + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a sequence" + "(but not a string) or np.ndarray.", + GridSearchCV, clf, param_dict) param_dict = {"C": []} clf = SVC() - assert_raises(ValueError, GridSearchCV, clf, param_dict) + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a non-empty sequence.", + GridSearchCV, clf, param_dict) + + param_dict = {"C": "1,2,3"} + clf = SVC() + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a sequence" + "(but not a string) or np.ndarray.", + GridSearchCV, clf, param_dict) param_dict = {"C": np.ones(6).reshape(3, 2)} clf = SVC() -- GitLab