From af6ab92b3bc0286e401218631859ee50f8be23f7 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Mon, 2 May 2011 16:05:21 +0200 Subject: [PATCH] Add optional parameter n_class to load_digits. --- scikits/learn/datasets/base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scikits/learn/datasets/base.py b/scikits/learn/datasets/base.py index cb2b962d51..5f3996fb2d 100644 --- a/scikits/learn/datasets/base.py +++ b/scikits/learn/datasets/base.py @@ -205,9 +205,15 @@ def load_iris(): DESCR=fdescr.read()) -def load_digits(): +def load_digits(n_class=10): """load the digits dataset and returns it. + + Parameters + ---------- + n_class : integer, between 0 and 10 + Number of classes to return, defaults to 10 + Returns ------- data : Bunch @@ -237,6 +243,12 @@ def load_digits(): flat_data = data[:, :-1] images = flat_data.view() images.shape = (-1, 8, 8) + + if n_class < 10: + idx = target < n_class + flat_data, target = flat_data[idx], target[idx] + images = images[idx] + return Bunch(data=flat_data, target=target.astype(np.int), target_names=np.arange(10), images=images, -- GitLab