diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3da1dab97385259048329c2ec9b5389f17b3a9b0..3f280a943f7c2b2f9f5564b7d6ece5cf54dd5763 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -296,6 +296,11 @@ Enhancements Bug fixes ......... + - :func: `model_selection.tests._search._check_param_grid` now works correctly with all types + that extends/implements `Sequence` (except string), including range (Python 3.x) and xrange + (Python 2.x). + (`#7323 <https://github.com/scikit-learn/scikit-learn/pull/7323>`_) by `Viacheslav Kovalevskyi`_. + - :class:`StratifiedKFold` now raises error if all n_labels for individual classes is less than n_folds. (`#6182 <https://github.com/scikit-learn/scikit-learn/pull/6182>`_) by `Devashish Deshpande`_. diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 435cabc68d49a6e6f8c9671aa932db42646985c6..c96ee7f19d704aa72421add24370a9cf436d50a1 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -12,7 +12,7 @@ from __future__ import division # License: BSD 3 clause from abc import ABCMeta, abstractmethod -from collections import Mapping, namedtuple, Sized, defaultdict +from collections import Mapping, namedtuple, Sized, defaultdict, Sequence from functools import partial, reduce from itertools import product import operator @@ -332,10 +332,11 @@ def _check_param_grid(param_grid): if isinstance(v, np.ndarray) and v.ndim > 1: raise ValueError("Parameter array should be one-dimensional.") - check = [isinstance(v, k) for k in (list, tuple, np.ndarray)] - if True not in check: + if (isinstance(v, six.string_types) or + not isinstance(v, (np.ndarray, Sequence))): raise ValueError("Parameter values for parameter ({0}) need " - "to be a sequence.".format(name)) + "to be a sequence(but not a string) or" + " np.ndarray.".format(name)) if len(v) == 0: raise ValueError("Parameter values for parameter ({0}) need " diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 65fc1964ddf315ae19b4865ae928285c04648714..03eafdafb1d30521bf7a96f2f22bcaf890c943da 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -3,6 +3,7 @@ from collections import Iterable, Sized from sklearn.externals.six.moves import cStringIO as StringIO from sklearn.externals.six.moves import xrange +from sklearn.externals.joblib._compat import PY3_OR_LATER from itertools import chain, product import pickle import sys @@ -169,22 +170,6 @@ def test_grid_search(): assert_raises(ValueError, grid_search.fit, X, y) -def test_grid_search_incorrect_param_grid(): - clf = MockClassifier() - assert_raise_message( - ValueError, - "Parameter values for parameter (C) need to be a sequence.", - GridSearchCV, clf, {'C': 1}) - - -def test_grid_search_param_grid_includes_sequence_of_a_zero_length(): - clf = MockClassifier() - assert_raise_message( - ValueError, - "Parameter values for parameter (C) need to be a non-empty sequence.", - GridSearchCV, clf, {'C': []}) - - @ignore_warnings def test_grid_search_no_score(): # Test grid-search on classifier that has no score function. @@ -320,14 +305,41 @@ def test_grid_search_one_grid_point(): assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) +def test_grid_search_when_param_grid_includes_range(): + # Test that the best estimator contains the right value for foo_param + clf = MockClassifier() + grid_search = None + if PY3_OR_LATER: + grid_search = GridSearchCV(clf, {'foo_param': range(1, 4)}) + else: + grid_search = GridSearchCV(clf, {'foo_param': xrange(1, 4)}) + grid_search.fit(X, y) + assert_equal(grid_search.best_estimator_.foo_param, 2) + + def test_grid_search_bad_param_grid(): param_dict = {"C": 1.0} clf = SVC() - assert_raises(ValueError, GridSearchCV, clf, param_dict) + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a sequence" + "(but not a string) or np.ndarray.", + GridSearchCV, clf, param_dict) param_dict = {"C": []} clf = SVC() - assert_raises(ValueError, GridSearchCV, clf, param_dict) + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a non-empty sequence.", + GridSearchCV, clf, param_dict) + + param_dict = {"C": "1,2,3"} + clf = SVC() + assert_raise_message( + ValueError, + "Parameter values for parameter (C) need to be a sequence" + "(but not a string) or np.ndarray.", + GridSearchCV, clf, param_dict) param_dict = {"C": np.ones(6).reshape(3, 2)} clf = SVC()