diff --git a/doc/datasets/rcv1.rst b/doc/datasets/rcv1.rst index a957d9f91ff9a531576c41115852292d9adec65f..486eeee90557881e3af1e7c4293e976b9e9cda79 100644 --- a/doc/datasets/rcv1.rst +++ b/doc/datasets/rcv1.rst @@ -16,7 +16,7 @@ It returns a dictionary-like object, with the following attributes: ``data``: The feature matrix is a scipy CSR sparse matrix, with 804414 samples and 47236 features. Non-zero values contains cosine-normalized, log TF-IDF vectors. -A nearly chronological split is proposed in [1]_: The first 23149 samples are the training set. The last 781265 samples are the testing set. +A nearly chronological split is proposed in [1]_: The first 23149 samples are the training set. The last 781265 samples are the testing set. This follows the official LYRL2004 chronological split. The array has 0.16% of non zero values:: >>> rcv1.data.shape @@ -34,11 +34,11 @@ Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596 >>> rcv1.sample_id[:3] array([2286, 2287, 2288], dtype=int32) -``categories``: -The target values are the categories of each sample. Each sample belongs to at least one category, and to up to 17 categories. -There are 103 categories, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT':: +``target_names``: +The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics. +There are 103 topics, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT':: - >>> rcv1.categories[:3] + >>> rcv1.target_names[:3].tolist() # doctest: +SKIP ['E11', 'ECAT', 'M11'] The dataset will be downloaded from the `dataset's homepage`_ if necessary. diff --git a/doc/datasets/rcv1_fixture.py b/doc/datasets/rcv1_fixture.py index 19d27120feb8dbfd5a50266731a1e17e3b13ea02..c409f2f937a648ebc4f9eac6279282b0768e3bad 100644 --- a/doc/datasets/rcv1_fixture.py +++ b/doc/datasets/rcv1_fixture.py @@ -8,8 +8,16 @@ if the proper environment variable is configured (see the source code of check_skip_network for more details). """ -from sklearn.utils.testing import check_skip_network +from sklearn.utils.testing import check_skip_network, SkipTest +import os +from sklearn.datasets import get_data_home def setup_module(): check_skip_network() + + # skip the test in rcv1.rst if the dataset is not already loaded + rcv1_dir = os.path.join(get_data_home(), "RCV1") + if not os.path.exists(rcv1_dir): + raise SkipTest("Download RCV1 dataset to run this test.") + diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 9895d5b35a4aebbbe0928ad660feeffa60cd0099..ff5708efeca7c030aa005ac8b925badb238a2f55 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -218,6 +218,7 @@ Loaders datasets.fetch_olivetti_faces datasets.fetch_california_housing datasets.fetch_covtype + datasets.fetch_rcv1 datasets.load_mlcomp datasets.load_sample_image datasets.load_sample_images diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py index 2fcd6c32f632d7a369687970b044404ed576f896..090f03560d6ff76b08f061acf03b7a0d7b3fb9cd 100644 --- a/sklearn/datasets/rcv1.py +++ b/sklearn/datasets/rcv1.py @@ -35,7 +35,7 @@ URL_topics = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/' logger = logging.getLogger() -def fetch_rcv1(data_home=None, download_if_missing=True, +def fetch_rcv1(data_home=None, subset='all', download_if_missing=True, random_state=None, shuffle=False): """Load the RCV1 multilabel dataset, downloading it if necessary. @@ -56,6 +56,12 @@ def fetch_rcv1(data_home=None, download_if_missing=True, Specify another download and cache folder for the datasets. By default all scikit learn data is stored in '~/scikit_learn_data' subfolders. + subset: string, 'train', 'test', or 'all', default='all' + Select the dataset to load: 'train' for the training set + (23149 samples), 'test' for the test set (781265 samples), + 'all' for both, with the training samples first if shuffle is False. + This follows the official LYRL2004 chronological split. + download_if_missing : boolean, default=True If False, raise a IOError if the data is not locally available instead of trying to download the data from the source site. @@ -74,28 +80,33 @@ def fetch_rcv1(data_home=None, download_if_missing=True, ------- dataset : dict-like object with the following attributes: - dataset.data : scipy csr array, shape (804414, 47236) - The first 23149 samples are training samples. - The last 781265 samples are testing samples. + dataset.data : scipy csr array, dtype np.float64, shape (804414, 47236) The array has 0.16% of non zero values. - dataset.target : scipy csr array, shape (804414, 103) + dataset.target : scipy csr array, dtype np.uint8, shape (804414, 103) Each sample has a value of 1 in its categories, and 0 in others. The array has 3.15% of non zero values. - dataset.sample_id : numpy array, shape (804414,) + dataset.sample_id : numpy array, dtype np.uint32, shape (804414,) Identification number of each sample, as ordered in dataset.data. - dataset.target_names : numpy array of object, length (103) + dataset.target_names : numpy array, dtype object, length (103) Names of each target (RCV1 topics), as ordered in dataset.target. dataset.DESCR : string Description of the RCV1 dataset. + Reference + --------- + Lewis, D. D., Yang, Y., Rose, T. G., & Li, F. (2004). RCV1: A new + benchmark collection for text categorization research. The Journal of + Machine Learning Research, 5, 361-397. + """ N_SAMPLES = 804414 N_FEATURES = 47236 N_CATEGORIES = 103 + N_TRAIN = 23149 data_home = get_data_home(data_home=data_home) rcv1_dir = join(data_home, "RCV1") @@ -126,7 +137,7 @@ def fetch_rcv1(data_home=None, download_if_missing=True, # Training data is before testing data X = sp.vstack([Xy[8], Xy[0], Xy[2], Xy[4], Xy[6]]).tocsr() sample_id = np.hstack((Xy[9], Xy[1], Xy[3], Xy[5], Xy[7])) - sample_id = sample_id.astype(np.int32) + sample_id = sample_id.astype(np.uint32) joblib.dump(X, samples_path, compress=9) joblib.dump(sample_id, sample_id_path, compress=9) @@ -146,7 +157,7 @@ def fetch_rcv1(data_home=None, download_if_missing=True, n_cat = -1 n_doc = -1 doc_previous = -1 - y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.int8) + y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.uint8) sample_id_bis = np.zeros(N_SAMPLES, dtype=np.int32) category_names = {} for line in GzipFile(fileobj=f, mode='rb'): @@ -181,6 +192,20 @@ def fetch_rcv1(data_home=None, download_if_missing=True, y = joblib.load(sample_topics_path) categories = joblib.load(topics_path) + if subset == 'all': + pass + elif subset == 'train': + X = X[:N_TRAIN, :] + y = y[:N_TRAIN, :] + sample_id = sample_id[:N_TRAIN] + elif subset == 'test': + X = X[N_TRAIN:, :] + y = y[N_TRAIN:, :] + sample_id = sample_id[N_TRAIN:] + else: + raise ValueError("Unknown subset parameter. Got '%s' instead of one" + " of ('all', 'train', test')" % subset) + if shuffle: X, y, sample_id = shuffle_(X, y, sample_id, random_state=random_state) diff --git a/sklearn/datasets/tests/test_rcv1.py b/sklearn/datasets/tests/test_rcv1.py index 471e02afc99fac9a3792f8fda8a452449bd704c9..e9833f215e9b04286a6c2be2d860b78a1b7770d7 100644 --- a/sklearn/datasets/tests/test_rcv1.py +++ b/sklearn/datasets/tests/test_rcv1.py @@ -43,29 +43,25 @@ def test_fetch_rcv1(): j = cat_list.index(cat) assert_equal(num, Y1[:, j].data.size) - # test shuffling - data2 = fetch_rcv1(shuffle=True, random_state=77, + # test shuffling and subset + data2 = fetch_rcv1(shuffle=True, subset='train', random_state=77, download_if_missing=False) X2, Y2 = data2.data, data2.target s2 = data2.sample_id - assert_true((s1 != s2).any()) - assert_array_equal(np.sort(s1), np.sort(s2)) + # The first 23149 samples are the training samples + assert_array_equal(np.sort(s1[:23149]), np.sort(s2)) # test some precise values - some_sample_id = (2286, 333274, 810593) - # indice of first nonzero feature - indices = (863, 863, 814) - # value of first nonzero feature - feature = (0.04973993, 0.12272136, 0.14245221) - for i, j, v in zip(some_sample_id, indices, feature): - i1 = np.nonzero(s1 == i)[0][0] - i2 = np.nonzero(s2 == i)[0][0] + some_sample_ids = (2286, 3274, 14042) + for sample_id in some_sample_ids: + idx1 = s1.tolist().index(sample_id) + idx2 = s2.tolist().index(sample_id) - sp_1 = X1[i1].sorted_indices() - sp_2 = X2[i2].sorted_indices() + feature_values_1 = X1[idx1, :].toarray() + feature_values_2 = X2[idx2, :].toarray() + assert_almost_equal(feature_values_1, feature_values_2) - assert_almost_equal(sp_1[0, j], v) - assert_almost_equal(sp_2[0, j], v) - - assert_array_equal(np.sort(Y1[i1].indices), np.sort(Y2[i2].indices)) + target_values_1 = Y1[idx1, :].toarray() + target_values_2 = Y2[idx2, :].toarray() + assert_almost_equal(target_values_1, target_values_2)