diff --git a/sklearn/datasets/svmlight_format.py b/sklearn/datasets/svmlight_format.py index f3fe0f84daa290e33d791155a1f34d659938244c..06ab8d5b73711110db39643a0eeedc8782fc0379 100644 --- a/sklearn/datasets/svmlight_format.py +++ b/sklearn/datasets/svmlight_format.py @@ -25,6 +25,7 @@ import numpy as np import scipy.sparse as sp from ._svmlight_format import _load_svmlight_file +from .. import __version__ from ..utils import atleast2d_or_csr @@ -200,6 +201,10 @@ def _dump_svmlight(X, y, f, one_based): else: line_pattern = u"%f %s\n" + f.write("# Generated by dump_svmlight_file from scikit-learn %s\n" + % __version__) + f.write("# Column indices are %s-based\n" % ["zero", "one"][one_based]) + for i in xrange(X.shape[0]): s = u" ".join([value_pattern % (j + one_based, X[i, j]) for j in X[i].nonzero()[is_sp]]) diff --git a/sklearn/datasets/tests/test_svmlight_format.py b/sklearn/datasets/tests/test_svmlight_format.py index c02c7cfa2a07c16cebf83764a81daf5e49ed7c1b..db3c53d57f57514ca1dfe6af81b817f89478788a 100644 --- a/sklearn/datasets/tests/test_svmlight_format.py +++ b/sklearn/datasets/tests/test_svmlight_format.py @@ -11,8 +11,10 @@ from numpy.testing import assert_array_equal from numpy.testing import assert_array_almost_equal from nose.tools import assert_raises, raises +import sklearn from sklearn.datasets import (load_svmlight_file, load_svmlight_files, dump_svmlight_file) +from sklearn.utils.testing import assert_in currdir = os.path.dirname(os.path.abspath(__file__)) datafile = os.path.join(currdir, "data", "svmlight_classification.txt") @@ -175,6 +177,12 @@ def test_dump(): dump_svmlight_file(X.astype(dtype), y, f, zero_based=zero_based) f.seek(0) + + comment = f.readline() + assert_in("scikit-learn %s" % sklearn.__version__, comment) + comment = f.readline() + assert_in(["one", "zero"][zero_based] + "-based", comment) + X2, y2 = load_svmlight_file(f, dtype=dtype, zero_based=zero_based) assert_equal(X2.dtype, dtype)