From a03db89eba7978cbe8d22573cf64def4df8b5d72 Mon Sep 17 00:00:00 2001
From: b0noI <viacheslav@b0noi.com>
Date: Wed, 7 Sep 2016 14:06:44 -0700
Subject: [PATCH] [MRG+1] Support of the collections.Sequence type has been
 added to the _check_param_grid method from model_selection. (#7323)

* Support of the collections.Sequence type has been added to the _check_param_grid method from model_selection.

* test_grid_search_when_param_grid_includes_range test was refactored (parts that are not nessesary have been removed).

* test_grid_search_bad_param_grid now checks that value is not string. This is important since string is a Sequence.

* _check_param_grid now checks is the type is not string together with the check for other types.

* whats_new.rst has been updated to include information about bug fix for bug #7322.

* Description of the fix for the bug #7322 has been updated.

* Fix for indented in model_selection._search.py.
---
 doc/whats_new.rst                            |  5 ++
 sklearn/model_selection/_search.py           |  9 ++--
 sklearn/model_selection/tests/test_search.py | 48 ++++++++++++--------
 3 files changed, 40 insertions(+), 22 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 3da1dab973..3f280a943f 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -296,6 +296,11 @@ Enhancements
 Bug fixes
 .........
 
+    - :func: `model_selection.tests._search._check_param_grid` now works correctly with all types
+      that extends/implements `Sequence` (except string), including range (Python 3.x) and xrange
+      (Python 2.x).
+      (`#7323 <https://github.com/scikit-learn/scikit-learn/pull/7323>`_) by `Viacheslav Kovalevskyi`_.
+
     - :class:`StratifiedKFold` now raises error if all n_labels for individual classes is less than n_folds.
       (`#6182 <https://github.com/scikit-learn/scikit-learn/pull/6182>`_) by `Devashish Deshpande`_.
 
diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py
index 435cabc68d..c96ee7f19d 100644
--- a/sklearn/model_selection/_search.py
+++ b/sklearn/model_selection/_search.py
@@ -12,7 +12,7 @@ from __future__ import division
 # License: BSD 3 clause
 
 from abc import ABCMeta, abstractmethod
-from collections import Mapping, namedtuple, Sized, defaultdict
+from collections import Mapping, namedtuple, Sized, defaultdict, Sequence
 from functools import partial, reduce
 from itertools import product
 import operator
@@ -332,10 +332,11 @@ def _check_param_grid(param_grid):
             if isinstance(v, np.ndarray) and v.ndim > 1:
                 raise ValueError("Parameter array should be one-dimensional.")
 
-            check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
-            if True not in check:
+            if (isinstance(v, six.string_types) or
+                    not isinstance(v, (np.ndarray, Sequence))):
                 raise ValueError("Parameter values for parameter ({0}) need "
-                                 "to be a sequence.".format(name))
+                                 "to be a sequence(but not a string) or"
+                                 " np.ndarray.".format(name))
 
             if len(v) == 0:
                 raise ValueError("Parameter values for parameter ({0}) need "
diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py
index 65fc1964dd..03eafdafb1 100644
--- a/sklearn/model_selection/tests/test_search.py
+++ b/sklearn/model_selection/tests/test_search.py
@@ -3,6 +3,7 @@
 from collections import Iterable, Sized
 from sklearn.externals.six.moves import cStringIO as StringIO
 from sklearn.externals.six.moves import xrange
+from sklearn.externals.joblib._compat import PY3_OR_LATER
 from itertools import chain, product
 import pickle
 import sys
@@ -169,22 +170,6 @@ def test_grid_search():
     assert_raises(ValueError, grid_search.fit, X, y)
 
 
-def test_grid_search_incorrect_param_grid():
-    clf = MockClassifier()
-    assert_raise_message(
-        ValueError,
-        "Parameter values for parameter (C) need to be a sequence.",
-        GridSearchCV, clf, {'C': 1})
-
-
-def test_grid_search_param_grid_includes_sequence_of_a_zero_length():
-    clf = MockClassifier()
-    assert_raise_message(
-        ValueError,
-        "Parameter values for parameter (C) need to be a non-empty sequence.",
-        GridSearchCV, clf, {'C': []})
-
-
 @ignore_warnings
 def test_grid_search_no_score():
     # Test grid-search on classifier that has no score function.
@@ -320,14 +305,41 @@ def test_grid_search_one_grid_point():
     assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_)
 
 
+def test_grid_search_when_param_grid_includes_range():
+    # Test that the best estimator contains the right value for foo_param
+    clf = MockClassifier()
+    grid_search = None
+    if PY3_OR_LATER:
+        grid_search = GridSearchCV(clf, {'foo_param': range(1, 4)})
+    else:
+        grid_search = GridSearchCV(clf, {'foo_param': xrange(1, 4)})
+    grid_search.fit(X, y)
+    assert_equal(grid_search.best_estimator_.foo_param, 2)
+
+
 def test_grid_search_bad_param_grid():
     param_dict = {"C": 1.0}
     clf = SVC()
-    assert_raises(ValueError, GridSearchCV, clf, param_dict)
+    assert_raise_message(
+        ValueError,
+        "Parameter values for parameter (C) need to be a sequence"
+        "(but not a string) or np.ndarray.",
+        GridSearchCV, clf, param_dict)
 
     param_dict = {"C": []}
     clf = SVC()
-    assert_raises(ValueError, GridSearchCV, clf, param_dict)
+    assert_raise_message(
+        ValueError,
+        "Parameter values for parameter (C) need to be a non-empty sequence.",
+        GridSearchCV, clf, param_dict)
+
+    param_dict = {"C": "1,2,3"}
+    clf = SVC()
+    assert_raise_message(
+        ValueError,
+        "Parameter values for parameter (C) need to be a sequence"
+        "(but not a string) or np.ndarray.",
+        GridSearchCV, clf, param_dict)
 
     param_dict = {"C": np.ones(6).reshape(3, 2)}
     clf = SVC()
-- 
GitLab