diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst index 0f4d728f7151aaad3694c53f9fdea32dea143c9c..e82b5e4112673040594456abc3a49087a8f987bf 100644 --- a/doc/datasets/twenty_newsgroups.rst +++ b/doc/datasets/twenty_newsgroups.rst @@ -12,7 +12,7 @@ This module contains two loaders. The first one, returns a list of the raw text files that can be fed to text feature extractors such as :class:`sklearn.feature_extraction.text.Vectorizer` with custom parameters so as to extract feature vectors. -The second one, ``sklearn.datasets.fetch_20newsgroups_tfidf``, +The second one, ``sklearn.datasets.fetch_20newsgroups_vectorized``, returns ready-to-use features, i.e., it is not necessary to use a feature extractor. @@ -98,7 +98,7 @@ zero features):: >>> vectors.nnz / vectors.shape[0] 118 -``sklearn.datasets.fetch_20newsgroups_tfidf`` is a function which returns +``sklearn.datasets.fetch_20newsgroups_vectorized`` is a function which returns ready-to-use tfidf features instead of file names. .. _`20 newsgroups website`: http://people.csail.mit.edu/jrennie/20Newsgroups/ diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 6eb51039808828cf31fe8147737049b52d83c337..4d089e14482cf2729d50272cefc57beee3063834 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -130,7 +130,7 @@ Loaders datasets.fetch_lfw_people datasets.load_20newsgroups datasets.fetch_20newsgroups - datasets.fetch_20newsgroups_tfidf + datasets.fetch_20newsgroups_vectorized datasets.fetch_olivetti_faces Samples generator diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index efa1015b995daf25c3e879eaf551017e98dff32c..74fb2e71084d877bbf0eafe5ede2df7ebb6aeb86 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -20,7 +20,7 @@ from .lfw import load_lfw_people from .lfw import fetch_lfw_pairs from .lfw import fetch_lfw_people from .twenty_newsgroups import fetch_20newsgroups -from .twenty_newsgroups import fetch_20newsgroups_tfidf +from .twenty_newsgroups import fetch_20newsgroups_vectorized from .twenty_newsgroups import load_20newsgroups from .mldata import fetch_mldata, mldata_filename from .samples_generator import make_classification diff --git a/sklearn/datasets/tests/test_20news.py b/sklearn/datasets/tests/test_20news.py index 22fa6256e3148af9d6bafed975c82d6bd67776cd..61d428a6a6b5080917eb1fb349819ed8c1a18ad0 100644 --- a/sklearn/datasets/tests/test_20news.py +++ b/sklearn/datasets/tests/test_20news.py @@ -34,18 +34,21 @@ def test_20news(): assert_equal(entry1, entry2) -def test_20news_tfidf(): +def test_20news_vectorized(): # This test is slow. raise SkipTest - bunch = datasets.fetch_20newsgroups_tfidf(subset="train") + bunch = datasets.fetch_20newsgroups_vectorized(subset="train") assert_equal(bunch.data.shape, (11314, 107130)) assert_equal(bunch.target.shape[0], 11314) + assert_equal(bunch.data.dtype, np.float64) - bunch = datasets.fetch_20newsgroups_tfidf(subset="test") + bunch = datasets.fetch_20newsgroups_vectorized(subset="test") assert_equal(bunch.data.shape, (7532, 107130)) assert_equal(bunch.target.shape[0], 7532) + assert_equal(bunch.data.dtype, np.float64) - bunch = datasets.fetch_20newsgroups_tfidf(subset="all") + bunch = datasets.fetch_20newsgroups_vectorized(subset="all") assert_equal(bunch.data.shape, (11314 + 7532, 107130)) assert_equal(bunch.target.shape[0], 11314 + 7532) + assert_equal(bunch.data.dtype, np.float64) diff --git a/sklearn/datasets/twenty_newsgroups.py b/sklearn/datasets/twenty_newsgroups.py index ab66d4f5a6cc8ffbbb7cd0d65ea58d1d69b0dcb7..ea06d74c0a75c20b96c169e7fe90a645bce00f91 100644 --- a/sklearn/datasets/twenty_newsgroups.py +++ b/sklearn/datasets/twenty_newsgroups.py @@ -50,8 +50,9 @@ from .base import Bunch from .base import load_files from ..utils import check_random_state, deprecated from ..utils.fixes import in1d -from ..feature_extraction.text import Vectorizer -from sklearn.externals import joblib +from ..feature_extraction.text import CountVectorizer +from ..preprocessing import normalize +from ..externals import joblib logger = logging.getLogger(__name__) @@ -192,7 +193,7 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None, return data -def fetch_20newsgroups_tfidf(subset="train", data_home=None): +def fetch_20newsgroups_vectorized(subset="train", data_home=None): """Load the 20 newsgroups dataset and transform it into tf-idf vectors This is a convenience function; the tf-idf transformation is done using the @@ -238,11 +239,18 @@ def fetch_20newsgroups_tfidf(subset="train", data_home=None): if os.path.exists(target_file): X_train, X_test = joblib.load(target_file) else: - vectorizer = Vectorizer() + vectorizer = CountVectorizer(dtype=np.int16) X_train = vectorizer.fit_transform(data_train.data) X_test = vectorizer.transform(data_test.data) joblib.dump((X_train, X_test), target_file) + # the data is stored as int16 for compactness + # but normalize needs floats + X_train = X_train.astype(np.float64) + X_test = X_test.astype(np.float64) + normalize(X_train, copy=False) + normalize(X_test, copy=False) + target_names = data_train.target_names if subset == "train":