diff --git a/doc/tutorial.rst b/doc/tutorial.rst
index d15496c79bc843f65a9300a909ca816372704c75..565069b16bd5ddf27df82916657c06ed896a2cce 100644
--- a/doc/tutorial.rst
+++ b/doc/tutorial.rst
@@ -16,11 +16,11 @@ single number, and for instance a multi-dimensional entry (aka
 *multivariate* data), is it said to have several attributes, or
 *features*.
 
-We can separate learning problems in a few large categories: 
+We can separate learning problems in a few large categories:
 
  * **supervised learning**, in which the data comes with additional
    attributes that we want to predict. This problem can be either:
-   
+
     * **classification**: samples belong to two or more classes and we
       want to learn from already labeled data how to predict the class
       of un-labeled data. An example of classification problem would
@@ -45,7 +45,7 @@ We can separate learning problems in a few large categories:
 .. topic:: Training set and testing set
 
     Machine learning is about learning some properties of a data set and
-    applying them to new data. This is why a common practice in machine 
+    applying them to new data. This is why a common practice in machine
     learning to evaluate an algorithm is to split the data at hand in two
     sets, one that we call a *training set* on which we learn data
     properties, and one that we call a *testing set*, on which we test
@@ -76,7 +76,7 @@ access to the features that can be used to classify the digits samples::
     array([[  0.,   0.,   5., ...,   0.,   0.,   0.],
            [  0.,   0.,   0., ...,  10.,   0.,   0.],
            [  0.,   0.,   0., ...,  16.,   9.,   0.],
-           ..., 
+           ...,
            [  0.,   0.,   1., ...,   6.,   0.,   0.],
            [  0.,   0.,   2., ...,  12.,   0.,   0.],
            [  0.,   0.,  10., ...,  12.,   1.,   0.]])
@@ -89,7 +89,7 @@ learn:
 array([0, 1, 2, ..., 8, 9, 8])
 
 .. topic:: Shape of the data arrays
-   
+
     The data is always a 2D array, `n_samples, n_features`, although
     the original data may have had a different shape. In the case of the
     digits, each original sample is an image of shape `8, 8` and can be
@@ -106,10 +106,19 @@ array([0, 1, 2, ..., 8, 9, 8])
            [  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])
 
     The :ref:`simple example on this dataset <example_plot_digits_classification.py>`
-    illustrates how starting from the original problem one can shape the 
+    illustrates how starting from the original problem one can shape the
     data for consumption in the `scikit.learn`.
 
 
+``scikits.learn`` also offers the possibility to reuse external datasets coming
+from the http://mlcomp.org online service that provides a repository of public
+datasets for various tasks (binary & multi label classification, regression,
+document classification, ...) along with a runtime environment to compare
+program performance on those datasets. Please refer to the following example for
+for instructions on the ``mlcomp`` dataset loader:
+:ref:`example_mlcomp_document_classification.py`.
+
+
 Learning and Predicting
 ------------------------
 
@@ -158,4 +167,4 @@ resolution. Do you agree with the classifier?
 
 A complete example of this classification problem is available as an
 example that you can run and study:
