diff --git a/scikits/learn/datasets/base.py b/scikits/learn/datasets/base.py index d872401b0a3ea80155b1eec0c1b83629f1b50e33..61b3363e8f9c3875a451fb19582ea53e19bbd8bd 100644 --- a/scikits/learn/datasets/base.py +++ b/scikits/learn/datasets/base.py @@ -22,11 +22,11 @@ class Bunch(dict): self.__dict__ = self -def load_text_files(container_path, description): - """Load text document files with categories as subfolder names +def load_files(container_path, description=None, categories=None): + """Load files with categories as subfolder names - Individual samples are assumed to be utf-8 encoded text files in a two level - folder structure such as the following: + Individual samples are assumed to be files stored a two levels folder + structure such as the following: container_folder/ category_1_folder/ @@ -42,13 +42,16 @@ def load_text_files(container_path, description): The folder names are used has supervised signal label names. The indivial file names are not important. - This function does not try to load the text features into a numpy array or - scipy sparse matrix, nor does it try to load the text in memory. + This function does not try to extract features into a numpy array or + scipy sparse matrix, nor does it try to load the files in memory. - The use text files in a scikit-learn classification or clustering algorithm - you will first need to use the `scikits.learn.features.text` module to build - a feature extraction transformer that suits your problem. + To use utf-8 text files in a scikit-learn classification or clustering + algorithm you will first need to use the `scikits.learn.features.text` + module to build a feature extraction transformer that suits your + problem. + Similar feature extractors should be build for other kind of unstructured + data input such as images, audio, video, ... Parameters ---------- @@ -60,6 +63,10 @@ def load_text_files(container_path, description): a paragraph describing the characteristic of the dataset, its source, reference, ... + categories : None or collection of string or unicode + if None (default), load all the categories. + if not Non, list of category names to load (other categories ignored) + Returns ------- @@ -77,6 +84,10 @@ def load_text_files(container_path, description): folders = [f for f in sorted(os.listdir(container_path)) if os.path.isdir(os.path.join(container_path, f))] + + if categories is not None: + folders = [f for f in folders if f in categories] + for label, folder in enumerate(folders): target_names[label] = folder folder_path = os.path.join(container_path, folder) diff --git a/scikits/learn/datasets/mlcomp.py b/scikits/learn/datasets/mlcomp.py index 4c08d371b18f1edde5a65a9281bf8ce965b57c0a..4ea38494df254332ced437ce55b04758f0b4f1f9 100644 --- a/scikits/learn/datasets/mlcomp.py +++ b/scikits/learn/datasets/mlcomp.py @@ -4,7 +4,7 @@ import os import numpy as np -from scikits.learn.datasets.base import load_text_files +from scikits.learn.datasets.base import load_files from scikits.learn.feature_extraction.text import HashingVectorizer from scikits.learn.feature_extraction.text.sparse import HashingVectorizer as \ SparseCountVectorizer @@ -13,7 +13,7 @@ from scikits.learn.feature_extraction.text.sparse import HashingVectorizer as \ def _load_document_classification(dataset_path, metadata, set_=None): if set_ is not None: dataset_path = os.path.join(dataset_path, set_) - return load_text_files(dataset_path, metadata.get('description')) + return load_files(dataset_path, metadata.get('description')) LOADERS = {