diff --git a/scikits/learn/datasets/mlcomp.py b/scikits/learn/datasets/mlcomp.py index d47ed5be4534c852c158a510105535919a51b872..038d29415c463b92cb912b313440fa58b77be235 100644 --- a/scikits/learn/datasets/mlcomp.py +++ b/scikits/learn/datasets/mlcomp.py @@ -6,14 +6,20 @@ import os import numpy as np from scikits.learn.datasets.base import Bunch from scikits.learn.features.text import HashingVectorizer +from scikits.learn.features.text import SparseHashingVectorizer -def _load_document_classification(dataset_path, metadata, set_, **kw): +def _load_document_classification(dataset_path, metadata, set_, sparse, **kw): """Loader implementation for the DocumentClassification format""" target = [] target_names = {} filenames = [] - vectorizer = kw.get('vectorizer', HashingVectorizer()) + vectorizer = kw.get('vectorizer') + if vectorizer is None: + if sparse: + vectorizer = SparseHashingVectorizer() + else: + vectorizer = HashingVectorizer() # TODO: make it possible to plug a several pass system to filter-out tokens # that occur in more than 30% of the documents for instance. @@ -29,7 +35,7 @@ def _load_document_classification(dataset_path, metadata, set_, **kw): folder_path = os.path.join(dataset_path, folder) documents = [os.path.join(folder_path, d) for d in sorted(os.listdir(folder_path))] - vectorizer.vectorize(documents) + vectorizer.vectorize_files(documents) target.extend(len(documents) * [label]) filenames.extend(documents) @@ -44,7 +50,8 @@ LOADERS = { } -def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs): +def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, sparse=False, + **kwargs): """Load a datasets as downloaded from http://mlcomp.org Parameters @@ -59,6 +66,9 @@ def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs): are stored, if mlcomp_root is None, the MLCOMP_DATASETS_HOME environment variable is looked up instead. + sparse : boolean if True then use a scipy.sparse matrix for the data field, + False by default + **kwargs : domain specific kwargs to be passed to the dataset loader. Returns @@ -124,6 +134,6 @@ def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs): loader = LOADERS.get(format) if loader is None: raise ValueError("No loader implemented for format: " + format) - return loader(dataset_path, metadata, set_=set_, **kwargs) + return loader(dataset_path, metadata, set_=set_, sparse=sparse, **kwargs) diff --git a/scikits/learn/features/tests/test_text.py b/scikits/learn/features/tests/test_text.py index 2e3462dcb2a8655a9309cd0dddc122c36c3b6958..6755f98f834203373f10914b87cdfac602b9233e 100644 --- a/scikits/learn/features/tests/test_text.py +++ b/scikits/learn/features/tests/test_text.py @@ -1,10 +1,29 @@ from scikits.learn.features.text import strip_accents from scikits.learn.features.text import SimpleAnalyzer from scikits.learn.features.text import HashingVectorizer +from scikits.learn.features.text import SparseHashingVectorizer from scikits.learn.logistic import LogisticRegression +from scikits.learn.sparse.svm import SVC import numpy as np from nose.tools import * - +from numpy.testing import assert_array_almost_equal + +JUNK_FOOD_DOCS = ( + "the pizza pizza beer", + "the pizza pizza beer", + "the the pizza beer beer", + "the pizza beer beer", + "the coke beer coke", + "the coke pizza pizza", +) + +NOTJUNK_FOOD_DOCS = ( + "the salad celeri", + "the salad salad sparkling water", + "the the celeri celeri", + "the tomato tomato salad water", + "the tomato salad water", +) def test_strip_accents(): # check some classical latin accentuated symbols @@ -43,21 +62,8 @@ def test_simple_analyzer(): def test_dense_tf_idf(): hv = HashingVectorizer(dim=1000, probes=3) - - # junk food documents - hv.sample_document("the pizza pizza beer") - hv.sample_document("the pizza pizza beer") - hv.sample_document("the the pizza beer beer") - hv.sample_document("the pizza beer beer") - hv.sample_document("the coke beer coke") - hv.sample_document("the coke pizza pizza") - - # not-junk food documents - hv.sample_document("the salad celeri") - hv.sample_document("the salad salad sparkling water") - hv.sample_document("the the celeri celeri") - hv.sample_document("the tomato tomato salad water") - hv.sample_document("the tomato salad water") + hv.vectorize(JUNK_FOOD_DOCS) + hv.vectorize(NOTJUNK_FOOD_DOCS) # extract the TF-IDF data X = hv.get_tfidf() @@ -72,3 +78,45 @@ def test_dense_tf_idf(): assert_equal(clf.predict([X[0]]), [-1]) assert_equal(clf.predict([X[-1]]), [1]) + +def test_sparse_tf_idf(): + hv = SparseHashingVectorizer(dim=10000, probes=3) + hv.vectorize(JUNK_FOOD_DOCS) + hv.vectorize(NOTJUNK_FOOD_DOCS) + + # extract the TF-IDF data + X = hv.get_tfidf() + assert_equal(X.shape, (11, 10000)) + + # label junk food as -1, the others as +1 + y = np.ones(X.shape[0]) + y[:6] = -1 + + # train and test a classifier + clf = SVC(kernel='linear', C=10).fit(X[1:-1], y[1:-1]) + assert_equal(clf.predict(X[0, :]), [-1]) + assert_equal(clf.predict(X[-1, :]), [1]) + +def test_dense_sparse_idf_sanity(): + hv = HashingVectorizer(dim=100, probes=3) + shv = SparseHashingVectorizer(dim=100, probes=3) + + hv.vectorize(JUNK_FOOD_DOCS) + shv.vectorize(JUNK_FOOD_DOCS) + + # check that running TF IDF estimates are the same + dense_tf_idf = hv.get_tfidf() + sparse_tfidf = shv.get_tfidf().todense() + + assert_array_almost_equal(dense_tf_idf, sparse_tfidf) + + # check that incremental behaviour stays the same + hv.vectorize(NOTJUNK_FOOD_DOCS) + shv.vectorize(NOTJUNK_FOOD_DOCS) + + dense_tf_idf = hv.get_tfidf() + sparse_tfidf = shv.get_tfidf().todense() + + assert_array_almost_equal(dense_tf_idf, sparse_tfidf) + + diff --git a/scikits/learn/features/text.py b/scikits/learn/features/text.py index a0162cd79335ef214ee050a0f59d64301ff3f007..e0c35608fd3c35fa8699a3287f54287e47bef82c 100644 --- a/scikits/learn/features/text.py +++ b/scikits/learn/features/text.py @@ -3,6 +3,7 @@ # License: BSD Style. """Utilities to build feature vectors from text documents""" +from collections import defaultdict import re import unicodedata import numpy as np @@ -106,10 +107,6 @@ class HashingVectorizer(object): # TODO: implement me using the murmurhash that might be faster: but profile # me first :) - # TODO: make it possible to select between the current dense representation - # and sparse alternatives from scipy.sparse once the liblinear and libsvm - # wrappers have been updated to be able to handle it efficiently - def __init__(self, dim=5000, probes=1, analyzer=SimpleAnalyzer(), use_idf=True): self.dim = dim @@ -121,23 +118,14 @@ class HashingVectorizer(object): # computing IDF self.df_counts = np.ones(dim, dtype=long) self.tf_vectors = None - self.sampled = 0 def hash_sign(self, token, probe=0): + """Compute the hash of token with number proble and hashed sign""" h = hash(token + (probe * u"#")) return abs(h) % self.dim, 1.0 if h % 2 == 0 else -1.0 - def sample_document(self, text, tf_vector=None, update_estimates=True): + def _sample_document(self, text, tf_vector, update_estimates=True): """Extract features from text and update running freq estimates""" - if tf_vector is None: - # allocate term frequency vector and stack to history - tf_vector = np.zeros(self.dim, np.float64) - if self.tf_vectors is None: - self.tf_vectors = tf_vector.reshape((1, self.dim)) - else: - self.tf_vectors = np.vstack((self.tf_vectors, tf_vector)) - tf_vector = self.tf_vectors[-1] - tokens = self.analyzer.analyze(text) for token in tokens: # TODO add support for cooccurence tokens in a sentence @@ -150,11 +138,11 @@ class HashingVectorizer(object): if update_estimates and self.use_idf: # update the running DF estimate self.df_counts += tf_vector != 0.0 - self.sampled += 1 return tf_vector def get_idf(self): - return np.log(float(self.sampled) / self.df_counts) + n_samples = float(len(self.tf_vectors)) + return np.log(n_samples / self.df_counts) def get_tfidf(self): """Compute the TF-log(IDF) vectors of the sampled documents""" @@ -162,11 +150,22 @@ class HashingVectorizer(object): return None return self.tf_vectors * self.get_idf() - def vectorize(self, document_filepaths): - """Vectorize a batch of documents""" + def vectorize(self, text_documents): + """Vectorize a batch of documents in python utf-8 strings or unicode""" + tf_vectors = np.zeros((len(text_documents), self.dim)) + for i, text in enumerate(text_documents): + self._sample_document(text, tf_vectors[i]) + + if self.tf_vectors is None: + self.tf_vectors = tf_vectors + else: + self.tf_vectors = np.vstack((self.tf_vectors, tf_vectors)) + + def vectorize_files(self, document_filepaths): + """Vectorize a batch of documents stored in utf-8 text files""" tf_vectors = np.zeros((len(document_filepaths), self.dim)) for i, filepath in enumerate(document_filepaths): - self.sample_document(file(filepath).read(), tf_vectors[i]) + self._sample_document(file(filepath).read(), tf_vectors[i]) if self.tf_vectors is None: self.tf_vectors = tf_vectors @@ -181,7 +180,7 @@ class HashingVectorizer(object): class SparseHashingVectorizer(object): - """Compute term frequencies vectors using hashed term space in sparse matrix + """Compute term freq vectors using hashed term space in a sparse matrix The logic is the same as HashingVectorizer but it is possible to use much larger dimension vectors without memory issues thanks to the usage of @@ -199,56 +198,62 @@ class SparseHashingVectorizer(object): # computing IDF self.df_counts = np.ones(dim, dtype=long) self.tf_vectors = None - self.sampled = 0 def hash_sign(self, token, probe=0): h = hash(token + (probe * u"#")) return abs(h) % self.dim, 1.0 if h % 2 == 0 else -1.0 - def sample_document(self, text, tf_vectors=None, idx=0, - update_estimates=True): + def _sample_document(self, text, tf_vectors, idx=0, update_estimates=True): """Extract features from text and update running freq estimates""" - if tf_vectors is None: - # allocate term frequency vector and stack to history - tf_vectors = sp.lil_matrix((1, self.dim)) - stack = True - else: - stack = False tokens = self.analyzer.analyze(text) + counts = defaultdict(lambda: 0.0) for token in tokens: # TODO add support for cooccurence tokens in a sentence # window for probe in xrange(self.probes): i, incr = self.hash_sign(token, probe) - tf_vectors[idx, i] += incr - tf_vectors[idx, :] /= len(tokens) * self.probes + counts[i] += incr + for k, v in counts.iteritems(): + if v == 0.0: + # can happen if equally frequent conflicting features + continue + tf_vectors[idx, k] = v / (len(tokens) * self.probes) - if update_estimates and self.use_idf: - # update the running DF estimate - self.df_counts += np.sum(tf_vectors.nonzero()[0] == idx) - self.sampled += 1 - - if stack: - if self.tf_vectors is None: - self.tf_vectors = tf_vectors - else: - self.tf_vectors = sp.vstack((self.tf_vectors, tf_vectors)) + if update_estimates and self.use_idf: + # update the running DF estimate + self.df_counts[k] += 1 def get_idf(self): - return np.log(float(self.sampled) / self.df_counts) + n_samples = float(self.tf_vectors.shape[0]) + return np.log(n_samples / self.df_counts) def get_tfidf(self): """Compute the TF-log(IDF) vectors of the sampled documents""" if self.tf_vectors is None: return None - return self.tf_vectors.multiply(self.get_idf()[np.newaxis,:]) + sparse_idf = sp.lil_matrix((self.dim, self.dim)) + sparse_idf.setdiag(self.get_idf()) + # use matrix multiply by a diagonal version of idf to emulate array + # broadcasting with a matrix API + return self.tf_vectors * sparse_idf + + def vectorize(self, text_documents): + """Vectorize a batch of documents in python utf-8 strings or unicode""" + tf_vectors = sp.lil_matrix((len(text_documents), self.dim)) + for i, text in enumerate(text_documents): + self._sample_document(text, tf_vectors, i) - def vectorize(self, document_filepaths): - """Vectorize a batch of documents""" + if self.tf_vectors is None: + self.tf_vectors = tf_vectors + else: + self.tf_vectors = sp.vstack((self.tf_vectors, tf_vectors)) + + def vectorize_files(self, document_filepaths): + """Vectorize a batch of utf-8 text files""" tf_vectors = sp.lil_matrix((len(document_filepaths), self.dim)) for i, filepath in enumerate(document_filepaths): - self.sample_document(file(filepath).read(), tf_vectors[i]) + self._sample_document(file(filepath).read(), tf_vectors, i) if self.tf_vectors is None: self.tf_vectors = tf_vectors @@ -257,8 +262,8 @@ class SparseHashingVectorizer(object): def get_vectors(self): if self.use_idf: - return self.get_tfidf().tocsr() + return self.get_tfidf() else: - return self.tf_vectors.tocsr() + return self.tf_vectors diff --git a/scikits/learn/sparse/svm.py b/scikits/learn/sparse/svm.py index 5f0969b038ef559282a44c95d88aa4d4cfaa9bfe..cfc5c47977a39bb35eff1f32b9a92ce38ddab507 100644 --- a/scikits/learn/sparse/svm.py +++ b/scikits/learn/sparse/svm.py @@ -148,16 +148,17 @@ class SparseBaseLibsvm(BaseEstimator): @property def coef_(self): if self.kernel != 'linear': - raise NotImplementedError('coef_ is only available when using a linear kernel') + raise NotImplementedError( + 'coef_ is only available when using a linear kernel') return np.dot(self.dual_coef_, self.support_) + class SVC(SparseBaseLibsvm): """SVC for sparse matrices (csr) For best results, this accepts a matrix in csr format (scipy.sparse.csr), but should be able to convert from any array-like object (including other sparse representations). - """ def __init__(self, impl='c_svc', kernel='rbf', degree=3, gamma=0.0, coef0=0.0, cache_size=100.0, eps=1e-3, C=1.0, nu=0.5, p=0.1,