From a39c8ab75c3a2685bc3a038ee9d9e5e9fb4970b8 Mon Sep 17 00:00:00 2001
From: Olivier Grisel <olivier.grisel@ensta.org>
Date: Fri, 16 Jun 2017 18:11:47 +0200
Subject: [PATCH] ENH svmlight chunk loader (#935)

---
 doc/whats_new.rst                             |   5 +
 sklearn/datasets/_svmlight_format.pyx         |  15 ++-
 sklearn/datasets/svmlight_format.py           |  70 +++++++++---
 .../datasets/tests/test_svmlight_format.py    | 108 +++++++++++++++++-
 4 files changed, 181 insertions(+), 17 deletions(-)

diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 6792906bb9..70310366d0 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -202,11 +202,16 @@ Enhancements
 
    - Prevent cast from float32 to float64 in
      :class:`linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
+     :class:`sklearn.linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
      by :user:`Joan Massich <massich>`, :user:`Nicolas Cordier <ncordier>`
 
    - Add ``max_train_size`` parameter to :class:`model_selection.TimeSeriesSplit`
      :issue:`8282` by :user:`Aman Dalmia <dalmia>`.
 
+   - Make it possible to load a chunk of an svmlight formatted file by
+     passing a range of bytes to :func:`datasets.load_svmlight_file`.
+     :issue:`935` by :user:`Olivier Grisel <ogrisel>`.
+
 Bug fixes
 .........
 
diff --git a/sklearn/datasets/_svmlight_format.pyx b/sklearn/datasets/_svmlight_format.pyx
index 3596f5eef1..152bd4325d 100644
--- a/sklearn/datasets/_svmlight_format.pyx
+++ b/sklearn/datasets/_svmlight_format.pyx
@@ -26,7 +26,7 @@ cdef bytes COLON = u':'.encode('ascii')
 @cython.boundscheck(False)
 @cython.wraparound(False)
 def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
-                        bint query_id):
+                        bint query_id, long long offset, long long length):
     cdef array.array data, indices, indptr
     cdef bytes line
     cdef char *hash_ptr
@@ -35,6 +35,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
     cdef Py_ssize_t i
     cdef bytes qid_prefix = b('qid')
     cdef Py_ssize_t n_features
+    cdef long long offset_max = offset + length if length > 0 else -1
 
     # Special-case float32 but use float64 for everything else;
     # the Python code will do further conversions.
@@ -52,6 +53,12 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
     else:
         labels = array.array("d")
 
+    if offset > 0:
+        f.seek(offset)
+        # drop the current line that might be truncated and is to be
+        # fetched by another call
+        f.readline()
+
     for line in f:
         # skip comments
         line_cstr = line
@@ -90,7 +97,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
             idx = int(idx_s)
             if idx < 0 or not zero_based and idx == 0:
                 raise ValueError(
-                        "Invalid index %d in SVMlight/LibSVM data file." % idx)
+                    "Invalid index %d in SVMlight/LibSVM data file." % idx)
             if idx <= prev_idx:
                 raise ValueError("Feature indices in SVMlight/LibSVM data "
                                  "file should be sorted and unique.")
@@ -106,4 +113,8 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
         array.resize_smart(indptr, len(indptr) + 1)
         indptr[len(indptr) - 1] = len(data)
 
+        if offset_max != -1 and f.tell() > offset_max:
+            # Stop here and let another call deal with the following.
+            break
+
     return (dtype, data, indices, indptr, labels, query)
diff --git a/sklearn/datasets/svmlight_format.py b/sklearn/datasets/svmlight_format.py
index a567e2091e..c919dc8c0a 100644
--- a/sklearn/datasets/svmlight_format.py
+++ b/sklearn/datasets/svmlight_format.py
@@ -31,7 +31,8 @@ from ..utils import check_array
 
 
 def load_svmlight_file(f, n_features=None, dtype=np.float64,
-                       multilabel=False, zero_based="auto", query_id=False):
+                       multilabel=False, zero_based="auto", query_id=False,
+                       offset=0, length=-1):
     """Load datasets in the svmlight / libsvm format into sparse CSR matrix
 
     This format is a text-based format, with one sample per line. It does
@@ -76,6 +77,8 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
         bigger sliced dataset: each subset might not have examples of
         every feature, hence the inferred shape might vary from one
         slice to another.
+        n_features is only required if ``offset`` or ``length`` are passed a
+        non-default value.
 
     multilabel : boolean, optional, default False
         Samples may have several labels each (see
@@ -88,7 +91,10 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
         If set to "auto", a heuristic check is applied to determine this from
         the file contents. Both kinds of files occur "in the wild", but they
         are unfortunately not self-identifying. Using "auto" or True should
-        always be safe.
+        always be safe when no ``offset`` or ``length`` is passed.
+        If ``offset`` or ``length`` are passed, the "auto" mode falls back
+        to ``zero_based=True`` to avoid having the heuristic check yield
+        inconsistent results on different segments of the file.
 
     query_id : boolean, default False
         If True, will return the query_id array for each file.
@@ -97,6 +103,15 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
         Data type of dataset to be loaded. This will be the data type of the
         output numpy arrays ``X`` and ``y``.
 
+    offset : integer, optional, default 0
+        Ignore the offset first bytes by seeking forward, then
+        discarding the following bytes up until the next new line
+        character.
+
+    length : integer, optional, default -1
+        If strictly positive, stop reading any new line of data once the
+        position in the file has reached the (offset + length) bytes threshold.
+
     Returns
     -------
     X : scipy.sparse matrix of shape (n_samples, n_features)
@@ -129,7 +144,7 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
         X, y = get_data()
     """
     return tuple(load_svmlight_files([f], n_features, dtype, multilabel,
-                                     zero_based, query_id))
+                                     zero_based, query_id, offset, length))
 
 
 def _gen_open(f):
