Skip to content
Snippets Groups Projects
Commit e1ba1108 authored by Lars Buitinck's avatar Lars Buitinck
Browse files

ENH make generated SVMlight files self-describing in a comment

Comment is *not* used to parse the files, since that would break
compatibility with other tools; it's just there for the user.
parent 2530791e
Branches
Tags
No related merge requests found
...@@ -25,6 +25,7 @@ import numpy as np ...@@ -25,6 +25,7 @@ import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from ._svmlight_format import _load_svmlight_file from ._svmlight_format import _load_svmlight_file
from .. import __version__
from ..utils import atleast2d_or_csr from ..utils import atleast2d_or_csr
...@@ -200,6 +201,10 @@ def _dump_svmlight(X, y, f, one_based): ...@@ -200,6 +201,10 @@ def _dump_svmlight(X, y, f, one_based):
else: else:
line_pattern = u"%f %s\n" line_pattern = u"%f %s\n"
f.write("# Generated by dump_svmlight_file from scikit-learn %s\n"
% __version__)
f.write("# Column indices are %s-based\n" % ["zero", "one"][one_based])
for i in xrange(X.shape[0]): for i in xrange(X.shape[0]):
s = u" ".join([value_pattern % (j + one_based, X[i, j]) s = u" ".join([value_pattern % (j + one_based, X[i, j])
for j in X[i].nonzero()[is_sp]]) for j in X[i].nonzero()[is_sp]])
......
...@@ -11,8 +11,10 @@ from numpy.testing import assert_array_equal ...@@ -11,8 +11,10 @@ from numpy.testing import assert_array_equal
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
from nose.tools import assert_raises, raises from nose.tools import assert_raises, raises
import sklearn
from sklearn.datasets import (load_svmlight_file, load_svmlight_files, from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
dump_svmlight_file) dump_svmlight_file)
from sklearn.utils.testing import assert_in
currdir = os.path.dirname(os.path.abspath(__file__)) currdir = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(currdir, "data", "svmlight_classification.txt") datafile = os.path.join(currdir, "data", "svmlight_classification.txt")
...@@ -175,6 +177,12 @@ def test_dump(): ...@@ -175,6 +177,12 @@ def test_dump():
dump_svmlight_file(X.astype(dtype), y, f, dump_svmlight_file(X.astype(dtype), y, f,
zero_based=zero_based) zero_based=zero_based)
f.seek(0) f.seek(0)
comment = f.readline()
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype, X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based) zero_based=zero_based)
assert_equal(X2.dtype, dtype) assert_equal(X2.dtype, dtype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment