diff --git a/doc/datasets/index.rst b/doc/datasets/index.rst index cbc0cedbc02a5b1c0909cc406269c11922d680b1..f0722a928bdebe98e513d87dfb081afddc8ba62f 100644 --- a/doc/datasets/index.rst +++ b/doc/datasets/index.rst @@ -264,3 +264,5 @@ features:: .. include:: labeled_faces.rst .. include:: covtype.rst + +.. include:: rcv1.rst diff --git a/doc/datasets/rcv1.rst b/doc/datasets/rcv1.rst new file mode 100644 index 0000000000000000000000000000000000000000..a957d9f91ff9a531576c41115852292d9adec65f --- /dev/null +++ b/doc/datasets/rcv1.rst @@ -0,0 +1,52 @@ + +.. _rcv1: + +RCV1 dataset +============ + +Reuters Corpus Volume I (RCV1) is an archive of over 800,000 manually categorized newswire stories made available by Reuters, Ltd. for research purposes. The dataset is extensively described in [1]_. + +:func:`sklearn.datasets.fetch_rcv1` will load the following version: RCV1-v2, vectors, full sets, topics multilabels:: + + >>> from sklearn.datasets import fetch_rcv1 + >>> rcv1 = fetch_rcv1() + +It returns a dictionary-like object, with the following attributes: + +``data``: +The feature matrix is a scipy CSR sparse matrix, with 804414 samples and +47236 features. Non-zero values contains cosine-normalized, log TF-IDF vectors. +A nearly chronological split is proposed in [1]_: The first 23149 samples are the training set. The last 781265 samples are the testing set. +The array has 0.16% of non zero values:: + + >>> rcv1.data.shape + (804414, 47236) + +``target``: +The target values are stored in a scipy CSR sparse matrix, with 804414 samples and 103 categories. Each sample has a value of 1 in its categories, and 0 in others. The array has 3.15% of non zero values:: + + >>> rcv1.target.shape + (804414, 103) + +``sample_id``: +Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596:: + + >>> rcv1.sample_id[:3] + array([2286, 2287, 2288], dtype=int32) + +``categories``: +The target values are the categories of each sample. Each sample belongs to at least one category, and to up to 17 categories. +There are 103 categories, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT':: + + >>> rcv1.categories[:3] + ['E11', 'ECAT', 'M11'] + +The dataset will be downloaded from the `dataset's homepage`_ if necessary. +The compressed size is about 656 MB. + +.. _dataset's homepage: http://jmlr.csail.mit.edu/papers/volume5/lewis04a/ + + +.. topic:: References + + .. [1] Lewis, D. D., Yang, Y., Rose, T. G., & Li, F. (2004). RCV1: A new benchmark collection for text categorization research. The Journal of Machine Learning Research, 5, 361-397. diff --git a/doc/datasets/rcv1_fixture.py b/doc/datasets/rcv1_fixture.py new file mode 100644 index 0000000000000000000000000000000000000000..19d27120feb8dbfd5a50266731a1e17e3b13ea02 --- /dev/null +++ b/doc/datasets/rcv1_fixture.py @@ -0,0 +1,15 @@ +"""Fixture module to skip the datasets loading when offline + +The RCV1 data is rather large and some CI workers such as travis are +stateless hence will not cache the dataset as regular sklearn users would do. + +The following will skip the execution of the rcv1.rst doctests +if the proper environment variable is configured (see the source code of +check_skip_network for more details). + +""" +from sklearn.utils.testing import check_skip_network + + +def setup_module(): + check_skip_network() diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3df46348788237170896e904f95170c4f3aeb8fc..e88532ce7b6d46676c6647ddc00a44ff30d78fd3 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -91,6 +91,9 @@ Enhancements - ``dump_svmlight_file`` now handles multi-label datasets. By Chih-Wei Chang. + - RCV1 dataset loader (:func:`sklearn.datasets.fetch_rcv1`). + By `Tom Dupre la Tour`_. + Bug fixes ......... diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index 4b041675a2a851bdc6e54291da5914bf9eec3a29..58f69cac7fc3f3df3a3f49b2b5bef34406f5f25b 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -49,6 +49,8 @@ from .svmlight_format import dump_svmlight_file from .olivetti_faces import fetch_olivetti_faces from .species_distributions import fetch_species_distributions from .california_housing import fetch_california_housing +from .rcv1 import fetch_rcv1 + __all__ = ['clear_data_home', 'dump_svmlight_file', @@ -61,6 +63,7 @@ __all__ = ['clear_data_home', 'fetch_species_distributions', 'fetch_california_housing', 'fetch_covtype', + 'fetch_rcv1', 'get_data_home', 'load_boston', 'load_diabetes', diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py index cc4d111c60361a60142840cd4c4b42e09db3e19b..13217652864c28795413c8a639c8211976b1a888 100644 --- a/sklearn/datasets/covtype.py +++ b/sklearn/datasets/covtype.py @@ -15,11 +15,9 @@ Courtesy of Jock A. Blackard and Colorado State University. # License: BSD 3 clause import sys -import errno from gzip import GzipFile from io import BytesIO import logging -import os from os.path import exists, join try: from urllib2 import urlopen @@ -30,6 +28,7 @@ import numpy as np from .base import get_data_home from .base import Bunch +from ..utils.fixes import makedirs from ..externals import joblib from ..utils import check_random_state @@ -98,7 +97,7 @@ def fetch_covtype(data_home=None, download_if_missing=True, available = exists(samples_path) if download_if_missing and not available: - _mkdirp(covtype_dir) + makedirs(covtype_dir, exist_ok=True) logger.warning("Downloading %s" % URL) f = BytesIO(urlopen(URL).read()) Xy = np.genfromtxt(GzipFile(fileobj=f), delimiter=',') @@ -123,14 +122,3 @@ def fetch_covtype(data_home=None, download_if_missing=True, y = y[ind] return Bunch(data=X, target=y, DESCR=__doc__) - - -def _mkdirp(d): - """Ensure directory d exists (like mkdir -p on Unix) - No guarantee that the directory is writable. - """ - try: - os.makedirs(d) - except OSError as e: - if e.errno != errno.EEXIST: - raise diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcd6c32f632d7a369687970b044404ed576f896 --- /dev/null +++ b/sklearn/datasets/rcv1.py @@ -0,0 +1,205 @@ +"""RCV1 dataset. +""" + +# Author: Tom Dupre la Tour +# License: BSD 3 clause + +import logging + +from os.path import exists, join +from gzip import GzipFile +from io import BytesIO +from contextlib import closing + +try: + from urllib2 import urlopen +except ImportError: + from urllib.request import urlopen + +import numpy as np +import scipy.sparse as sp + +from .base import get_data_home +from .base import Bunch +from ..utils.fixes import makedirs +from ..externals import joblib +from .svmlight_format import load_svmlight_files +from ..utils import shuffle as shuffle_ + + +URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/' + 'a13-vector-files/lyrl2004_vectors') +URL_topics = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/' + 'a08-topic-qrels/rcv1-v2.topics.qrels.gz') + +logger = logging.getLogger() + + +def fetch_rcv1(data_home=None, download_if_missing=True, + random_state=None, shuffle=False): + """Load the RCV1 multilabel dataset, downloading it if necessary. + + Version: RCV1-v2, vectors, full sets, topics multilabels. + + ============== ===================== + Classes 103 + Samples total 804414 + Dimensionality 47236 + Features real, between 0 and 1 + ============== ===================== + + Read more in the :ref:`User Guide <datasets>`. + + Parameters + ---------- + data_home : string, optional + Specify another download and cache folder for the datasets. By default + all scikit learn data is stored in '~/scikit_learn_data' subfolders. + + download_if_missing : boolean, default=True + If False, raise a IOError if the data is not locally available + instead of trying to download the data from the source site. + + random_state : int, RandomState instance or None, optional (default=None) + Random state for shuffling the dataset. + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + shuffle : bool, default=False + Whether to shuffle dataset. + + Returns + ------- + dataset : dict-like object with the following attributes: + + dataset.data : scipy csr array, shape (804414, 47236) + The first 23149 samples are training samples. + The last 781265 samples are testing samples. + The array has 0.16% of non zero values. + + dataset.target : scipy csr array, shape (804414, 103) + Each sample has a value of 1 in its categories, and 0 in others. + The array has 3.15% of non zero values. + + dataset.sample_id : numpy array, shape (804414,) + Identification number of each sample, as ordered in dataset.data. + + dataset.target_names : numpy array of object, length (103) + Names of each target (RCV1 topics), as ordered in dataset.target. + + dataset.DESCR : string + Description of the RCV1 dataset. + + """ + N_SAMPLES = 804414 + N_FEATURES = 47236 + N_CATEGORIES = 103 + + data_home = get_data_home(data_home=data_home) + rcv1_dir = join(data_home, "RCV1") + if download_if_missing: + makedirs(rcv1_dir, exist_ok=True) + + samples_path = join(rcv1_dir, "samples.pkl") + sample_id_path = join(rcv1_dir, "sample_id.pkl") + sample_topics_path = join(rcv1_dir, "sample_topics.pkl") + topics_path = join(rcv1_dir, "topics_names.pkl") + + # load data (X) and sample_id + if download_if_missing and (not exists(samples_path) or + not exists(sample_id_path)): + file_urls = ["%s_test_pt%d.dat.gz" % (URL, i) for i in range(4)] + file_urls.append("%s_train.dat.gz" % URL) + files = [] + for file_url in file_urls: + logger.warning("Downloading %s" % file_url) + with closing(urlopen(file_url)) as online_file: + # buffer the full file in memory to make possible to Gzip to + # work correctly + f = BytesIO(online_file.read()) + files.append(GzipFile(fileobj=f)) + + Xy = load_svmlight_files(files, n_features=N_FEATURES) + + # Training data is before testing data + X = sp.vstack([Xy[8], Xy[0], Xy[2], Xy[4], Xy[6]]).tocsr() + sample_id = np.hstack((Xy[9], Xy[1], Xy[3], Xy[5], Xy[7])) + sample_id = sample_id.astype(np.int32) + + joblib.dump(X, samples_path, compress=9) + joblib.dump(sample_id, sample_id_path, compress=9) + + else: + X = joblib.load(samples_path) + sample_id = joblib.load(sample_id_path) + + # load target (y), categories, and sample_id_bis + if download_if_missing and (not exists(sample_topics_path) or + not exists(topics_path)): + logger.warning("Downloading %s" % URL_topics) + with closing(urlopen(URL_topics)) as online_topics: + f = BytesIO(online_topics.read()) + + # parse the target file + n_cat = -1 + n_doc = -1 + doc_previous = -1 + y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.int8) + sample_id_bis = np.zeros(N_SAMPLES, dtype=np.int32) + category_names = {} + for line in GzipFile(fileobj=f, mode='rb'): + line_components = line.decode("ascii").split(u" ") + if len(line_components) == 3: + cat, doc, _ = line_components + if cat not in category_names: + n_cat += 1 + category_names[cat] = n_cat + + doc = int(doc) + if doc != doc_previous: + doc_previous = doc + n_doc += 1 + sample_id_bis[n_doc] = doc + y[n_doc, category_names[cat]] = 1 + + # Samples in X are ordered with sample_id, + # whereas in y, they are ordered with sample_id_bis. + permutation = _find_permutation(sample_id_bis, sample_id) + y = sp.csr_matrix(y[permutation, :]) + + # save category names in a list, with same order than y + categories = np.empty(N_CATEGORIES, dtype=object) + for k in category_names.keys(): + categories[category_names[k]] = k + + joblib.dump(y, sample_topics_path, compress=9) + joblib.dump(categories, topics_path, compress=9) + + else: + y = joblib.load(sample_topics_path) + categories = joblib.load(topics_path) + + if shuffle: + X, y, sample_id = shuffle_(X, y, sample_id, random_state=random_state) + + return Bunch(data=X, target=y, sample_id=sample_id, + target_names=categories, DESCR=__doc__) + + +def _inverse_permutation(p): + """inverse permutation p""" + n = p.size + s = np.zeros(n, dtype=np.int32) + i = np.arange(n, dtype=np.int32) + np.put(s, p, i) # s[p] = i + return s + + +def _find_permutation(a, b): + """find the permutation from a to b""" + t = np.argsort(a) + u = np.argsort(b) + u_ = _inverse_permutation(u) + return t[u_] diff --git a/sklearn/datasets/tests/test_rcv1.py b/sklearn/datasets/tests/test_rcv1.py new file mode 100644 index 0000000000000000000000000000000000000000..471e02afc99fac9a3792f8fda8a452449bd704c9 --- /dev/null +++ b/sklearn/datasets/tests/test_rcv1.py @@ -0,0 +1,71 @@ +"""Test the rcv1 loader. + +Skipped if rcv1 is not already downloaded to data_home. +""" + +import errno +import scipy.sparse as sp +import numpy as np +from sklearn.datasets import fetch_rcv1 +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import SkipTest + + +def test_fetch_rcv1(): + try: + data1 = fetch_rcv1(shuffle=False, download_if_missing=False) + except IOError as e: + if e.errno == errno.ENOENT: + raise SkipTest("Download RCV1 dataset to run this test.") + + X1, Y1 = data1.data, data1.target + cat_list, s1 = data1.target_names.tolist(), data1.sample_id + + # test sparsity + assert_true(sp.issparse(X1)) + assert_true(sp.issparse(Y1)) + assert_equal(60915113, X1.data.size) + assert_equal(2606875, Y1.data.size) + + # test shapes + assert_equal((804414, 47236), X1.shape) + assert_equal((804414, 103), Y1.shape) + assert_equal((804414,), s1.shape) + assert_equal(103, len(cat_list)) + + # test number of sample for some categories + some_categories = ('GMIL', 'E143', 'CCAT') + number_non_zero_in_cat = (5, 1206, 381327) + for num, cat in zip(number_non_zero_in_cat, some_categories): + j = cat_list.index(cat) + assert_equal(num, Y1[:, j].data.size) + + # test shuffling + data2 = fetch_rcv1(shuffle=True, random_state=77, + download_if_missing=False) + X2, Y2 = data2.data, data2.target + s2 = data2.sample_id + + assert_true((s1 != s2).any()) + assert_array_equal(np.sort(s1), np.sort(s2)) + + # test some precise values + some_sample_id = (2286, 333274, 810593) + # indice of first nonzero feature + indices = (863, 863, 814) + # value of first nonzero feature + feature = (0.04973993, 0.12272136, 0.14245221) + for i, j, v in zip(some_sample_id, indices, feature): + i1 = np.nonzero(s1 == i)[0][0] + i2 = np.nonzero(s2 == i)[0][0] + + sp_1 = X1[i1].sorted_indices() + sp_2 = X2[i2].sorted_indices() + + assert_almost_equal(sp_1[0, j], v) + assert_almost_equal(sp_2[0, j], v) + + assert_array_equal(np.sort(Y1[i1].indices), np.sort(Y2[i2].indices)) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index b66a3b9ff2a759b7a5806acfc7f0c36e5db54c87..73a0331b37b00f5e782fd3a5c1687d099226cc97 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -14,6 +14,8 @@ import inspect import warnings import sys import functools +import os +import errno import numpy as np import scipy.sparse as sp @@ -351,3 +353,25 @@ if np_version < (1, 6, 2): else: from numpy import bincount + + +if 'exist_ok' in inspect.getargspec(os.makedirs).args: + makedirs = os.makedirs +else: + def makedirs(name, mode=0o777, exist_ok=False): + """makedirs(name [, mode=0o777][, exist_ok=False]) + + Super-mkdir; create a leaf directory and all intermediate ones. Works + like mkdir, except that any intermediate path segment (not just the + rightmost) will be created if it does not exist. If the target + directory already exists, raise an OSError if exist_ok is False. + Otherwise no exception is raised. This is recursive. + + """ + + try: + os.makedirs(name, mode=mode) + except OSError as e: + if (not exist_ok or e.errno != errno.EEXIST + or not os.path.isdir(name)): + raise