diff --git a/doc/datasets/index.rst b/doc/datasets/index.rst
index cbc0cedbc02a5b1c0909cc406269c11922d680b1..f0722a928bdebe98e513d87dfb081afddc8ba62f 100644
--- a/doc/datasets/index.rst
+++ b/doc/datasets/index.rst
@@ -264,3 +264,5 @@ features::
 .. include:: labeled_faces.rst
 
 .. include:: covtype.rst
+
+.. include:: rcv1.rst
diff --git a/doc/datasets/rcv1.rst b/doc/datasets/rcv1.rst
new file mode 100644
index 0000000000000000000000000000000000000000..a957d9f91ff9a531576c41115852292d9adec65f
--- /dev/null
+++ b/doc/datasets/rcv1.rst
@@ -0,0 +1,52 @@
+
+.. _rcv1:
+
+RCV1 dataset
+============
+
+Reuters Corpus Volume I (RCV1) is an archive of over 800,000 manually categorized newswire stories made available by Reuters, Ltd. for research purposes. The dataset is extensively described in [1]_.
+
+:func:`sklearn.datasets.fetch_rcv1` will load the following version: RCV1-v2, vectors, full sets, topics multilabels::
+
+    >>> from sklearn.datasets import fetch_rcv1
+    >>> rcv1 = fetch_rcv1()
+
+It returns a dictionary-like object, with the following attributes:
+
+``data``:
+The feature matrix is a scipy CSR sparse matrix, with 804414 samples and
+47236 features. Non-zero values contains cosine-normalized, log TF-IDF vectors.
+A nearly chronological split is proposed in [1]_: The first 23149 samples are the training set. The last 781265 samples are the testing set. 
+The array has 0.16% of non zero values::
+
+    >>> rcv1.data.shape
+    (804414, 47236)
+
+``target``:
+The target values are stored in a scipy CSR sparse matrix, with 804414 samples and 103 categories. Each sample has a value of 1 in its categories, and 0 in others. The array has 3.15% of non zero values::
+
+    >>> rcv1.target.shape
+    (804414, 103)
+
+``sample_id``:
+Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596::
+
+    >>> rcv1.sample_id[:3]
+    array([2286, 2287, 2288], dtype=int32)
+
+``categories``:
+The target values are the categories of each sample. Each sample belongs to at least one category, and to up to 17 categories. 
+There are 103 categories, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT'::
+
+    >>> rcv1.categories[:3]
+    ['E11', 'ECAT', 'M11']
+
+The dataset will be downloaded from the `dataset's homepage`_ if necessary.
+The compressed size is about 656 MB.
+
+.. _dataset's homepage: http://jmlr.csail.mit.edu/papers/volume5/lewis04a/
+
+
+.. topic:: References
+
+    .. [1] Lewis, D. D., Yang, Y., Rose, T. G., & Li, F. (2004). RCV1: A new benchmark collection for text categorization research. The Journal of Machine Learning Research, 5, 361-397.
diff --git a/doc/datasets/rcv1_fixture.py b/doc/datasets/rcv1_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d27120feb8dbfd5a50266731a1e17e3b13ea02
--- /dev/null
+++ b/doc/datasets/rcv1_fixture.py
@@ -0,0 +1,15 @@
+"""Fixture module to skip the datasets loading when offline
+
+The RCV1 data is rather large and some CI workers such as travis are
+stateless hence will not cache the dataset as regular sklearn users would do.
+
+The following will skip the execution of the rcv1.rst doctests
+if the proper environment variable is configured (see the source code of
+check_skip_network for more details).
+
+"""
+from sklearn.utils.testing import check_skip_network
+
+
+def setup_module():
+    check_skip_network()
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 3df46348788237170896e904f95170c4f3aeb8fc..e88532ce7b6d46676c6647ddc00a44ff30d78fd3 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -91,6 +91,9 @@ Enhancements
    - ``dump_svmlight_file`` now handles multi-label datasets.
      By Chih-Wei Chang.
 
+   - RCV1 dataset loader (:func:`sklearn.datasets.fetch_rcv1`).
+     By `Tom Dupre la Tour`_.
+
 Bug fixes
 .........
 
diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py
index 4b041675a2a851bdc6e54291da5914bf9eec3a29..58f69cac7fc3f3df3a3f49b2b5bef34406f5f25b 100644
--- a/sklearn/datasets/__init__.py
+++ b/sklearn/datasets/__init__.py
@@ -49,6 +49,8 @@ from .svmlight_format import dump_svmlight_file
 from .olivetti_faces import fetch_olivetti_faces
 from .species_distributions import fetch_species_distributions
 from .california_housing import fetch_california_housing
