Skip to content
Snippets Groups Projects
Commit 792e5295 authored by TomDLT's avatar TomDLT
Browse files

add fetch_rcv1

parent 7df3c232
No related branches found
No related tags found
No related merge requests found
...@@ -264,3 +264,5 @@ features:: ...@@ -264,3 +264,5 @@ features::
.. include:: labeled_faces.rst .. include:: labeled_faces.rst
.. include:: covtype.rst .. include:: covtype.rst
.. include:: rcv1.rst
.. _rcv1:
RCV1 dataset
============
Reuters Corpus Volume I (RCV1) is an archive of over 800,000 manually categorized newswire stories made available by Reuters, Ltd. for research purposes. The dataset is extensively described in [1]_.
:func:`sklearn.datasets.fetch_rcv1` will load the following version: RCV1-v2, vectors, full sets, topics multilabels::
>>> from sklearn.datasets import fetch_rcv1
>>> rcv1 = fetch_rcv1()
It returns a dictionary-like object, with the following attributes:
``data``:
The feature matrix is a scipy CSR sparse matrix, with 804414 samples and
47236 features. Non-zero values contains cosine-normalized, log TF-IDF vectors.
A nearly chronological split is proposed in [1]_: The first 23149 samples are the training set. The last 781265 samples are the testing set.
The array has 0.16% of non zero values::
>>> rcv1.data.shape
(804414, 47236)
``target``:
The target values are stored in a scipy CSR sparse matrix, with 804414 samples and 103 categories. Each sample has a value of 1 in its categories, and 0 in others. The array has 3.15% of non zero values::
>>> rcv1.target.shape
(804414, 103)
``sample_id``:
Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596::
>>> rcv1.sample_id[:3]
array([2286, 2287, 2288], dtype=int32)
``categories``:
The target values are the categories of each sample. Each sample belongs to at least one category, and to up to 17 categories.
There are 103 categories, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT'::
>>> rcv1.categories[:3]
['E11', 'ECAT', 'M11']
The dataset will be downloaded from the `dataset's homepage`_ if necessary.
The compressed size is about 656 MB.
.. _dataset's homepage: http://jmlr.csail.mit.edu/papers/volume5/lewis04a/
.. topic:: References
.. [1] Lewis, D. D., Yang, Y., Rose, T. G., & Li, F. (2004). RCV1: A new benchmark collection for text categorization research. The Journal of Machine Learning Research, 5, 361-397.
"""Fixture module to skip the datasets loading when offline
The RCV1 data is rather large and some CI workers such as travis are
stateless hence will not cache the dataset as regular sklearn users would do.
The following will skip the execution of the rcv1.rst doctests
if the proper environment variable is configured (see the source code of
check_skip_network for more details).
"""
from sklearn.utils.testing import check_skip_network
def setup_module():
check_skip_network()
...@@ -91,6 +91,9 @@ Enhancements ...@@ -91,6 +91,9 @@ Enhancements
- ``dump_svmlight_file`` now handles multi-label datasets. - ``dump_svmlight_file`` now handles multi-label datasets.
By Chih-Wei Chang. By Chih-Wei Chang.
- RCV1 dataset loader (:func:`sklearn.datasets.fetch_rcv1`).
By `Tom Dupre la Tour`_.
Bug fixes Bug fixes
......... .........
......
...@@ -49,6 +49,8 @@ from .svmlight_format import dump_svmlight_file ...@@ -49,6 +49,8 @@ from .svmlight_format import dump_svmlight_file
from .olivetti_faces import fetch_olivetti_faces from .olivetti_faces import fetch_olivetti_faces
from .species_distributions import fetch_species_distributions from .species_distributions import fetch_species_distributions
from .california_housing import fetch_california_housing from .california_housing import fetch_california_housing
from .rcv1 import fetch_rcv1
__all__ = ['clear_data_home', __all__ = ['clear_data_home',
'dump_svmlight_file', 'dump_svmlight_file',
...@@ -61,6 +63,7 @@ __all__ = ['clear_data_home', ...@@ -61,6 +63,7 @@ __all__ = ['clear_data_home',
'fetch_species_distributions', 'fetch_species_distributions',
'fetch_california_housing', 'fetch_california_housing',
'fetch_covtype', 'fetch_covtype',
'fetch_rcv1',
'get_data_home', 'get_data_home',
'load_boston', 'load_boston',
'load_diabetes', 'load_diabetes',
......
...@@ -15,11 +15,9 @@ Courtesy of Jock A. Blackard and Colorado State University. ...@@ -15,11 +15,9 @@ Courtesy of Jock A. Blackard and Colorado State University.
# License: BSD 3 clause # License: BSD 3 clause
import sys import sys
import errno
from gzip import GzipFile from gzip import GzipFile
from io import BytesIO from io import BytesIO
import logging import logging
import os
from os.path import exists, join from os.path import exists, join
try: try:
from urllib2 import urlopen from urllib2 import urlopen
...@@ -30,6 +28,7 @@ import numpy as np ...@@ -30,6 +28,7 @@ import numpy as np
from .base import get_data_home from .base import get_data_home
from .base import Bunch from .base import Bunch
from ..utils.fixes import makedirs
from ..externals import joblib from ..externals import joblib
from ..utils import check_random_state from ..utils import check_random_state
...@@ -98,7 +97,7 @@ def fetch_covtype(data_home=None, download_if_missing=True, ...@@ -98,7 +97,7 @@ def fetch_covtype(data_home=None, download_if_missing=True,
available = exists(samples_path) available = exists(samples_path)
if download_if_missing and not available: if download_if_missing and not available:
_mkdirp(covtype_dir) makedirs(covtype_dir, exist_ok=True)
logger.warning("Downloading %s" % URL) logger.warning("Downloading %s" % URL)
f = BytesIO(urlopen(URL).read()) f = BytesIO(urlopen(URL).read())
Xy = np.genfromtxt(GzipFile(fileobj=f), delimiter=',') Xy = np.genfromtxt(GzipFile(fileobj=f), delimiter=',')
...@@ -123,14 +122,3 @@ def fetch_covtype(data_home=None, download_if_missing=True, ...@@ -123,14 +122,3 @@ def fetch_covtype(data_home=None, download_if_missing=True,
y = y[ind] y = y[ind]
return Bunch(data=X, target=y, DESCR=__doc__) return Bunch(data=X, target=y, DESCR=__doc__)
def _mkdirp(d):
"""Ensure directory d exists (like mkdir -p on Unix)
No guarantee that the directory is writable.
"""
try:
os.makedirs(d)
except OSError as e:
if e.errno != errno.EEXIST:
raise
"""RCV1 dataset.
"""
# Author: Tom Dupre la Tour
# License: BSD 3 clause
import logging
from os.path import exists, join
from gzip import GzipFile
from io import BytesIO
from contextlib import closing
try:
from urllib2 import urlopen
except ImportError:
from urllib.request import urlopen
import numpy as np
import scipy.sparse as sp
from .base import get_data_home
from .base import Bunch
from ..utils.fixes import makedirs
from ..externals import joblib
from .svmlight_format import load_svmlight_files
from ..utils import shuffle as shuffle_
URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/'
'a13-vector-files/lyrl2004_vectors')
URL_topics = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/'
'a08-topic-qrels/rcv1-v2.topics.qrels.gz')
logger = logging.getLogger()
def fetch_rcv1(data_home=None, download_if_missing=True,
random_state=None, shuffle=False):
"""Load the RCV1 multilabel dataset, downloading it if necessary.
Version: RCV1-v2, vectors, full sets, topics multilabels.
============== =====================
Classes 103
Samples total 804414
Dimensionality 47236
Features real, between 0 and 1
============== =====================
Read more in the :ref:`User Guide <datasets>`.
Parameters
----------
data_home : string, optional
Specify another download and cache folder for the datasets. By default
all scikit learn data is stored in '~/scikit_learn_data' subfolders.
download_if_missing : boolean, default=True
If False, raise a IOError if the data is not locally available
instead of trying to download the data from the source site.
random_state : int, RandomState instance or None, optional (default=None)
Random state for shuffling the dataset.
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
shuffle : bool, default=False
Whether to shuffle dataset.
Returns
-------
dataset : dict-like object with the following attributes:
dataset.data : scipy csr array, shape (804414, 47236)
The first 23149 samples are training samples.
The last 781265 samples are testing samples.
The array has 0.16% of non zero values.
dataset.target : scipy csr array, shape (804414, 103)
Each sample has a value of 1 in its categories, and 0 in others.
The array has 3.15% of non zero values.
dataset.sample_id : numpy array, shape (804414,)
Identification number of each sample, as ordered in dataset.data.
dataset.target_names : numpy array of object, length (103)
Names of each target (RCV1 topics), as ordered in dataset.target.
dataset.DESCR : string
Description of the RCV1 dataset.
"""
N_SAMPLES = 804414
N_FEATURES = 47236
N_CATEGORIES = 103
data_home = get_data_home(data_home=data_home)
rcv1_dir = join(data_home, "RCV1")
if download_if_missing:
makedirs(rcv1_dir, exist_ok=True)
samples_path = join(rcv1_dir, "samples.pkl")
sample_id_path = join(rcv1_dir, "sample_id.pkl")
sample_topics_path = join(rcv1_dir, "sample_topics.pkl")
topics_path = join(rcv1_dir, "topics_names.pkl")
# load data (X) and sample_id
if download_if_missing and (not exists(samples_path) or
not exists(sample_id_path)):
file_urls = ["%s_test_pt%d.dat.gz" % (URL, i) for i in range(4)]
file_urls.append("%s_train.dat.gz" % URL)
files = []
for file_url in file_urls:
logger.warning("Downloading %s" % file_url)
with closing(urlopen(file_url)) as online_file:
# buffer the full file in memory to make possible to Gzip to
# work correctly
f = BytesIO(online_file.read())
files.append(GzipFile(fileobj=f))
Xy = load_svmlight_files(files, n_features=N_FEATURES)
# Training data is before testing data
X = sp.vstack([Xy[8], Xy[0], Xy[2], Xy[4], Xy[6]]).tocsr()
sample_id = np.hstack((Xy[9], Xy[1], Xy[3], Xy[5], Xy[7]))
sample_id = sample_id.astype(np.int32)
joblib.dump(X, samples_path, compress=9)
joblib.dump(sample_id, sample_id_path, compress=9)
else:
X = joblib.load(samples_path)
sample_id = joblib.load(sample_id_path)
# load target (y), categories, and sample_id_bis
if download_if_missing and (not exists(sample_topics_path) or
not exists(topics_path)):
logger.warning("Downloading %s" % URL_topics)
with closing(urlopen(URL_topics)) as online_topics:
f = BytesIO(online_topics.read())
# parse the target file
n_cat = -1
n_doc = -1
doc_previous = -1
y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.int8)
sample_id_bis = np.zeros(N_SAMPLES, dtype=np.int32)
category_names = {}
for line in GzipFile(fileobj=f, mode='rb'):
line_components = line.decode("ascii").split(u" ")
if len(line_components) == 3:
cat, doc, _ = line_components
if cat not in category_names:
n_cat += 1
category_names[cat] = n_cat
doc = int(doc)
if doc != doc_previous:
doc_previous = doc
n_doc += 1
sample_id_bis[n_doc] = doc
y[n_doc, category_names[cat]] = 1
# Samples in X are ordered with sample_id,
# whereas in y, they are ordered with sample_id_bis.
permutation = _find_permutation(sample_id_bis, sample_id)
y = sp.csr_matrix(y[permutation, :])
# save category names in a list, with same order than y
categories = np.empty(N_CATEGORIES, dtype=object)
for k in category_names.keys():
categories[category_names[k]] = k
joblib.dump(y, sample_topics_path, compress=9)
joblib.dump(categories, topics_path, compress=9)
else:
y = joblib.load(sample_topics_path)
categories = joblib.load(topics_path)
if shuffle:
X, y, sample_id = shuffle_(X, y, sample_id, random_state=random_state)
return Bunch(data=X, target=y, sample_id=sample_id,
target_names=categories, DESCR=__doc__)
def _inverse_permutation(p):
"""inverse permutation p"""
n = p.size
s = np.zeros(n, dtype=np.int32)
i = np.arange(n, dtype=np.int32)
np.put(s, p, i) # s[p] = i
return s
def _find_permutation(a, b):
"""find the permutation from a to b"""
t = np.argsort(a)
u = np.argsort(b)
u_ = _inverse_permutation(u)
return t[u_]
"""Test the rcv1 loader.
Skipped if rcv1 is not already downloaded to data_home.
"""
import errno
import scipy.sparse as sp
import numpy as np
from sklearn.datasets import fetch_rcv1
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import SkipTest
def test_fetch_rcv1():
try:
data1 = fetch_rcv1(shuffle=False, download_if_missing=False)
except IOError as e:
if e.errno == errno.ENOENT:
raise SkipTest("Download RCV1 dataset to run this test.")
X1, Y1 = data1.data, data1.target
cat_list, s1 = data1.target_names.tolist(), data1.sample_id
# test sparsity
assert_true(sp.issparse(X1))
assert_true(sp.issparse(Y1))
assert_equal(60915113, X1.data.size)
assert_equal(2606875, Y1.data.size)
# test shapes
assert_equal((804414, 47236), X1.shape)
assert_equal((804414, 103), Y1.shape)
assert_equal((804414,), s1.shape)
assert_equal(103, len(cat_list))
# test number of sample for some categories
some_categories = ('GMIL', 'E143', 'CCAT')
number_non_zero_in_cat = (5, 1206, 381327)
for num, cat in zip(number_non_zero_in_cat, some_categories):
j = cat_list.index(cat)
assert_equal(num, Y1[:, j].data.size)
# test shuffling
data2 = fetch_rcv1(shuffle=True, random_state=77,
download_if_missing=False)
X2, Y2 = data2.data, data2.target
s2 = data2.sample_id
assert_true((s1 != s2).any())
assert_array_equal(np.sort(s1), np.sort(s2))
# test some precise values
some_sample_id = (2286, 333274, 810593)
# indice of first nonzero feature
indices = (863, 863, 814)
# value of first nonzero feature
feature = (0.04973993, 0.12272136, 0.14245221)
for i, j, v in zip(some_sample_id, indices, feature):
i1 = np.nonzero(s1 == i)[0][0]
i2 = np.nonzero(s2 == i)[0][0]
sp_1 = X1[i1].sorted_indices()
sp_2 = X2[i2].sorted_indices()
assert_almost_equal(sp_1[0, j], v)
assert_almost_equal(sp_2[0, j], v)
assert_array_equal(np.sort(Y1[i1].indices), np.sort(Y2[i2].indices))
...@@ -14,6 +14,8 @@ import inspect ...@@ -14,6 +14,8 @@ import inspect
import warnings import warnings
import sys import sys
import functools import functools
import os
import errno
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
...@@ -351,3 +353,25 @@ if np_version < (1, 6, 2): ...@@ -351,3 +353,25 @@ if np_version < (1, 6, 2):
else: else:
from numpy import bincount from numpy import bincount
if 'exist_ok' in inspect.getargspec(os.makedirs).args:
makedirs = os.makedirs
else:
def makedirs(name, mode=0o777, exist_ok=False):
"""makedirs(name [, mode=0o777][, exist_ok=False])
Super-mkdir; create a leaf directory and all intermediate ones. Works
like mkdir, except that any intermediate path segment (not just the
rightmost) will be created if it does not exist. If the target
directory already exists, raise an OSError if exist_ok is False.
Otherwise no exception is raised. This is recursive.
"""
try:
os.makedirs(name, mode=mode)
except OSError as e:
if (not exist_ok or e.errno != errno.EEXIST
or not os.path.isdir(name)):
raise
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment