Skip to content
Snippets Groups Projects
Commit af6ab92b authored by Fabian Pedregosa's avatar Fabian Pedregosa
Browse files

Add optional parameter n_class to load_digits.

parent 1409e01e
No related branches found
No related tags found
No related merge requests found
......@@ -205,9 +205,15 @@ def load_iris():
DESCR=fdescr.read())
def load_digits():
def load_digits(n_class=10):
"""load the digits dataset and returns it.
Parameters
----------
n_class : integer, between 0 and 10
Number of classes to return, defaults to 10
Returns
-------
data : Bunch
......@@ -237,6 +243,12 @@ def load_digits():
flat_data = data[:, :-1]
images = flat_data.view()
images.shape = (-1, 8, 8)
if n_class < 10:
idx = target < n_class
flat_data, target = flat_data[idx], target[idx]
images = images[idx]
return Bunch(data=flat_data, target=target.astype(np.int),
target_names=np.arange(10),
images=images,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment