From af6ab92b3bc0286e401218631859ee50f8be23f7 Mon Sep 17 00:00:00 2001
From: Fabian Pedregosa <fabian.pedregosa@inria.fr>
Date: Mon, 2 May 2011 16:05:21 +0200
Subject: [PATCH] Add optional parameter n_class to load_digits.

---
 scikits/learn/datasets/base.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/scikits/learn/datasets/base.py b/scikits/learn/datasets/base.py
index cb2b962d51..5f3996fb2d 100644
--- a/scikits/learn/datasets/base.py
+++ b/scikits/learn/datasets/base.py
@@ -205,9 +205,15 @@ def load_iris():
                  DESCR=fdescr.read())
 
 
-def load_digits():
+def load_digits(n_class=10):
     """load the digits dataset and returns it.
 
+
+    Parameters
+    ----------
+    n_class : integer, between 0 and 10
+        Number of classes to return, defaults to 10
+
     Returns
     -------
     data : Bunch
@@ -237,6 +243,12 @@ def load_digits():
     flat_data = data[:, :-1]
     images = flat_data.view()
     images.shape = (-1, 8, 8)
+
+    if n_class < 10:
+        idx = target < n_class
+        flat_data, target = flat_data[idx], target[idx]
+        images = images[idx]
+
     return Bunch(data=flat_data, target=target.astype(np.int),
                  target_names=np.arange(10),
                  images=images,
-- 
GitLab