From 0adaacc40c4a9a9aa616ca23388ba0d90eeeb4b5 Mon Sep 17 00:00:00 2001
From: GaelVaroquaux <gael.varoquaux@normalesup.org>
Date: Sun, 22 Apr 2012 18:18:32 +0200
Subject: [PATCH] API: n_test -> test_size in Bootstrap

---
 doc/modules/cross_validation.rst       |  2 +-
 doc/whats_new.rst                      |  9 ++-
 sklearn/cross_validation.py            | 97 ++++++++++++++++----------
 sklearn/tests/test_cross_validation.py |  9 ++-
 sklearn/utils/__init__.py              |  2 +
 5 files changed, 73 insertions(+), 46 deletions(-)

diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst
index dc235ac526..5bb98b3b4c 100644
--- a/doc/modules/cross_validation.rst
+++ b/doc/modules/cross_validation.rst
@@ -390,7 +390,7 @@ smaller than the total dataset if it is very large.
   >>> len(bs)
   3
   >>> print bs
-  Bootstrap(9, n_bootstraps=3, n_train=5, n_test=4, random_state=0)
+  Bootstrap(9, n_bootstraps=3, train_size=5, test_size=4, random_state=0)
 
   >>> for train_index, test_index in bs:
   ...    print train_index, test_index
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index a43b6f7197..ec3b6eac70 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -162,11 +162,16 @@ API changes summary
    - The SVMlight format loader now supports files with both zero-based and
      one-based column indices, since both occur "in the wild".
 
-   - Options in class :class:`ShuffleSplit` are now consistent with
-     :class:`StratifiedShuffleSplit`. Options ``test_fraction`` and
+   - Arguments in class :class:`ShuffleSplit` are now consistent with
+     :class:`StratifiedShuffleSplit`. Arguments ``test_fraction`` and
      ``train_fraction`` are deprecated and renamed to ``test_size`` and
      ``train_size`` and can accept both ``float`` and ``int``.
 
+   - Arguments in class :class:`Bootstrap` are now consistent with
+     :class:`StratifiedShuffleSplit`. Arguments ``n_test`` and
+     ``n_train`` are deprecated and renamed to ``test_size`` and
+     ``train_size`` and can accept both ``float`` and ``int``.
+
    - Argument ``p`` added to classes in :ref:`neighbors` to specify an 
      arbitrary Minkowski metric for nearest neighbors searches.
 
diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py
index e4aa7a5a77..8e61a0a7dc 100644
--- a/sklearn/cross_validation.py
+++ b/sklearn/cross_validation.py
@@ -569,7 +569,7 @@ class Bootstrap(object):
     n_bootstraps : int (default is 3)
         Number of bootstrapping iterations
 
-    n_train : int or float (default is 0.5)
+    train_size : int or float (default is 0.5)
         If int, number of samples to include in the training split
         (should be smaller than the total number of samples passed
         in the dataset).
@@ -577,7 +577,7 @@ class Bootstrap(object):
         If float, should be between 0.0 and 1.0 and represent the
         proportion of the dataset to include in the train split.
 
-    n_test : int or float or None (default is None)
+    test_size : int or float or None (default is None)
         If int, number of samples to include in the training set
         (should be smaller than the total number of samples passed
         in the dataset).
@@ -597,7 +597,7 @@ class Bootstrap(object):
     >>> len(bs)
     3
     >>> print bs
-    Bootstrap(9, n_bootstraps=3, n_train=5, n_test=4, random_state=0)
+    Bootstrap(9, n_bootstraps=3, train_size=5, test_size=4, random_state=0)
     >>> for train_index, test_index in bs:
     ...    print "TRAIN:", train_index, "TEST:", test_index
     ...
@@ -613,32 +613,46 @@ class Bootstrap(object):
     # Static marker to be able to introspect the CV type
     indices = True
 
-    def __init__(self, n, n_bootstraps=3, n_train=0.5, n_test=None,
-                 random_state=None):
+    def __init__(self, n, n_bootstraps=3, train_size=.5, test_size=None,
+                 n_train=None, n_test=None, random_state=None):
         self.n = n
         self.n_bootstraps = n_bootstraps
-
-        if isinstance(n_train, float) and n_train >= 0.0 and n_train <= 1.0:
-            self.n_train = ceil(n_train * n)
-        elif isinstance(n_train, int):
-            self.n_train = n_train
+        if n_train is not None:
+            train_size = n_train
+            warnings.warn(
+                "n_train is deprecated in 0.11 and scheduled for "
+                "removal in 0.12, use train_size instead",
+                DeprecationWarning, stacklevel=2)
+        if n_test is not None:
+            test_size = n_test
+            warnings.warn(
+                "n_test is deprecated in 0.11 and scheduled for "
+                "removal in 0.12, use test_size instead",
+                DeprecationWarning, stacklevel=2)
+        if (isinstance(train_size, float) and train_size >= 0.0
+                            and train_size <= 1.0):
+            self.train_size = ceil(train_size * n)
+        elif isinstance(train_size, int):
+            self.train_size = train_size
         else:
-            raise ValueError("Invalid value for n_train: %r" % n_train)
-        if self.n_train > n:
-            raise ValueError("n_train=%d should not be larger than n=%d" %
-                             (self.n_train, n))
-
-        if isinstance(n_test, float) and n_test >= 0.0 and n_test <= 1.0:
-            self.n_test = ceil(n_test * n)
-        elif isinstance(n_test, int):
-            self.n_test = n_test
-        elif n_test is None:
-            self.n_test = self.n - self.n_train
+            raise ValueError("Invalid value for train_size: %r" %
+                             train_size)
+        if self.train_size > n:
+            raise ValueError("train_size=%d should not be larger than n=%d" %
+                             (self.train_size, n))
+
+        if (isinstance(test_size, float) and test_size >= 0.0
+                    and test_size <= 1.0):
+            self.test_size = ceil(test_size * n)
+        elif isinstance(test_size, int):
+            self.test_size = test_size
+        elif test_size is None:
+            self.test_size = self.n - self.train_size
         else:
