From a39c8ab75c3a2685bc3a038ee9d9e5e9fb4970b8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel <olivier.grisel@ensta.org> Date: Fri, 16 Jun 2017 18:11:47 +0200 Subject: [PATCH] ENH svmlight chunk loader (#935) --- doc/whats_new.rst | 5 + sklearn/datasets/_svmlight_format.pyx | 15 ++- sklearn/datasets/svmlight_format.py | 70 +++++++++--- .../datasets/tests/test_svmlight_format.py | 108 +++++++++++++++++- 4 files changed, 181 insertions(+), 17 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6792906bb9..70310366d0 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -202,11 +202,16 @@ Enhancements - Prevent cast from float32 to float64 in :class:`linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers + :class:`sklearn.linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers by :user:`Joan Massich <massich>`, :user:`Nicolas Cordier <ncordier>` - Add ``max_train_size`` parameter to :class:`model_selection.TimeSeriesSplit` :issue:`8282` by :user:`Aman Dalmia <dalmia>`. + - Make it possible to load a chunk of an svmlight formatted file by + passing a range of bytes to :func:`datasets.load_svmlight_file`. + :issue:`935` by :user:`Olivier Grisel <ogrisel>`. + Bug fixes ......... diff --git a/sklearn/datasets/_svmlight_format.pyx b/sklearn/datasets/_svmlight_format.pyx index 3596f5eef1..152bd4325d 100644 --- a/sklearn/datasets/_svmlight_format.pyx +++ b/sklearn/datasets/_svmlight_format.pyx @@ -26,7 +26,7 @@ cdef bytes COLON = u':'.encode('ascii') @cython.boundscheck(False) @cython.wraparound(False) def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based, - bint query_id): + bint query_id, long long offset, long long length): cdef array.array data, indices, indptr cdef bytes line cdef char *hash_ptr @@ -35,6 +35,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based, cdef Py_ssize_t i cdef bytes qid_prefix = b('qid') cdef Py_ssize_t n_features + cdef long long offset_max = offset + length if length > 0 else -1 # Special-case float32 but use float64 for everything else; # the Python code will do further conversions. @@ -52,6 +53,12 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based, else: labels = array.array("d") + if offset > 0: + f.seek(offset) + # drop the current line that might be truncated and is to be + # fetched by another call + f.readline() + for line in f: # skip comments line_cstr = line @@ -90,7 +97,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based, idx = int(idx_s) if idx < 0 or not zero_based and idx == 0: raise ValueError( - "Invalid index %d in SVMlight/LibSVM data file." % idx) + "Invalid index %d in SVMlight/LibSVM data file." % idx) if idx <= prev_idx: raise ValueError("Feature indices in SVMlight/LibSVM data " "file should be sorted and unique.") @@ -106,4 +113,8 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based, array.resize_smart(indptr, len(indptr) + 1) indptr[len(indptr) - 1] = len(data) + if offset_max != -1 and f.tell() > offset_max: + # Stop here and let another call deal with the following. + break + return (dtype, data, indices, indptr, labels, query) diff --git a/sklearn/datasets/svmlight_format.py b/sklearn/datasets/svmlight_format.py index a567e2091e..c919dc8c0a 100644 --- a/sklearn/datasets/svmlight_format.py +++ b/sklearn/datasets/svmlight_format.py @@ -31,7 +31,8 @@ from ..utils import check_array def load_svmlight_file(f, n_features=None, dtype=np.float64, - multilabel=False, zero_based="auto", query_id=False): + multilabel=False, zero_based="auto", query_id=False, + offset=0, length=-1): """Load datasets in the svmlight / libsvm format into sparse CSR matrix This format is a text-based format, with one sample per line. It does @@ -76,6 +77,8 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64, bigger sliced dataset: each subset might not have examples of every feature, hence the inferred shape might vary from one slice to another. + n_features is only required if ``offset`` or ``length`` are passed a + non-default value. multilabel : boolean, optional, default False Samples may have several labels each (see @@ -88,7 +91,10 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64, If set to "auto", a heuristic check is applied to determine this from the file contents. Both kinds of files occur "in the wild", but they are unfortunately not self-identifying. Using "auto" or True should - always be safe. + always be safe when no ``offset`` or ``length`` is passed. + If ``offset`` or ``length`` are passed, the "auto" mode falls back + to ``zero_based=True`` to avoid having the heuristic check yield + inconsistent results on different segments of the file. query_id : boolean, default False If True, will return the query_id array for each file. @@ -97,6 +103,15 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64, Data type of dataset to be loaded. This will be the data type of the output numpy arrays ``X`` and ``y``. + offset : integer, optional, default 0 + Ignore the offset first bytes by seeking forward, then + discarding the following bytes up until the next new line + character. + + length : integer, optional, default -1 + If strictly positive, stop reading any new line of data once the + position in the file has reached the (offset + length) bytes threshold. + Returns ------- X : scipy.sparse matrix of shape (n_samples, n_features) @@ -129,7 +144,7 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64, X, y = get_data() """ return tuple(load_svmlight_files([f], n_features, dtype, multilabel, - zero_based, query_id)) + zero_based, query_id, offset, length)) def _gen_open(f): @@ -149,15 +164,18 @@ def _gen_open(f): return open(f, "rb") -def _open_and_load(f, dtype, multilabel, zero_based, query_id): +def _open_and_load(f, dtype, multilabel, zero_based, query_id, + offset=0, length=-1): if hasattr(f, "read"): actual_dtype, data, ind, indptr, labels, query = \ - _load_svmlight_file(f, dtype, multilabel, zero_based, query_id) + _load_svmlight_file(f, dtype, multilabel, zero_based, query_id, + offset, length) # XXX remove closing when Python 2.7+/3.1+ required else: with closing(_gen_open(f)) as f: actual_dtype, data, ind, indptr, labels, query = \ - _load_svmlight_file(f, dtype, multilabel, zero_based, query_id) + _load_svmlight_file(f, dtype, multilabel, zero_based, query_id, + offset, length) # convert from array.array, give data the right dtype if not multilabel: @@ -172,7 +190,8 @@ def _open_and_load(f, dtype, multilabel, zero_based, query_id): def load_svmlight_files(files, n_features=None, dtype=np.float64, - multilabel=False, zero_based="auto", query_id=False): + multilabel=False, zero_based="auto", query_id=False, + offset=0, length=-1): """Load dataset from multiple files in SVMlight format This function is equivalent to mapping load_svmlight_file over a list of @@ -216,7 +235,10 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64, If set to "auto", a heuristic check is applied to determine this from the file contents. Both kinds of files occur "in the wild", but they are unfortunately not self-identifying. Using "auto" or True should - always be safe. + always be safe when no offset or length is passed. + If offset or length are passed, the "auto" mode falls back + to zero_based=True to avoid having the heuristic check yield + inconsistent results on different segments of the file. query_id : boolean, defaults to False If True, will return the query_id array for each file. @@ -225,6 +247,15 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64, Data type of dataset to be loaded. This will be the data type of the output numpy arrays ``X`` and ``y``. + offset : integer, optional, default 0 + Ignore the offset first bytes by seeking forward, then + discarding the following bytes up until the next new line + character. + + length : integer, optional, default -1 + If strictly positive, stop reading any new line of data once the + position in the file has reached the (offset + length) bytes threshold. + Returns ------- [X1, y1, ..., Xn, yn] @@ -245,16 +276,27 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64, -------- load_svmlight_file """ - r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id)) + if (offset != 0 or length > 0) and zero_based == "auto": + # disable heuristic search to avoid getting inconsistent results on + # different segments of the file + zero_based = True + + if (offset != 0 or length > 0) and n_features is None: + raise ValueError( + "n_features is required when offset or length is specified.") + + r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id), + offset=offset, length=length) for f in files] - if (zero_based is False - or zero_based == "auto" and all(np.min(tmp[1]) > 0 for tmp in r)): - for ind in r: - indices = ind[1] + if (zero_based is False or + zero_based == "auto" and all(len(tmp[1]) and np.min(tmp[1]) > 0 + for tmp in r)): + for _, indices, _, _, _ in r: indices -= 1 - n_f = max(ind[1].max() for ind in r) + 1 + n_f = max(ind[1].max() if len(ind[1]) else 0 for ind in r) + 1 + if n_features is None: n_features = n_f elif n_features < n_f: diff --git a/sklearn/datasets/tests/test_svmlight_format.py b/sklearn/datasets/tests/test_svmlight_format.py index 47d956d1ee..c98206065f 100644 --- a/sklearn/datasets/tests/test_svmlight_format.py +++ b/sklearn/datasets/tests/test_svmlight_format.py @@ -1,3 +1,4 @@ +from __future__ import division from bz2 import BZ2File import gzip from io import BytesIO @@ -13,8 +14,10 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_raises_regex from sklearn.utils.testing import raises from sklearn.utils.testing import assert_in +from sklearn.utils.fixes import sp_version import sklearn from sklearn.datasets import (load_svmlight_file, load_svmlight_files, @@ -401,4 +404,107 @@ def test_load_with_long_qid(): f.seek(0) X, y = load_svmlight_file(f, query_id=False, zero_based=True) assert_array_equal(y, true_y) - assert_array_equal(X.toarray(), true_X) \ No newline at end of file + assert_array_equal(X.toarray(), true_X) + + +def test_load_zeros(): + f = BytesIO() + true_X = sp.csr_matrix(np.zeros(shape=(3, 4))) + true_y = np.array([0, 1, 0]) + dump_svmlight_file(true_X, true_y, f) + + for zero_based in ['auto', True, False]: + f.seek(0) + X, y = load_svmlight_file(f, n_features=4, zero_based=zero_based) + assert_array_equal(y, true_y) + assert_array_equal(X.toarray(), true_X.toarray()) + + +def test_load_with_offsets(): + def check_load_with_offsets(sparsity, n_samples, n_features): + rng = np.random.RandomState(0) + X = rng.uniform(low=0.0, high=1.0, size=(n_samples, n_features)) + if sparsity: + X[X < sparsity] = 0.0 + X = sp.csr_matrix(X) + y = rng.randint(low=0, high=2, size=n_samples) + + f = BytesIO() + dump_svmlight_file(X, y, f) + f.seek(0) + + size = len(f.getvalue()) + + # put some marks that are likely to happen anywhere in a row + mark_0 = 0 + mark_1 = size // 3 + length_0 = mark_1 - mark_0 + mark_2 = 4 * size // 5 + length_1 = mark_2 - mark_1 + + # load the original sparse matrix into 3 independant CSR matrices + X_0, y_0 = load_svmlight_file(f, n_features=n_features, + offset=mark_0, length=length_0) + X_1, y_1 = load_svmlight_file(f, n_features=n_features, + offset=mark_1, length=length_1) + X_2, y_2 = load_svmlight_file(f, n_features=n_features, + offset=mark_2) + + y_concat = np.concatenate([y_0, y_1, y_2]) + X_concat = sp.vstack([X_0, X_1, X_2]) + assert_array_equal(y, y_concat) + assert_array_almost_equal(X.toarray(), X_concat.toarray()) + + # Generate a uniformly random sparse matrix + for sparsity in [0, 0.1, .5, 0.99, 1]: + for n_samples in [13, 101]: + for n_features in [2, 7, 41]: + yield check_load_with_offsets, sparsity, n_samples, n_features + + +def test_load_offset_exhaustive_splits(): + rng = np.random.RandomState(0) + X = np.array([ + [0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 0, 6], + [1, 2, 3, 4, 0, 6], + [0, 0, 0, 0, 0, 0], + [1, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0], + ]) + X = sp.csr_matrix(X) + n_samples, n_features = X.shape + y = rng.randint(low=0, high=2, size=n_samples) + query_id = np.arange(n_samples) // 2 + + f = BytesIO() + dump_svmlight_file(X, y, f, query_id=query_id) + f.seek(0) + + size = len(f.getvalue()) + + # load the same data in 2 parts with all the possible byte offsets to + # locate the split so has to test for particular boundary cases + for mark in range(size): + if sp_version < (0, 14) and (mark == 0 or mark > size - 100): + # old scipy does not support sparse matrices with 0 rows. + continue + f.seek(0) + X_0, y_0, q_0 = load_svmlight_file(f, n_features=n_features, + query_id=True, offset=0, + length=mark) + X_1, y_1, q_1 = load_svmlight_file(f, n_features=n_features, + query_id=True, offset=mark, + length=-1) + q_concat = np.concatenate([q_0, q_1]) + y_concat = np.concatenate([y_0, y_1]) + X_concat = sp.vstack([X_0, X_1]) + assert_array_equal(y, y_concat) + assert_array_equal(query_id, q_concat) + assert_array_almost_equal(X.toarray(), X_concat.toarray()) + + +def test_load_with_offsets_error(): + assert_raises_regex(ValueError, "n_features is required", + load_svmlight_file, datafile, offset=3, length=3) -- GitLab