diff --git a/scikits/learn/datasets/mlcomp.py b/scikits/learn/datasets/mlcomp.py
index d47ed5be4534c852c158a510105535919a51b872..038d29415c463b92cb912b313440fa58b77be235 100644
--- a/scikits/learn/datasets/mlcomp.py
+++ b/scikits/learn/datasets/mlcomp.py
@@ -6,14 +6,20 @@ import os
 import numpy as np
 from scikits.learn.datasets.base import Bunch
 from scikits.learn.features.text import HashingVectorizer
+from scikits.learn.features.text import SparseHashingVectorizer
 
 
-def _load_document_classification(dataset_path, metadata, set_, **kw):
+def _load_document_classification(dataset_path, metadata, set_, sparse, **kw):
     """Loader implementation for the DocumentClassification format"""
     target = []
     target_names = {}
     filenames = []
-    vectorizer = kw.get('vectorizer', HashingVectorizer())
+    vectorizer = kw.get('vectorizer')
+    if vectorizer is None:
+        if sparse:
+            vectorizer = SparseHashingVectorizer()
+        else:
+            vectorizer = HashingVectorizer()
 
     # TODO: make it possible to plug a several pass system to filter-out tokens
     # that occur in more than 30% of the documents for instance.
@@ -29,7 +35,7 @@ def _load_document_classification(dataset_path, metadata, set_, **kw):
         folder_path = os.path.join(dataset_path, folder)
         documents = [os.path.join(folder_path, d)
                      for d in sorted(os.listdir(folder_path))]
-        vectorizer.vectorize(documents)
+        vectorizer.vectorize_files(documents)
         target.extend(len(documents) * [label])
         filenames.extend(documents)
 
@@ -44,7 +50,8 @@ LOADERS = {
 }
 
 
