diff --git a/examples/document_classification_20newsgroups.py b/examples/document_classification_20newsgroups.py index 61bab1b878d70293c1c52d0374ece17b3a568956..4a70c666ec5a0087be41eb6e16b337d2e604b936 100644 --- a/examples/document_classification_20newsgroups.py +++ b/examples/document_classification_20newsgroups.py @@ -92,22 +92,16 @@ categories = [ print "Loading 20 newsgroups dataset for categories:" print categories -data = load_files('20news-18828', categories=categories) +data = load_files('20news-18828', categories=categories, shuffle=True, rng=42) print "%d documents" % len(data.filenames) print "%d categories" % len(data.target_names) print -# shuffle the ordering and split a training set and a test set +# split a training set and a test set filenames = data.filenames y = data.target -n = y.shape[0] -indices = np.arange(n) -np.random.seed(42) -np.random.shuffle(indices) - -filenames = filenames[indices] -y = y[indices] +n = filenames.shape[0] filenames_train, filenames_test = filenames[:-n/2], filenames[-n/2:] y_train, y_test = y[:-n/2], y[-n/2:] diff --git a/scikits/learn/datasets/base.py b/scikits/learn/datasets/base.py index eee7420d89a176993b0f7b8e818a42deb5aacc9c..6a0033ad1ee640db8c3cd407f488ca802d631df8 100644 --- a/scikits/learn/datasets/base.py +++ b/scikits/learn/datasets/base.py @@ -22,7 +22,8 @@ class Bunch(dict): self.__dict__ = self -def load_files(container_path, description=None, categories=None): +def load_files(container_path, description=None, categories=None, shuffle=True, + rng=42): """Load files with categories as subfolder names Individual samples are assumed to be files stored a two levels folder @@ -67,6 +68,12 @@ def load_files(container_path, description=None, categories=None): if None (default), load all the categories. if not Non, list of category names to load (other categories ignored) + shuffle : True by default + whether or not to shuffle the data: might be important for + + rng : a numpy random number generator or a seed integer, 42 by default + used to shuffle the dataset + Returns ------- @@ -96,9 +103,21 @@ def load_files(container_path, description=None, categories=None): target.extend(len(documents) * [label]) filenames.extend(documents) - return Bunch(filenames=np.array(filenames), + # convert as array for fancy indexing + filenames = np.array(filenames) + target = np.array(target) + + if shuffle: + if isinstance(rng, int): + rng = np.random.RandomState(rng) + indices = np.arange(filenames.shape[0]) + rng.shuffle(indices) + filenames = filenames[indices] + target = target[indices] + + return Bunch(filenames=filenames, target_names=target_names, - target=np.array(target), + target=target, DESCR=description)