-:ref:`example_plot_digits_classification.py`. 
+:ref:`example_plot_digits_classification.py`.
diff --git a/examples/mlcomp_document_classification.py b/examples/mlcomp_document_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..2782a393b89e584b8f07b00a690ad36db36a2e75
--- /dev/null
+++ b/examples/mlcomp_document_classification.py
@@ -0,0 +1,99 @@
+"""
+================================
+Classification of text documents
+================================
+
+This is an example showing how the scikit-learn can be used to classify
+documents by topics using a bag-of-words approach.
+
+The dataset used in this example is the 20 newsgroups dataset and should be
+downloaded from the http://mlcomp.org (free registration required):
+
+  http://mlcomp.org/datasets/379
+
+Once downloaded unzip the arhive somewhere on your filesystem. For instance in::
+
+  % mkdir -p ~/data/mlcomp
+  % cd  ~/data/mlcomp
+  % unzip /path/to/dataset-379-20news-18828_XXXXX.zip
+
+You should get a folder ``~/data/mlcomp/379`` with a file named ``metadata`` and
+subfolders ``raw``, ``train`` and ``test`` holding the text documents organized by
+newsgroups.
+
+Then set the ``MLCOMP_DATASETS_HOME`` environment variable pointing to
+the root folder holding the uncompressed archive::
+
+  % export MLCOMP_DATASETS_HOME="~/data/mlcomp"
+
+Then you are ready to run this example using your favorite python shell::
+
+  % ipython examples/mlcomp_document_classification.py
+
+"""
+# Author: Olivier Grisel <olivier.grisel@ensta.org>
+# License: Simplified BSD
+
+from time import time
+import sys
+import os
+import numpy as np
+import pylab as pl
+from scikits.learn.datasets import load_mlcomp
+from scikits.learn.logistic import LogisticRegression
+from scikits.learn.svm import LinearSVC
+from scikits.learn.metrics import confusion_matrix
+
+if 'MLCOMP_DATASETS_HOME' not in os.environ:
+    print "Please follow those instructions to get started:"
+    print __doc__
+    sys.exit(0)
+
+# Load the training set
+print "Loading 20 newsgroups training set... "
+t0 = time()
+news_train = load_mlcomp('20news-18828', 'train')
+print "done in %fs" % (time() - t0)
+
+# The documents have been hashed into TF-IDF (Term Frequencies times Inverse
+# Document Frequencies) vectors of a fixed dimension.
+# Currently most scikits.learn wrappers or algorithm implementations are unable
+# to leverage efficiently a sparse datastracture; hence we use of a this dense
+# representation of a text dataset. Efficient handling of sparse data
+# structures should be expected for in an upcoming version of scikits.learn
+print "n_samples: %d, n_features: %d" % news_train.data.shape
+
+print "Training a linear classification model with L1 penalty... "
+parameters = {
+    'loss': 'l2',
+    'penalty': 'l1',
+    'C': 10,
+    'dual': False,
+    'eps': 1e-4,
+}
+print "parameters:", parameters
+t0 = time()
+clf = LinearSVC(**parameters).fit(news_train.data, news_train.target)
+print "done in %fs" % (time() - t0)
+print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100)
+
+print "Loading 20 newsgroups test set... "
+t0 = time()
+news_test = load_mlcomp('20news-18828', 'test')
+print "done in %fs" % (time() - t0)
+
+print "Predicting the labels of the test set..."
+t0 = time()
+pred = clf.predict(news_test.data)
+print "done in %fs" % (time() - t0)
+print "Classification accuracy: %f" % (np.mean(pred == news_test.target) * 100)
+
+cm = confusion_matrix(news_test.target, pred)
+print "Confusion matrix:"
+print cm
+
+# Show confusion matrix
+pl.matshow(cm)
+pl.title('Confusion matrix')
+pl.colorbar()
+pl.show()
diff --git a/scikits/learn/datasets/mlcomp.py b/scikits/learn/datasets/mlcomp.py
index 42a0deb8ebd3d2e77b5c3c4dfe8edfb74d800d44..f6678dbe32c168ffdf4a5a52db49acf622b11fb1 100644
--- a/scikits/learn/datasets/mlcomp.py
+++ b/scikits/learn/datasets/mlcomp.py
@@ -3,6 +3,7 @@
 """Glue code to load http://mlcomp.org data as a scikit.learn dataset"""
 
 import os
+import numpy as np
 from scikits.learn.datasets.base import Bunch
 from scikits.learn.features.text import HashingVectorizer
 
@@ -14,6 +15,12 @@ def load_document_classification(dataset_path, metadata, set_, **kw):
     filenames = []
     vectorizer = kw.get('vectorizer', HashingVectorizer())
 
+    # TODO: make it possible to plug a several pass system to filter-out tokens
+    # that occur in more than 30% of the documents for instance.
+
+    # TODO: use joblib.Parallel or multiprocessing to parallelize the following
+    # (provided this is not IO bound)
+
     dataset_path = os.path.join(dataset_path, set_)
     folders = [f for f in sorted(os.listdir(dataset_path))
                if os.path.isdir(os.path.join(dataset_path, f))]
@@ -26,7 +33,7 @@ def load_document_classification(dataset_path, metadata, set_, **kw):
         target.extend(len(documents) * [label])
         filenames.extend(documents)
 
-    return Bunch(data=vectorizer.get_vectors(), target=target,
+    return Bunch(data=vectorizer.get_vectors(), target=np.array(target),
                  target_names=target_names, filenames=filenames,
                  DESCR=metadata.get('description'))
 
diff --git a/scikits/learn/features/text.py b/scikits/learn/features/text.py
index 9835401388d6d9ae934d3f18c87d4f6c2f386966..e77a5f842a092339947ba293195fea0eba6b6573 100644
--- a/scikits/learn/features/text.py
+++ b/scikits/learn/features/text.py
@@ -26,6 +26,8 @@ class SimpleAnalyzer(object):
 
     token_pattern = re.compile(r"\b\w\w+\b", re.U)
 
+    # TODO: make it possible to pass stop words list here
+
     def __init__(self, default_charset='utf-8'):
         self.charset = default_charset
 
@@ -58,7 +60,11 @@ class HashingVectorizer(object):
     # TODO: implement me using the murmurhash that might be faster: but profile
     # me first :)
 
-    def __init__(self, dim=1000, probes=3, analyzer=SimpleAnalyzer(),
+    # TODO: make it possible to select between the current dense representation
+    # and sparse alternatives from scipy.sparse once the liblinear and libsvm
+    # wrappers have been updated to be able to handle it efficiently
+
+    def __init__(self, dim=5000, probes=3, analyzer=SimpleAnalyzer(),
                  use_idf=True):
         self.dim = dim
         self.probes = probes
@@ -121,7 +127,6 @@ class HashingVectorizer(object):
         else:
             self.tf_vectors = np.vstack((self.tf_vectors, tf_vectors))
 
-
     def get_vectors(self):
         if self.use_idf:
             return self.get_tfidf()
@@ -129,5 +134,3 @@ class HashingVectorizer(object):
             return self.tf_vectors
 
 
-
-