-def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs):
+def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, sparse=False,
+                **kwargs):
     """Load a datasets as downloaded from http://mlcomp.org
 
     Parameters
@@ -59,6 +66,9 @@ def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs):
                   are stored, if mlcomp_root is None, the MLCOMP_DATASETS_HOME
                   environment variable is looked up instead.
 
+    sparse : boolean if True then use a scipy.sparse matrix for the data field,
+             False by default
+
     **kwargs : domain specific kwargs to be passed to the dataset loader.
 
     Returns
@@ -124,6 +134,6 @@ def load_mlcomp(name_or_id, set_="raw", mlcomp_root=None, **kwargs):
     loader = LOADERS.get(format)
     if loader is None:
         raise ValueError("No loader implemented for format: " + format)
-    return loader(dataset_path, metadata, set_=set_, **kwargs)
+    return loader(dataset_path, metadata, set_=set_, sparse=sparse, **kwargs)
 
 
diff --git a/scikits/learn/features/tests/test_text.py b/scikits/learn/features/tests/test_text.py
index 2e3462dcb2a8655a9309cd0dddc122c36c3b6958..6755f98f834203373f10914b87cdfac602b9233e 100644
--- a/scikits/learn/features/tests/test_text.py
+++ b/scikits/learn/features/tests/test_text.py
@@ -1,10 +1,29 @@
 from scikits.learn.features.text import strip_accents
 from scikits.learn.features.text import SimpleAnalyzer
 from scikits.learn.features.text import HashingVectorizer
+from scikits.learn.features.text import SparseHashingVectorizer
 from scikits.learn.logistic import LogisticRegression
+from scikits.learn.sparse.svm import SVC
 import numpy as np
 from nose.tools import *
-
+from numpy.testing import assert_array_almost_equal
+
+JUNK_FOOD_DOCS = (
+    "the pizza pizza beer",
+    "the pizza pizza beer",
+    "the the pizza beer beer",
+    "the pizza beer beer",
+    "the coke beer coke",
+    "the coke pizza pizza",
+)
+
+NOTJUNK_FOOD_DOCS = (
+    "the salad celeri",
+    "the salad salad sparkling water",
+    "the the celeri celeri",
+    "the tomato tomato salad water",
+    "the tomato salad water",
+)
 
 def test_strip_accents():
     # check some classical latin accentuated symbols
@@ -43,21 +62,8 @@ def test_simple_analyzer():
 
 def test_dense_tf_idf():
     hv = HashingVectorizer(dim=1000, probes=3)
-
-    # junk food documents
-    hv.sample_document("the pizza pizza beer")
-    hv.sample_document("the pizza pizza beer")
-    hv.sample_document("the the pizza beer beer")
-    hv.sample_document("the pizza beer beer")
-    hv.sample_document("the coke beer coke")
-    hv.sample_document("the coke pizza pizza")
-
-    # not-junk food documents
-    hv.sample_document("the salad celeri")
-    hv.sample_document("the salad salad sparkling water")
-    hv.sample_document("the the celeri celeri")
-    hv.sample_document("the tomato tomato salad water")
-    hv.sample_document("the tomato salad water")
+    hv.vectorize(JUNK_FOOD_DOCS)
+    hv.vectorize(NOTJUNK_FOOD_DOCS)
 
     # extract the TF-IDF data
     X = hv.get_tfidf()
@@ -72,3 +78,45 @@ def test_dense_tf_idf():
     assert_equal(clf.predict([X[0]]), [-1])
     assert_equal(clf.predict([X[-1]]), [1])
 
+
+def test_sparse_tf_idf():
+    hv = SparseHashingVectorizer(dim=10000, probes=3)
+    hv.vectorize(JUNK_FOOD_DOCS)
+    hv.vectorize(NOTJUNK_FOOD_DOCS)
+
+    # extract the TF-IDF data
+    X = hv.get_tfidf()
+    assert_equal(X.shape, (11, 10000))
+
+    # label junk food as -1, the others as +1
+    y = np.ones(X.shape[0])
+    y[:6] = -1
+
+    # train and test a classifier
+    clf = SVC(kernel='linear', C=10).fit(X[1:-1], y[1:-1])
+    assert_equal(clf.predict(X[0, :]), [-1])
+    assert_equal(clf.predict(X[-1, :]), [1])
+
+def test_dense_sparse_idf_sanity():
+    hv = HashingVectorizer(dim=100, probes=3)
+    shv = SparseHashingVectorizer(dim=100, probes=3)
+
+    hv.vectorize(JUNK_FOOD_DOCS)
+    shv.vectorize(JUNK_FOOD_DOCS)
+
+    # check that running TF IDF estimates are the same
+    dense_tf_idf = hv.get_tfidf()
+    sparse_tfidf = shv.get_tfidf().todense()
+
+    assert_array_almost_equal(dense_tf_idf, sparse_tfidf)
+
+    # check that incremental behaviour stays the same
+    hv.vectorize(NOTJUNK_FOOD_DOCS)
+    shv.vectorize(NOTJUNK_FOOD_DOCS)
+
+    dense_tf_idf = hv.get_tfidf()
+    sparse_tfidf = shv.get_tfidf().todense()
+
+    assert_array_almost_equal(dense_tf_idf, sparse_tfidf)
+
+
diff --git a/scikits/learn/features/text.py b/scikits/learn/features/text.py
index a0162cd79335ef214ee050a0f59d64301ff3f007..e0c35608fd3c35fa8699a3287f54287e47bef82c 100644
--- a/scikits/learn/features/text.py
+++ b/scikits/learn/features/text.py
@@ -3,6 +3,7 @@
 # License: BSD Style.
 """Utilities to build feature vectors from text documents"""
 
+from collections import defaultdict
 import re
 import unicodedata
 import numpy as np
@@ -106,10 +107,6 @@ class HashingVectorizer(object):
     # TODO: implement me using the murmurhash that might be faster: but profile
     # me first :)
 
-    # TODO: make it possible to select between the current dense representation
-    # and sparse alternatives from scipy.sparse once the liblinear and libsvm
-    # wrappers have been updated to be able to handle it efficiently
-
     def __init__(self, dim=5000, probes=1, analyzer=SimpleAnalyzer(),
                  use_idf=True):
         self.dim = dim
@@ -121,23 +118,14 @@ class HashingVectorizer(object):
         # computing IDF
         self.df_counts = np.ones(dim, dtype=long)
         self.tf_vectors = None
-        self.sampled = 0
 
     def hash_sign(self, token, probe=0):
+        """Compute the hash of token with number proble and hashed sign"""
         h = hash(token + (probe * u"#"))
         return abs(h) % self.dim, 1.0 if h % 2 == 0 else -1.0
 
-    def sample_document(self, text, tf_vector=None, update_estimates=True):
+    def _sample_document(self, text, tf_vector, update_estimates=True):
         """Extract features from text and update running freq estimates"""
-        if tf_vector is None:
-            # allocate term frequency vector and stack to history
-            tf_vector = np.zeros(self.dim, np.float64)
-            if self.tf_vectors is None:
-                self.tf_vectors = tf_vector.reshape((1, self.dim))
-            else:
-                self.tf_vectors = np.vstack((self.tf_vectors, tf_vector))
-                tf_vector = self.tf_vectors[-1]
-
         tokens = self.analyzer.analyze(text)
         for token in tokens:
             # TODO add support for cooccurence tokens in a sentence
@@ -150,11 +138,11 @@ class HashingVectorizer(object):
         if update_estimates and self.use_idf:
             # update the running DF estimate
             self.df_counts += tf_vector != 0.0
-            self.sampled += 1
         return tf_vector
 
     def get_idf(self):
-        return np.log(float(self.sampled) / self.df_counts)
+        n_samples = float(len(self.tf_vectors))
+        return np.log(n_samples / self.df_counts)
 
     def get_tfidf(self):
         """Compute the TF-log(IDF) vectors of the sampled documents"""
@@ -162,11 +150,22 @@ class HashingVectorizer(object):
             return None
         return self.tf_vectors * self.get_idf()
 
-    def vectorize(self, document_filepaths):
-        """Vectorize a batch of documents"""
+    def vectorize(self, text_documents):
+        """Vectorize a batch of documents in python utf-8 strings or unicode"""
+        tf_vectors = np.zeros((len(text_documents), self.dim))
+        for i, text in enumerate(text_documents):
+            self._sample_document(text, tf_vectors[i])
+
+        if self.tf_vectors is None:
+            self.tf_vectors = tf_vectors
+        else:
+            self.tf_vectors = np.vstack((self.tf_vectors, tf_vectors))
+
+    def vectorize_files(self, document_filepaths):
+        """Vectorize a batch of documents stored in utf-8 text files"""
         tf_vectors = np.zeros((len(document_filepaths), self.dim))
         for i, filepath in enumerate(document_filepaths):
-            self.sample_document(file(filepath).read(), tf_vectors[i])
+            self._sample_document(file(filepath).read(), tf_vectors[i])
 
         if self.tf_vectors is None:
             self.tf_vectors = tf_vectors
@@ -181,7 +180,7 @@ class HashingVectorizer(object):
 
 
 class SparseHashingVectorizer(object):
-    """Compute term frequencies vectors using hashed term space in sparse matrix
+    """Compute term freq vectors using hashed term space in a sparse matrix
 
     The logic is the same as HashingVectorizer but it is possible to use much
     larger dimension vectors without memory issues thanks to the usage of
@@ -199,56 +198,62 @@ class SparseHashingVectorizer(object):
         # computing IDF
         self.df_counts = np.ones(dim, dtype=long)
         self.tf_vectors = None
-        self.sampled = 0
 
     def hash_sign(self, token, probe=0):
         h = hash(token + (probe * u"#"))
         return abs(h) % self.dim, 1.0 if h % 2 == 0 else -1.0
 
-    def sample_document(self, text, tf_vectors=None, idx=0,
-                        update_estimates=True):
+    def _sample_document(self, text, tf_vectors, idx=0, update_estimates=True):
         """Extract features from text and update running freq estimates"""
-        if tf_vectors is None:
-            # allocate term frequency vector and stack to history
-            tf_vectors = sp.lil_matrix((1, self.dim))
-            stack = True
-        else:
-            stack = False
 
         tokens = self.analyzer.analyze(text)
+        counts = defaultdict(lambda: 0.0)
         for token in tokens:
             # TODO add support for cooccurence tokens in a sentence
             # window
             for probe in xrange(self.probes):
                 i, incr = self.hash_sign(token, probe)
-                tf_vectors[idx, i] += incr
-        tf_vectors[idx, :] /= len(tokens) * self.probes
+                counts[i] += incr
+        for k, v in counts.iteritems():
+            if v == 0.0:
+                # can happen if equally frequent conflicting features
+                continue
+            tf_vectors[idx, k] = v / (len(tokens) * self.probes)
 
-        if update_estimates and self.use_idf:
-            # update the running DF estimate
-            self.df_counts += np.sum(tf_vectors.nonzero()[0] == idx)
-            self.sampled += 1
-
-        if stack:
-            if self.tf_vectors is None:
-                self.tf_vectors = tf_vectors
-            else:
-                self.tf_vectors = sp.vstack((self.tf_vectors, tf_vectors))
+            if update_estimates and self.use_idf:
+                # update the running DF estimate
+                self.df_counts[k] += 1
 
     def get_idf(self):
-        return np.log(float(self.sampled) / self.df_counts)
+        n_samples = float(self.tf_vectors.shape[0])
+        return np.log(n_samples / self.df_counts)
 
     def get_tfidf(self):
         """Compute the TF-log(IDF) vectors of the sampled documents"""
         if self.tf_vectors is None:
             return None
-        return self.tf_vectors.multiply(self.get_idf()[np.newaxis,:])
+        sparse_idf = sp.lil_matrix((self.dim, self.dim))
+        sparse_idf.setdiag(self.get_idf())
+        # use matrix multiply by a diagonal version of idf to emulate array
+        # broadcasting with a matrix API
+        return self.tf_vectors * sparse_idf
+
+    def vectorize(self, text_documents):
+        """Vectorize a batch of documents in python utf-8 strings or unicode"""
+        tf_vectors = sp.lil_matrix((len(text_documents), self.dim))
+        for i, text in enumerate(text_documents):
+            self._sample_document(text, tf_vectors, i)
 
-    def vectorize(self, document_filepaths):
-        """Vectorize a batch of documents"""
+        if self.tf_vectors is None:
+            self.tf_vectors = tf_vectors
+        else:
+            self.tf_vectors = sp.vstack((self.tf_vectors, tf_vectors))
+
+    def vectorize_files(self, document_filepaths):
+        """Vectorize a batch of utf-8 text files"""
         tf_vectors = sp.lil_matrix((len(document_filepaths), self.dim))
         for i, filepath in enumerate(document_filepaths):
-            self.sample_document(file(filepath).read(), tf_vectors[i])
+            self._sample_document(file(filepath).read(), tf_vectors, i)
 
         if self.tf_vectors is None:
             self.tf_vectors = tf_vectors
@@ -257,8 +262,8 @@ class SparseHashingVectorizer(object):
 
     def get_vectors(self):
         if self.use_idf:
-            return self.get_tfidf().tocsr()
+            return self.get_tfidf()
         else:
-            return self.tf_vectors.tocsr()
+            return self.tf_vectors
 
 
diff --git a/scikits/learn/sparse/svm.py b/scikits/learn/sparse/svm.py
index 5f0969b038ef559282a44c95d88aa4d4cfaa9bfe..cfc5c47977a39bb35eff1f32b9a92ce38ddab507 100644
--- a/scikits/learn/sparse/svm.py
+++ b/scikits/learn/sparse/svm.py
@@ -148,16 +148,17 @@ class SparseBaseLibsvm(BaseEstimator):
     @property
     def coef_(self):
         if self.kernel != 'linear':
-            raise NotImplementedError('coef_ is only available when using a linear kernel')
+            raise NotImplementedError(
+                'coef_ is only available when using a linear kernel')
         return np.dot(self.dual_coef_, self.support_)
 
+
 class SVC(SparseBaseLibsvm):
     """SVC for sparse matrices (csr)
 
     For best results, this accepts a matrix in csr format
     (scipy.sparse.csr), but should be able to convert from any array-like
     object (including other sparse representations).
-
     """
     def __init__(self, impl='c_svc', kernel='rbf', degree=3, gamma=0.0,
                  coef0=0.0, cache_size=100.0, eps=1e-3, C=1.0, nu=0.5, p=0.1,