From ea20d1c30b3305367915e02853471bfe7352c937 Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Sun, 4 Jul 2010 17:10:11 +0200 Subject: [PATCH] remove labels handling from vectorizer code --- scikits/learn/features/tests/test_text.py | 30 +++++++++++++---------- scikits/learn/features/text.py | 12 ++++----- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/scikits/learn/features/tests/test_text.py b/scikits/learn/features/tests/test_text.py index 421361466d..688882c507 100644 --- a/scikits/learn/features/tests/test_text.py +++ b/scikits/learn/features/tests/test_text.py @@ -2,6 +2,7 @@ from scikits.learn.features.text import strip_accents from scikits.learn.features.text import SimpleAnalyzer from scikits.learn.features.text import HashingVectorizer from scikits.learn.logistic import LogisticRegression +import numpy as np from nose.tools import * @@ -44,24 +45,27 @@ def test_tf_idf(): hv = HashingVectorizer(dim=1000, probes=3) # junk food documents - hv.sample_document("the pizza pizza beer", label=-1) - hv.sample_document("the pizza pizza beer", label=-1) - hv.sample_document("the the pizza beer beer", label=-1) - hv.sample_document("the pizza beer beer", label=-1) - hv.sample_document("the coke beer coke", label=-1) - hv.sample_document("the coke pizza pizza", label=-1) + hv.sample_document("the pizza pizza beer") + hv.sample_document("the pizza pizza beer") + hv.sample_document("the the pizza beer beer") + hv.sample_document("the pizza beer beer") + hv.sample_document("the coke beer coke") + hv.sample_document("the coke pizza pizza") # not-junk food documents - hv.sample_document("the salad celeri", label=1) - hv.sample_document("the salad salad sparkling water", label=1) - hv.sample_document("the the celeri celeri", label=1) - hv.sample_document("the tomato tomato salad water", label=1) - hv.sample_document("the tomato salad water", label=1) + hv.sample_document("the salad celeri") + hv.sample_document("the salad salad sparkling water") + hv.sample_document("the the celeri celeri") + hv.sample_document("the tomato tomato salad water") + hv.sample_document("the tomato salad water") # extract the TF-IDF data - X, y = hv.get_tfidf(), hv.labels + X = hv.get_tfidf() assert_equal(X.shape, (11, 1000)) - assert_equal(len(y), 11) + + # label junk food as -1, the others as +1 + y = np.ones(X.shape[0]) + y[:6] = -1 # train and test a classifier clf = LogisticRegression().fit(X[1:-1], y[1:-1]) diff --git a/scikits/learn/features/text.py b/scikits/learn/features/text.py index adc9f2e0d9..41766e58e4 100644 --- a/scikits/learn/features/text.py +++ b/scikits/learn/features/text.py @@ -63,14 +63,13 @@ class HashingVectorizer(object): # computing IDF self.df_counts = np.ones(dim, dtype=long) self.tf_vectors = None - self.labels = [] self.sampled = 0 def hash_sign(self, token, probe=0): h = hash(token + (probe * u"#")) return abs(h) % self.dim, 1.0 if h % 2 == 0 else -1.0 - def sample_document(self, text, tf_vector=None, label=None): + def sample_document(self, text, tf_vector=None, update_estimates=True): """Extract features from text and update running freq estimates""" if tf_vector is None: # allocate term frequency vector and stack to history @@ -80,8 +79,6 @@ class HashingVectorizer(object): else: self.tf_vectors = np.vstack((self.tf_vectors, tf_vector)) tf_vector = self.tf_vectors[-1] - if label is not None: - self.labels.append(label) tokens = self.analyzer.analyze(text) for token in tokens: @@ -92,9 +89,10 @@ class HashingVectorizer(object): tf_vector[i] += incr tf_vector /= len(tokens) * self.probes - # update the running DF estimate - self.df_counts += tf_vector != 0.0 - self.sampled += 1 + if update_estimates: + # update the running DF estimate + self.df_counts += tf_vector != 0.0 + self.sampled += 1 return tf_vector def get_tfidf(self): -- GitLab