@@ -149,15 +164,18 @@ def _gen_open(f):
         return open(f, "rb")
 
 
-def _open_and_load(f, dtype, multilabel, zero_based, query_id):
+def _open_and_load(f, dtype, multilabel, zero_based, query_id,
+                   offset=0, length=-1):
     if hasattr(f, "read"):
         actual_dtype, data, ind, indptr, labels, query = \
-            _load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
+            _load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
+                                offset, length)
     # XXX remove closing when Python 2.7+/3.1+ required
     else:
         with closing(_gen_open(f)) as f:
             actual_dtype, data, ind, indptr, labels, query = \
-                _load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
+                _load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
+                                    offset, length)
 
     # convert from array.array, give data the right dtype
     if not multilabel:
@@ -172,7 +190,8 @@ def _open_and_load(f, dtype, multilabel, zero_based, query_id):
 
 
 def load_svmlight_files(files, n_features=None, dtype=np.float64,
-                        multilabel=False, zero_based="auto", query_id=False):
+                        multilabel=False, zero_based="auto", query_id=False,
+                        offset=0, length=-1):
     """Load dataset from multiple files in SVMlight format
 
     This function is equivalent to mapping load_svmlight_file over a list of
@@ -216,7 +235,10 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
         If set to "auto", a heuristic check is applied to determine this from
         the file contents. Both kinds of files occur "in the wild", but they
         are unfortunately not self-identifying. Using "auto" or True should
-        always be safe.
+        always be safe when no offset or length is passed.
+        If offset or length are passed, the "auto" mode falls back
+        to zero_based=True to avoid having the heuristic check yield
+        inconsistent results on different segments of the file.
 
     query_id : boolean, defaults to False
         If True, will return the query_id array for each file.
@@ -225,6 +247,15 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
         Data type of dataset to be loaded. This will be the data type of the
         output numpy arrays ``X`` and ``y``.
 
+    offset : integer, optional, default 0
+        Ignore the offset first bytes by seeking forward, then
+        discarding the following bytes up until the next new line
+        character.
+
+    length : integer, optional, default -1
+        If strictly positive, stop reading any new line of data once the
+        position in the file has reached the (offset + length) bytes threshold.
+
     Returns
     -------
     [X1, y1, ..., Xn, yn]
@@ -245,16 +276,27 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
     --------
     load_svmlight_file
     """
-    r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id))
+    if (offset != 0 or length > 0) and zero_based == "auto":
+        # disable heuristic search to avoid getting inconsistent results on
+        # different segments of the file
+        zero_based = True
+
+    if (offset != 0 or length > 0) and n_features is None:
+        raise ValueError(
+            "n_features is required when offset or length is specified.")
+
+    r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id),
+                        offset=offset, length=length)
          for f in files]
 
-    if (zero_based is False
-            or zero_based == "auto" and all(np.min(tmp[1]) > 0 for tmp in r)):
-        for ind in r:
-            indices = ind[1]
+    if (zero_based is False or
+            zero_based == "auto" and all(len(tmp[1]) and np.min(tmp[1]) > 0
+                                         for tmp in r)):
+        for _, indices, _, _, _ in r:
             indices -= 1
 
-    n_f = max(ind[1].max() for ind in r) + 1
+    n_f = max(ind[1].max() if len(ind[1]) else 0 for ind in r) + 1
+
     if n_features is None:
         n_features = n_f
     elif n_features < n_f:
diff --git a/sklearn/datasets/tests/test_svmlight_format.py b/sklearn/datasets/tests/test_svmlight_format.py
index 47d956d1ee..c98206065f 100644
--- a/sklearn/datasets/tests/test_svmlight_format.py
+++ b/sklearn/datasets/tests/test_svmlight_format.py
@@ -1,3 +1,4 @@
+from __future__ import division
 from bz2 import BZ2File
 import gzip
 from io import BytesIO
@@ -13,8 +14,10 @@ from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_array_almost_equal
 from sklearn.utils.testing import assert_raises