+from .rcv1 import fetch_rcv1
+
 
 __all__ = ['clear_data_home',
            'dump_svmlight_file',
@@ -61,6 +63,7 @@ __all__ = ['clear_data_home',
            'fetch_species_distributions',
            'fetch_california_housing',
            'fetch_covtype',
+           'fetch_rcv1',
            'get_data_home',
            'load_boston',
            'load_diabetes',
diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py
index cc4d111c60361a60142840cd4c4b42e09db3e19b..13217652864c28795413c8a639c8211976b1a888 100644
--- a/sklearn/datasets/covtype.py
+++ b/sklearn/datasets/covtype.py
@@ -15,11 +15,9 @@ Courtesy of Jock A. Blackard and Colorado State University.
 # License: BSD 3 clause
 
 import sys
-import errno
 from gzip import GzipFile
 from io import BytesIO
 import logging
-import os
 from os.path import exists, join
 try:
     from urllib2 import urlopen
@@ -30,6 +28,7 @@ import numpy as np
 
 from .base import get_data_home
 from .base import Bunch
+from ..utils.fixes import makedirs
 from ..externals import joblib
 from ..utils import check_random_state
 
@@ -98,7 +97,7 @@ def fetch_covtype(data_home=None, download_if_missing=True,
     available = exists(samples_path)
 
     if download_if_missing and not available:
-        _mkdirp(covtype_dir)
+        makedirs(covtype_dir, exist_ok=True)
         logger.warning("Downloading %s" % URL)
         f = BytesIO(urlopen(URL).read())
         Xy = np.genfromtxt(GzipFile(fileobj=f), delimiter=',')
@@ -123,14 +122,3 @@ def fetch_covtype(data_home=None, download_if_missing=True,
         y = y[ind]
 
     return Bunch(data=X, target=y, DESCR=__doc__)
-
-
-def _mkdirp(d):
-    """Ensure directory d exists (like mkdir -p on Unix)
-    No guarantee that the directory is writable.
-    """
-    try:
-        os.makedirs(d)
-    except OSError as e:
-        if e.errno != errno.EEXIST:
-            raise
diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcd6c32f632d7a369687970b044404ed576f896
--- /dev/null
+++ b/sklearn/datasets/rcv1.py
@@ -0,0 +1,205 @@
+"""RCV1 dataset.
+"""
+
+# Author: Tom Dupre la Tour
+# License: BSD 3 clause
+
+import logging
+
+from os.path import exists, join
+from gzip import GzipFile
+from io import BytesIO
+from contextlib import closing
+
+try:
+    from urllib2 import urlopen
+except ImportError:
+    from urllib.request import urlopen
+
+import numpy as np
+import scipy.sparse as sp
+
+from .base import get_data_home
+from .base import Bunch
+from ..utils.fixes import makedirs
+from ..externals import joblib
+from .svmlight_format import load_svmlight_files
+from ..utils import shuffle as shuffle_
+
+
+URL = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/'
+       'a13-vector-files/lyrl2004_vectors')
+URL_topics = ('http://jmlr.csail.mit.edu/papers/volume5/lewis04a/'
+              'a08-topic-qrels/rcv1-v2.topics.qrels.gz')
+
+logger = logging.getLogger()
+
+
+def fetch_rcv1(data_home=None, download_if_missing=True,
+               random_state=None, shuffle=False):
+    """Load the RCV1 multilabel dataset, downloading it if necessary.
+
+    Version: RCV1-v2, vectors, full sets, topics multilabels.
+
+    ==============     =====================
+    Classes                              103
+    Samples total                     804414
+    Dimensionality                     47236
+    Features           real, between 0 and 1
+    ==============     =====================
+
+    Read more in the :ref:`User Guide <datasets>`.
+
+    Parameters
+    ----------
+    data_home : string, optional
+        Specify another download and cache folder for the datasets. By default
+        all scikit learn data is stored in '~/scikit_learn_data' subfolders.
+
+    download_if_missing : boolean, default=True
+        If False, raise a IOError if the data is not locally available
+        instead of trying to download the data from the source site.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        Random state for shuffling the dataset.
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by `np.random`.
+
+    shuffle : bool, default=False
+        Whether to shuffle dataset.
+
+    Returns
+    -------
+    dataset : dict-like object with the following attributes:
+
+    dataset.data : scipy csr array, shape (804414, 47236)
+        The first 23149 samples are training samples.
+        The last 781265 samples are testing samples.
+        The array has 0.16% of non zero values.
+
+    dataset.target : scipy csr array, shape (804414, 103)
+        Each sample has a value of 1 in its categories, and 0 in others.
+        The array has 3.15% of non zero values.
+
+    dataset.sample_id : numpy array, shape (804414,)
+        Identification number of each sample, as ordered in dataset.data.
+
+    dataset.target_names : numpy array of object, length (103)
+        Names of each target (RCV1 topics), as ordered in dataset.target.
+
+    dataset.DESCR : string
+        Description of the RCV1 dataset.
+
+    """
+    N_SAMPLES = 804414
+    N_FEATURES = 47236
+    N_CATEGORIES = 103
+
+    data_home = get_data_home(data_home=data_home)
+    rcv1_dir = join(data_home, "RCV1")
+    if download_if_missing:
+        makedirs(rcv1_dir, exist_ok=True)
+
+    samples_path = join(rcv1_dir, "samples.pkl")
+    sample_id_path = join(rcv1_dir, "sample_id.pkl")
+    sample_topics_path = join(rcv1_dir, "sample_topics.pkl")
+    topics_path = join(rcv1_dir, "topics_names.pkl")
+
+    # load data (X) and sample_id
+    if download_if_missing and (not exists(samples_path) or
+                                not exists(sample_id_path)):
+        file_urls = ["%s_test_pt%d.dat.gz" % (URL, i) for i in range(4)]
+        file_urls.append("%s_train.dat.gz" % URL)
+        files = []
+        for file_url in file_urls:
+            logger.warning("Downloading %s" % file_url)
+            with closing(urlopen(file_url)) as online_file:
+                # buffer the full file in memory to make possible to Gzip to
+                # work correctly
+                f = BytesIO(online_file.read())
+            files.append(GzipFile(fileobj=f))
+
+        Xy = load_svmlight_files(files, n_features=N_FEATURES)
+
+        # Training data is before testing data
+        X = sp.vstack([Xy[8], Xy[0], Xy[2], Xy[4], Xy[6]]).tocsr()
+        sample_id = np.hstack((Xy[9], Xy[1], Xy[3], Xy[5], Xy[7]))
+        sample_id = sample_id.astype(np.int32)
+
+        joblib.dump(X, samples_path, compress=9)
+        joblib.dump(sample_id, sample_id_path, compress=9)
+
+    else:
+        X = joblib.load(samples_path)
+        sample_id = joblib.load(sample_id_path)
+
+    # load target (y), categories, and sample_id_bis
+    if download_if_missing and (not exists(sample_topics_path) or
+                                not exists(topics_path)):
+        logger.warning("Downloading %s" % URL_topics)
+        with closing(urlopen(URL_topics)) as online_topics:
+            f = BytesIO(online_topics.read())
+
+        # parse the target file
+        n_cat = -1
+        n_doc = -1
+        doc_previous = -1
+        y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.int8)
+        sample_id_bis = np.zeros(N_SAMPLES, dtype=np.int32)
+        category_names = {}
+        for line in GzipFile(fileobj=f, mode='rb'):
+            line_components = line.decode("ascii").split(u" ")
+            if len(line_components) == 3:
+                cat, doc, _ = line_components
+                if cat not in category_names:
+                    n_cat += 1
+                    category_names[cat] = n_cat
+
+                doc = int(doc)
+                if doc != doc_previous:
+                    doc_previous = doc
+                    n_doc += 1
+                    sample_id_bis[n_doc] = doc
+                y[n_doc, category_names[cat]] = 1
+
+        # Samples in X are ordered with sample_id,
+        # whereas in y, they are ordered with sample_id_bis.
+        permutation = _find_permutation(sample_id_bis, sample_id)
+        y = sp.csr_matrix(y[permutation, :])
+
+        # save category names in a list, with same order than y
+        categories = np.empty(N_CATEGORIES, dtype=object)
+        for k in category_names.keys():
+            categories[category_names[k]] = k
+
+        joblib.dump(y, sample_topics_path, compress=9)
+        joblib.dump(categories, topics_path, compress=9)
+
+    else:
+        y = joblib.load(sample_topics_path)
+        categories = joblib.load(topics_path)
+
+    if shuffle:
+        X, y, sample_id = shuffle_(X, y, sample_id, random_state=random_state)
+
+    return Bunch(data=X, target=y, sample_id=sample_id,
+                 target_names=categories, DESCR=__doc__)
+
+
+def _inverse_permutation(p):
+    """inverse permutation p"""
+    n = p.size
+    s = np.zeros(n, dtype=np.int32)
+    i = np.arange(n, dtype=np.int32)
+    np.put(s, p, i)  # s[p] = i
+    return s
+
+
+def _find_permutation(a, b):
+    """find the permutation from a to b"""
+    t = np.argsort(a)
+    u = np.argsort(b)
+    u_ = _inverse_permutation(u)
+    return t[u_]
diff --git a/sklearn/datasets/tests/test_rcv1.py b/sklearn/datasets/tests/test_rcv1.py
new file mode 100644
index 0000000000000000000000000000000000000000..471e02afc99fac9a3792f8fda8a452449bd704c9
--- /dev/null
+++ b/sklearn/datasets/tests/test_rcv1.py
@@ -0,0 +1,71 @@
+"""Test the rcv1 loader.
+
+Skipped if rcv1 is not already downloaded to data_home.
+"""
+
+import errno
+import scipy.sparse as sp
+import numpy as np
+from sklearn.datasets import fetch_rcv1
+from sklearn.utils.testing import assert_almost_equal
+from sklearn.utils.testing import assert_array_equal
+from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_true
+from sklearn.utils.testing import SkipTest
+
+
+def test_fetch_rcv1():
+    try:
+        data1 = fetch_rcv1(shuffle=False, download_if_missing=False)
+    except IOError as e:
+        if e.errno == errno.ENOENT:
+            raise SkipTest("Download RCV1 dataset to run this test.")
+
+    X1, Y1 = data1.data, data1.target
+    cat_list, s1 = data1.target_names.tolist(), data1.sample_id
+
+    # test sparsity
+    assert_true(sp.issparse(X1))
+    assert_true(sp.issparse(Y1))
+    assert_equal(60915113, X1.data.size)
+    assert_equal(2606875, Y1.data.size)
+
+    # test shapes
+    assert_equal((804414, 47236), X1.shape)
+    assert_equal((804414, 103), Y1.shape)
+    assert_equal((804414,), s1.shape)
+    assert_equal(103, len(cat_list))
+
+    # test number of sample for some categories
+    some_categories = ('GMIL', 'E143', 'CCAT')
+    number_non_zero_in_cat = (5, 1206, 381327)
+    for num, cat in zip(number_non_zero_in_cat, some_categories):
+        j = cat_list.index(cat)
+        assert_equal(num, Y1[:, j].data.size)
+
+    # test shuffling
+    data2 = fetch_rcv1(shuffle=True, random_state=77,
+                       download_if_missing=False)
+    X2, Y2 = data2.data, data2.target
+    s2 = data2.sample_id
+
+    assert_true((s1 != s2).any())
+    assert_array_equal(np.sort(s1), np.sort(s2))
+
+    # test some precise values
+    some_sample_id = (2286, 333274, 810593)
+    # indice of first nonzero feature
+    indices = (863, 863, 814)
+    # value of first nonzero feature
+    feature = (0.04973993, 0.12272136, 0.14245221)
+    for i, j, v in zip(some_sample_id, indices, feature):
+        i1 = np.nonzero(s1 == i)[0][0]
+        i2 = np.nonzero(s2 == i)[0][0]
+
+        sp_1 = X1[i1].sorted_indices()
+        sp_2 = X2[i2].sorted_indices()
+
+        assert_almost_equal(sp_1[0, j], v)
+        assert_almost_equal(sp_2[0, j], v)
+
+        assert_array_equal(np.sort(Y1[i1].indices), np.sort(Y2[i2].indices))
diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py
index b66a3b9ff2a759b7a5806acfc7f0c36e5db54c87..73a0331b37b00f5e782fd3a5c1687d099226cc97 100644
--- a/sklearn/utils/fixes.py
+++ b/sklearn/utils/fixes.py
@@ -14,6 +14,8 @@ import inspect
 import warnings
 import sys
 import functools
+import os
+import errno
 
 import numpy as np
 import scipy.sparse as sp
@@ -351,3 +353,25 @@ if np_version < (1, 6, 2):
 
 else:
     from numpy import bincount
+
+
+if 'exist_ok' in inspect.getargspec(os.makedirs).args:
+    makedirs = os.makedirs
+else:
+    def makedirs(name, mode=0o777, exist_ok=False):
+        """makedirs(name [, mode=0o777][, exist_ok=False])
+
+        Super-mkdir; create a leaf directory and all intermediate ones.  Works
+        like mkdir, except that any intermediate path segment (not just the
+        rightmost) will be created if it does not exist. If the target
+        directory already exists, raise an OSError if exist_ok is False.
+        Otherwise no exception is raised.  This is recursive.
+
+        """
+
+        try:
+            os.makedirs(name, mode=mode)
+        except OSError as e:
+            if (not exist_ok or e.errno != errno.EEXIST
+                    or not os.path.isdir(name)):
+                raise