diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 8863918e5302a2e37c929eeffae747da0fac5be0..88b8d47512f61302d0d7c07a7b5729102e3051ee 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -588,6 +588,9 @@ Enhancements - :class:`svm.SVC` fitted on sparse input now implements ``decision_function``. By `Rob Zinkov`_ and `Andreas Müller`_. + - :func:`cross_validation.train_test_split` now preserves the input type, + instead of converting to numpy arrays. + Documentation improvements .......................... diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 259d7320026baa23ccdb2e76fbacc59c2ccaae76..1b172a2ed001475a2f4f0de0f55ad1d279feca89 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1792,9 +1792,13 @@ def train_test_split(*arrays, **options): Parameters ---------- - *arrays : sequence of arrays or scipy.sparse matrices with same shape[0] - Python lists or tuples occurring in arrays are converted to 1D numpy - arrays. + *arrays : sequence of indexables with same length / shape[0] + + allowed inputs are lists, numpy arrays, scipy-sparse + matrices or pandas dataframes. + + .. versionadded:: 0.16 + preserves input type instead of always casting to numpy array. test_size : float, int, or None (default is None) If float, should be between 0.0 and 1.0 and represent the @@ -1818,8 +1822,11 @@ def train_test_split(*arrays, **options): Returns ------- - splitting : list of arrays, length=2 * len(arrays) - List containing train-test split of input array. + splitting : list, length = 2 * len(arrays), + List containing train-test split of inputs. + + .. versionadded:: 0.16 + Output type is the same as the input type. Examples --------