From d7c07abbc50b536531a5f622a089690d5ff5faa3 Mon Sep 17 00:00:00 2001 From: Gael Varoquaux <gael.varoquaux@normalesup.org> Date: Tue, 20 Apr 2010 22:47:50 +0000 Subject: [PATCH] ENH/DOC: Add an example doing classification on digits. git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@669 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8 --- examples/plot_digits_classification.py | 52 ++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 examples/plot_digits_classification.py diff --git a/examples/plot_digits_classification.py b/examples/plot_digits_classification.py new file mode 100644 index 0000000000..3972ec0483 --- /dev/null +++ b/examples/plot_digits_classification.py @@ -0,0 +1,52 @@ +""" +================================ +Recognizing hand-written digits +================================ + +An example showing how the scikit-learn can be used to recognize images of +hand-written digits. + +""" +# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org> +# License: Simplified BSD + +# Standard scientific Python imports +import pylab as pl + +# The digits dataset +from scikits.learn import datasets +digits = datasets.load_digits() + +# The data that we are interesting in is made of 8x8 images of digits, +# let's have a look at the first 3 images. We know which digit they +# represent: it is given in the 'target' of the dataset. +for index, (image, label) in enumerate(zip(digits.images, digits.target)[:4]): + pl.subplot(2, 4, index+1) + pl.imshow(image, cmap=pl.cm.gray_r) + pl.title('Training: %i' % label) + +# To apply an classifier on this data, we need to flatten the image, to +# turn the data in a (samples, feature) matrix: +n_features = len(digits.images) +data = digits.images.reshape((n_features, -1)) + +# Import a classifier: +from scikits.learn import svm +classifier = svm.SVC() + +# We learn the digits on the first half of the digits +classifier.fit(data[:n_features/2], digits.target[:n_features/2]) + +# Now predict the value of the digit on the second half: +predicted = classifier.predict(data[n_features/2:]) + +for index, (image, prediction) in enumerate(zip( + digits.images[n_features/2:], + predicted + )[:4]): + pl.subplot(2, 4, index+5) + pl.imshow(image, cmap=pl.cm.gray_r) + pl.title('Prediction: %i' % prediction) + + +pl.show() -- GitLab