+from sklearn.utils.testing import assert_raises_regex
 from sklearn.utils.testing import raises
 from sklearn.utils.testing import assert_in
+from sklearn.utils.fixes import sp_version
 
 import sklearn
 from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
@@ -401,4 +404,107 @@ def test_load_with_long_qid():
     f.seek(0)
     X, y = load_svmlight_file(f, query_id=False, zero_based=True)
     assert_array_equal(y, true_y)
-    assert_array_equal(X.toarray(), true_X)
\ No newline at end of file
+    assert_array_equal(X.toarray(), true_X)
+
+
+def test_load_zeros():
+    f = BytesIO()
+    true_X = sp.csr_matrix(np.zeros(shape=(3, 4)))
+    true_y = np.array([0, 1, 0])
+    dump_svmlight_file(true_X, true_y, f)
+
+    for zero_based in ['auto', True, False]:
+        f.seek(0)
+        X, y = load_svmlight_file(f, n_features=4, zero_based=zero_based)
+        assert_array_equal(y, true_y)
+        assert_array_equal(X.toarray(), true_X.toarray())
+
+
+def test_load_with_offsets():
+    def check_load_with_offsets(sparsity, n_samples, n_features):
+        rng = np.random.RandomState(0)
+        X = rng.uniform(low=0.0, high=1.0, size=(n_samples, n_features))
+        if sparsity:
+            X[X < sparsity] = 0.0
+        X = sp.csr_matrix(X)
+        y = rng.randint(low=0, high=2, size=n_samples)
+
+        f = BytesIO()
+        dump_svmlight_file(X, y, f)
+        f.seek(0)
+
+        size = len(f.getvalue())
+
+        # put some marks that are likely to happen anywhere in a row
+        mark_0 = 0
+        mark_1 = size // 3
+        length_0 = mark_1 - mark_0
+        mark_2 = 4 * size // 5
+        length_1 = mark_2 - mark_1
+
+        # load the original sparse matrix into 3 independant CSR matrices
+        X_0, y_0 = load_svmlight_file(f, n_features=n_features,
+                                      offset=mark_0, length=length_0)
+        X_1, y_1 = load_svmlight_file(f, n_features=n_features,
+                                      offset=mark_1, length=length_1)
+        X_2, y_2 = load_svmlight_file(f, n_features=n_features,
+                                      offset=mark_2)
+
+        y_concat = np.concatenate([y_0, y_1, y_2])
+        X_concat = sp.vstack([X_0, X_1, X_2])
+        assert_array_equal(y, y_concat)
+        assert_array_almost_equal(X.toarray(), X_concat.toarray())
+
+    # Generate a uniformly random sparse matrix
+    for sparsity in [0, 0.1, .5, 0.99, 1]:
+        for n_samples in [13, 101]:
+            for n_features in [2, 7, 41]:
+                yield check_load_with_offsets, sparsity, n_samples, n_features
+
+
+def test_load_offset_exhaustive_splits():
+    rng = np.random.RandomState(0)
+    X = np.array([
+        [0, 0, 0, 0, 0, 0],
+        [1, 2, 3, 4, 0, 6],
+        [1, 2, 3, 4, 0, 6],
+        [0, 0, 0, 0, 0, 0],
+        [1, 0, 3, 0, 0, 0],
+        [0, 0, 0, 0, 0, 1],
+        [1, 0, 0, 0, 0, 0],
+    ])
+    X = sp.csr_matrix(X)
+    n_samples, n_features = X.shape
+    y = rng.randint(low=0, high=2, size=n_samples)
+    query_id = np.arange(n_samples) // 2
+
+    f = BytesIO()
+    dump_svmlight_file(X, y, f, query_id=query_id)
+    f.seek(0)
+
+    size = len(f.getvalue())
+
+    # load the same data in 2 parts with all the possible byte offsets to
+    # locate the split so has to test for particular boundary cases
+    for mark in range(size):
+        if sp_version < (0, 14) and (mark == 0 or mark > size - 100):
+            # old scipy does not support sparse matrices with 0 rows.
+            continue
+        f.seek(0)
+        X_0, y_0, q_0 = load_svmlight_file(f, n_features=n_features,
+                                           query_id=True, offset=0,
+                                           length=mark)
+        X_1, y_1, q_1 = load_svmlight_file(f, n_features=n_features,
+                                           query_id=True, offset=mark,
+                                           length=-1)
+        q_concat = np.concatenate([q_0, q_1])
+        y_concat = np.concatenate([y_0, y_1])
+        X_concat = sp.vstack([X_0, X_1])
+        assert_array_equal(y, y_concat)
+        assert_array_equal(query_id, q_concat)
+        assert_array_almost_equal(X.toarray(), X_concat.toarray())
+
+
+def test_load_with_offsets_error():
+    assert_raises_regex(ValueError, "n_features is required",
+                        load_svmlight_file, datafile, offset=3, length=3)
-- 
GitLab