From 434eed65b4b2eb5f0f8814467576cd808742ccfd Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Sun, 4 Jul 2010 22:40:27 +0200 Subject: [PATCH] example usage of MLComp document classification datasets --- doc/tutorial.rst | 23 +++-- examples/mlcomp_document_classification.py | 99 ++++++++++++++++++++++ scikits/learn/datasets/mlcomp.py | 9 +- scikits/learn/features/text.py | 11 ++- 4 files changed, 130 insertions(+), 12 deletions(-) create mode 100644 examples/mlcomp_document_classification.py diff --git a/doc/tutorial.rst b/doc/tutorial.rst index d15496c79b..565069b16b 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -16,11 +16,11 @@ single number, and for instance a multi-dimensional entry (aka *multivariate* data), is it said to have several attributes, or *features*. -We can separate learning problems in a few large categories: +We can separate learning problems in a few large categories: * **supervised learning**, in which the data comes with additional attributes that we want to predict. This problem can be either: - + * **classification**: samples belong to two or more classes and we want to learn from already labeled data how to predict the class of un-labeled data. An example of classification problem would @@ -45,7 +45,7 @@ We can separate learning problems in a few large categories: .. topic:: Training set and testing set Machine learning is about learning some properties of a data set and - applying them to new data. This is why a common practice in machine + applying them to new data. This is why a common practice in machine learning to evaluate an algorithm is to split the data at hand in two sets, one that we call a *training set* on which we learn data properties, and one that we call a *testing set*, on which we test @@ -76,7 +76,7 @@ access to the features that can be used to classify the digits samples:: array([[ 0., 0., 5., ..., 0., 0., 0.], [ 0., 0., 0., ..., 10., 0., 0.], [ 0., 0., 0., ..., 16., 9., 0.], - ..., + ..., [ 0., 0., 1., ..., 6., 0., 0.], [ 0., 0., 2., ..., 12., 0., 0.], [ 0., 0., 10., ..., 12., 1., 0.]]) @@ -89,7 +89,7 @@ learn: array([0, 1, 2, ..., 8, 9, 8]) .. topic:: Shape of the data arrays - + The data is always a 2D array, `n_samples, n_features`, although the original data may have had a different shape. In the case of the digits, each original sample is an image of shape `8, 8` and can be @@ -106,10 +106,19 @@ array([0, 1, 2, ..., 8, 9, 8]) [ 0., 0., 6., 13., 10., 0., 0., 0.]]) The :ref:`simple example on this dataset <example_plot_digits_classification.py>` - illustrates how starting from the original problem one can shape the + illustrates how starting from the original problem one can shape the data for consumption in the `scikit.learn`. +``scikits.learn`` also offers the possibility to reuse external datasets coming +from the http://mlcomp.org online service that provides a repository of public +datasets for various tasks (binary & multi label classification, regression, +document classification, ...) along with a runtime environment to compare +program performance on those datasets. Please refer to the following example for +for instructions on the ``mlcomp`` dataset loader: +:ref:`example_mlcomp_document_classification.py`. + + Learning and Predicting ------------------------ @@ -158,4 +167,4 @@ resolution. Do you agree with the classifier? A complete example of this classification problem is available as an example that you can run and study: -:ref:`example_plot_digits_classification.py`. +:ref:`example_plot_digits_classification.py`. diff --git a/examples/mlcomp_document_classification.py b/examples/mlcomp_document_classification.py new file mode 100644 index 0000000000..2782a393b8 --- /dev/null +++ b/examples/mlcomp_document_classification.py @@ -0,0 +1,99 @@ +""" +================================ +Classification of text documents +================================ + +This is an example showing how the scikit-learn can be used to classify +documents by topics using a bag-of-words approach. + +The dataset used in this example is the 20 newsgroups dataset and should be +downloaded from the http://mlcomp.org (free registration required): + + http://mlcomp.org/datasets/379 + +Once downloaded unzip the arhive somewhere on your filesystem. For instance in:: + + % mkdir -p ~/data/mlcomp + % cd ~/data/mlcomp + % unzip /path/to/dataset-379-20news-18828_XXXXX.zip + +You should get a folder ``~/data/mlcomp/379`` with a file named ``metadata`` and +subfolders ``raw``, ``train`` and ``test`` holding the text documents organized by +newsgroups. + +Then set the ``MLCOMP_DATASETS_HOME`` environment variable pointing to +the root folder holding the uncompressed archive:: + + % export MLCOMP_DATASETS_HOME="~/data/mlcomp" + +Then you are ready to run this example using your favorite python shell:: + + % ipython examples/mlcomp_document_classification.py + +""" +# Author: Olivier Grisel <olivier.grisel@ensta.org> +# License: Simplified BSD + +from time import time +import sys +import os +import numpy as np +import pylab as pl +from scikits.learn.datasets import load_mlcomp +from scikits.learn.logistic import LogisticRegression +from scikits.learn.svm import LinearSVC +from scikits.learn.metrics import confusion_matrix + +if 'MLCOMP_DATASETS_HOME' not in os.environ: + print "Please follow those instructions to get started:" + print __doc__ + sys.exit(0) + +# Load the training set +print "Loading 20 newsgroups training set... " +t0 = time() +news_train = load_mlcomp('20news-18828', 'train') +print "done in %fs" % (time() - t0) + +# The documents have been hashed into TF-IDF (Term Frequencies times Inverse +# Document Frequencies) vectors of a fixed dimension. +# Currently most scikits.learn wrappers or algorithm implementations are unable +# to leverage efficiently a sparse datastracture; hence we use of a this dense +# representation of a text dataset. Efficient handling of sparse data +# structures should be expected for in an upcoming version of scikits.learn +print "n_samples: %d, n_features: %d" % news_train.data.shape + +print "Training a linear classification model with L1 penalty... " +parameters = { + 'loss': 'l2', + 'penalty': 'l1', + 'C': 10, + 'dual': False, + 'eps': 1e-4, +} +print "parameters:", parameters +t0 = time() +clf = LinearSVC(**parameters).fit(news_train.data, news_train.target) +print "done in %fs" % (time() - t0) +print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100) + +print "Loading 20 newsgroups test set... " +t0 = time() +news_test = load_mlcomp('20news-18828', 'test') +print "done in %fs" % (time() - t0) + +print "Predicting the labels of the test set..." +t0 = time() +pred = clf.predict(news_test.data) +print "done in %fs" % (time() - t0) +print "Classification accuracy: %f" % (np.mean(pred == news_test.target) * 100) + +cm = confusion_matrix(news_test.target, pred) +print "Confusion matrix:" +print cm + +# Show confusion matrix +pl.matshow(cm) +pl.title('Confusion matrix') +pl.colorbar() +pl.show() diff --git a/scikits/learn/datasets/mlcomp.py b/scikits/learn/datasets/mlcomp.py index 42a0deb8eb..f6678dbe32 100644 --- a/scikits/learn/datasets/mlcomp.py +++ b/scikits/learn/datasets/mlcomp.py @@ -3,6 +3,7 @@ """Glue code to load http://mlcomp.org data as a scikit.learn dataset""" import os +import numpy as np from scikits.learn.datasets.base import Bunch from scikits.learn.features.text import HashingVectorizer @@ -14,6 +15,12 @@ def load_document_classification(dataset_path, metadata, set_, **kw): filenames = [] vectorizer = kw.get('vectorizer', HashingVectorizer()) + # TODO: make it possible to plug a several pass system to filter-out tokens + # that occur in more than 30% of the documents for instance. + + # TODO: use joblib.Parallel or multiprocessing to parallelize the following + # (provided this is not IO bound) + dataset_path = os.path.join(dataset_path, set_) folders = [f for f in sorted(os.listdir(dataset_path)) if os.path.isdir(os.path.join(dataset_path, f))] @@ -26,7 +33,7 @@ def load_document_classification(dataset_path, metadata, set_, **kw): target.extend(len(documents) * [label]) filenames.extend(documents) - return Bunch(data=vectorizer.get_vectors(), target=target, + return Bunch(data=vectorizer.get_vectors(), target=np.array(target), target_names=target_names, filenames=filenames, DESCR=metadata.get('description')) diff --git a/scikits/learn/features/text.py b/scikits/learn/features/text.py index 9835401388..e77a5f842a 100644 --- a/scikits/learn/features/text.py +++ b/scikits/learn/features/text.py @@ -26,6 +26,8 @@ class SimpleAnalyzer(object): token_pattern = re.compile(r"\b\w\w+\b", re.U) + # TODO: make it possible to pass stop words list here + def __init__(self, default_charset='utf-8'): self.charset = default_charset @@ -58,7 +60,11 @@ class HashingVectorizer(object): # TODO: implement me using the murmurhash that might be faster: but profile # me first :) - def __init__(self, dim=1000, probes=3, analyzer=SimpleAnalyzer(), + # TODO: make it possible to select between the current dense representation + # and sparse alternatives from scipy.sparse once the liblinear and libsvm + # wrappers have been updated to be able to handle it efficiently + + def __init__(self, dim=5000, probes=3, analyzer=SimpleAnalyzer(), use_idf=True): self.dim = dim self.probes = probes @@ -121,7 +127,6 @@ class HashingVectorizer(object): else: self.tf_vectors = np.vstack((self.tf_vectors, tf_vectors)) - def get_vectors(self): if self.use_idf: return self.get_tfidf() @@ -129,5 +134,3 @@ class HashingVectorizer(object): return self.tf_vectors - - -- GitLab