From b22eb2ac9cfe6328fc6920042b243c346e6f678f Mon Sep 17 00:00:00 2001 From: Lars Buitinck <L.J.Buitinck@uva.nl> Date: Mon, 30 Apr 2012 14:31:38 +0200 Subject: [PATCH] =?UTF-8?q?BUG=20chi=C2=B2=20feature=20selection=20didn't?= =?UTF-8?q?=20work=20for=20COO=20matrices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reported via StackOverflow, http://stackoverflow.com/q/10373317/166749 --- sklearn/feature_selection/tests/test_chi2.py | 20 ++++++++++++++----- .../feature_selection/univariate_selection.py | 6 +++--- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sklearn/feature_selection/tests/test_chi2.py b/sklearn/feature_selection/tests/test_chi2.py index a163b6baf2..1fa66e7fdd 100644 --- a/sklearn/feature_selection/tests/test_chi2.py +++ b/sklearn/feature_selection/tests/test_chi2.py @@ -5,17 +5,17 @@ specifically to work with sparse matrices. import numpy as np from numpy.testing import assert_equal -from scipy.sparse import csr_matrix +from scipy.sparse import coo_matrix, csr_matrix from .. import SelectKBest, chi2 # Feature 0 is highly informative for class 1; # feature 1 is the same everywhere; # feature 2 is a bit informative for class 2. -X = ([[2, 1, 2], - [9, 1, 1], - [6, 1, 2], - [0, 1, 2]]) +X = [[2, 1, 2], + [9, 1, 1], + [6, 1, 2], + [0, 1, 2]] y = [0, 1, 2, 2] @@ -45,3 +45,13 @@ def test_chi2(): Xtrans = Xtrans.toarray() Xtrans2 = mkchi2(k=2).fit_transform(Xsp, y).toarray() assert_equal(Xtrans, Xtrans2) + + +def test_chi2_coo(): + """Check that chi2 works with a COO matrix + + (as returned by CountVectorizer, DictVectorizer) + """ + Xcoo = coo_matrix(X) + mkchi2(k=2).fit_transform(Xcoo, y) + # if we got here without an exception, we're safe diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index e19fbd891e..1e82a731ee 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -13,7 +13,7 @@ from scipy.sparse import issparse from ..base import BaseEstimator, TransformerMixin from ..preprocessing import LabelBinarizer -from ..utils import array2d, safe_asarray, deprecated, as_float_array +from ..utils import array2d, atleast2d_or_csr, deprecated, as_float_array from ..utils.extmath import safe_sparse_dot ###################################################################### @@ -145,7 +145,7 @@ def chi2(X, y): # XXX: we might want to do some of the following in logspace instead for # numerical stability. - X = safe_asarray(X) + X = atleast2d_or_csr(X) Y = LabelBinarizer().fit_transform(y) if Y.shape[1] == 1: Y = np.append(1 - Y, Y, axis=1) @@ -266,7 +266,7 @@ class _AbstractUnivariateFilter(BaseEstimator, TransformerMixin): """ Transform a new matrix using the selected features """ - return safe_asarray(X)[:, self.get_support(indices=issparse(X))] + return atleast2d_or_csr(X)[:, self.get_support(indices=issparse(X))] def inverse_transform(self, X): """ -- GitLab