-            raise ValueError("Invalid value for n_test: %r" % n_test)
-        if self.n_test > n:
-            raise ValueError("n_test=%d should not be larger than n=%d" %
-                             (self.n_test, n))
+            raise ValueError("Invalid value for test_size: %r" % test_size)
+        if self.test_size > n:
+            raise ValueError("test_size=%d should not be larger than n=%d" %
+                             (self.test_size, n))
 
         self.random_state = random_state
 
@@ -647,22 +661,25 @@ class Bootstrap(object):
         for i in range(self.n_bootstraps):
             # random partition
             permutation = rng.permutation(self.n)
-            ind_train = permutation[:self.n_train]
-            ind_test = permutation[self.n_train:self.n_train + self.n_test]
+            ind_train = permutation[:self.train_size]
+            ind_test = permutation[self.train_size:self.train_size
+                                   + self.test_size]
 
             # bootstrap in each split individually
-            train = rng.randint(0, self.n_train, size=(self.n_train,))
-            test = rng.randint(0, self.n_test, size=(self.n_test,))
+            train = rng.randint(0, self.train_size,
+                                size=(self.train_size,))
+            test = rng.randint(0, self.test_size,
+                                size=(self.test_size,))
             yield ind_train[train], ind_test[test]
 
     def __repr__(self):
-        return ('%s(%d, n_bootstraps=%d, n_train=%d, n_test=%d, '
+        return ('%s(%d, n_bootstraps=%d, train_size=%d, test_size=%d, '
                 'random_state=%d)' % (
                     self.__class__.__name__,
                     self.n,
                     self.n_bootstraps,
-                    self.n_train,
-                    self.n_test,
+                    self.train_size,
+                    self.test_size,
                     self.random_state,
                 ))
 
@@ -746,12 +763,14 @@ class ShuffleSplit(object):
         if test_fraction is not None:
             warnings.warn(
                 "test_fraction is deprecated in 0.11 and scheduled for "
-                "removal in 0.12, use test_size instead")
+                "removal in 0.12, use test_size instead",
+                DeprecationWarning, stacklevel=2)
             test_size = test_fraction
         if train_fraction is not None:
             warnings.warn(
                 "train_fraction is deprecated in 0.11 and scheduled for "
-                "removal in 0.12, use train_size instead")
+                "removal in 0.12, use train_size instead",
+                DeprecationWarning, stacklevel=2)
             train_size = train_fraction
 
         self.test_size = test_size
@@ -1259,16 +1278,18 @@ def train_test_split(*arrays, **options):
     test_fraction = options.pop('test_fraction', None)
     if test_fraction is not None:
         warnings.warn(
-            "test_fraction is deprecated in 0.11 and scheduled for "
-            "removal in 0.12, use test_size instead")
+                "test_fraction is deprecated in 0.11 and scheduled for "
+                "removal in 0.12, use test_size instead",
+                DeprecationWarning, stacklevel=2)
     else:
         test_fraction = 0.25
 
     train_fraction = options.pop('train_fraction', None)
     if train_fraction is not None:
         warnings.warn(
-            "train_fraction is deprecated in 0.11 and scheduled for "
-            "removal in 0.12, use train_size instead")
+                "train_fraction is deprecated in 0.11 and scheduled for "
+                "removal in 0.12, use train_size instead",
+                DeprecationWarning, stacklevel=2)
 
     test_size = options.pop('test_size', test_fraction)
     train_size = options.pop('train_size', train_fraction)
diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py
index 352a73bf87..d82922a970 100644
--- a/sklearn/tests/test_cross_validation.py
+++ b/sklearn/tests/test_cross_validation.py
@@ -163,7 +163,6 @@ def test_train_test_split_errors():
 
 
 def test_shuffle_split_warnings():
-    # change warnings.warn to catch the message
     expected_message = ("test_fraction is deprecated in 0.11 and scheduled "
                         "for removal in 0.12, use test_size instead",
                         "train_fraction is deprecated in 0.11 and scheduled "
@@ -317,10 +316,10 @@ def test_cross_val_generator_with_indices():
 
 
 def test_bootstrap_errors():
-    assert_raises(ValueError, cval.Bootstrap, 10, n_train=100)
-    assert_raises(ValueError, cval.Bootstrap, 10, n_test=100)
-    assert_raises(ValueError, cval.Bootstrap, 10, n_train=1.1)
-    assert_raises(ValueError, cval.Bootstrap, 10, n_test=1.1)
+    assert_raises(ValueError, cval.Bootstrap, 10, train_size=100)
+    assert_raises(ValueError, cval.Bootstrap, 10, test_size=100)
+    assert_raises(ValueError, cval.Bootstrap, 10, train_size=1.1)
+    assert_raises(ValueError, cval.Bootstrap, 10, test_size=1.1)
 
 
 def test_shufflesplit_errors():
diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py
index 0b354d3368..748cec14b9 100644
--- a/sklearn/utils/__init__.py
+++ b/sklearn/utils/__init__.py
@@ -8,6 +8,8 @@ import warnings
 from .validation import *
 from .murmurhash import murmurhash3_32
 
+# Make sure that DeprecationWarning get printed
+warnings.simplefilter("always", DeprecationWarning)
 
 class deprecated(object):
     """Decorator to mark a function or class as deprecated.
-- 
GitLab