diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 7ac7c5fcc8241121ac630fcc4702bd99a355b1f5..a450c175ae8fca3f49a25a4639775e8e6d0030a4 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -397,6 +397,11 @@ Bug fixes
     - Fix :class:`linear_model.ElasticNet` sparse decision function to match
       output with dense in the multioutput case.
 
+    - Fix in :class:`sklearn.model_selection.StratifiedShuffleSplit` to
+      return splits of size ``train_size`` and ``test_size`` in all cases
+      (`#6472 <https://github.com/scikit-learn/scikit-learn/pull/6472>`).
+      By `Andreas Müller`_.
+
 API changes summary
 -------------------
 
diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py
index 508b0460ec1542282e170c2a36292e5b965d8e08..010f7106a4870e0c25c9adcfc7a75d67e618eab0 100644
--- a/sklearn/cross_validation.py
+++ b/sklearn/cross_validation.py
@@ -27,6 +27,7 @@ from .utils import indexable, check_random_state, safe_indexing
 from .utils.validation import (_is_arraylike, _num_samples,
                                column_or_1d)
 from .utils.multiclass import type_of_target
+from .utils.random import choice
 from .externals.joblib import Parallel, delayed, logger
 from .externals.six import with_metaclass
 from .externals.six.moves import zip
@@ -414,9 +415,9 @@ class LabelKFold(_BaseKFold):
 
         if n_folds > n_labels:
             raise ValueError(
-                    ("Cannot have number of folds n_folds={0} greater"
-                     " than the number of labels: {1}.").format(n_folds,
-                                                                n_labels))
+                ("Cannot have number of folds n_folds={0} greater"
+                 " than the number of labels: {1}.").format(n_folds,
+                                                            n_labels))
 
         # Weight labels by their number of occurrences
         n_samples_per_label = np.bincount(labels)
@@ -906,6 +907,59 @@ def _validate_shuffle_split(n, test_size, train_size):
     return int(n_train), int(n_test)
 
 
+def _approximate_mode(class_counts, n_draws, rng):
+    """Computes approximate mode of multivariate hypergeometric.
+
+    This is an approximation to the mode of the multivariate
+    hypergeometric given by class_counts and n_draws.
+    It shouldn't be off by more than one.
+
+    It is the mostly likely outcome of drawing n_draws many
+    samples from the population given by class_counts.
+
+    Parameters
+    ----------
+    class_counts : ndarray of int
+        Population per class.
+    n_draws : int
+        Number of draws (samples to draw) from the overall population.
+    rng : random state
+        Used to break ties.
+
+    Returns
+    -------
+    sampled_classes : ndarray of int
+        Number of samples drawn from each class.
+        np.sum(sampled_classes) == n_draws
+    """
+    # this computes a bad approximation to the mode of the
+    # multivariate hypergeometric given by class_counts and n_draws
+    continuous = n_draws * class_counts / class_counts.sum()
+    # floored means we don't overshoot n_samples, but probably undershoot
+    floored = np.floor(continuous)
+    # we add samples according to how much "left over" probability
+    # they had, until we arrive at n_samples
+    need_to_add = int(n_draws - floored.sum())
+    if need_to_add > 0:
+        remainder = continuous - floored
+        values = np.sort(np.unique(remainder))[::-1]
+        # add according to remainder, but break ties
+        # randomly to avoid biases
+        for value in values:
+            inds, = np.where(remainder == value)
+            # if we need_to_add less than what's in inds
+            # we draw randomly from them.
+            # if we need to add more, we add them all and
+            # go to the next value
+            add_now = min(len(inds), need_to_add)
+            inds = choice(inds, size=add_now, replace=False, random_state=rng)
+            floored[inds] += 1
+            need_to_add -= add_now
+            if need_to_add == 0:
+                    break
+    return floored.astype(np.int)
+
+
 class StratifiedShuffleSplit(BaseShuffleSplit):
     """Stratified ShuffleSplit cross validation iterator
 
@@ -991,39 +1045,24 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
     def _iter_indices(self):
         rng = check_random_state(self.random_state)
         cls_count = bincount(self.y_indices)
-        p_i = cls_count / float(self.n)
-        n_i = np.round(self.n_train * p_i).astype(int)
-        t_i = np.minimum(cls_count - n_i,
-                         np.round(self.n_test * p_i).astype(int))
 
         for n in range(self.n_iter):
+            # if there are ties in the class-counts, we want
+            # to make sure to break them anew in each iteration
+            n_i = _approximate_mode(cls_count, self.n_train, rng)
+            class_counts_remaining = cls_count - n_i
+            t_i = _approximate_mode(class_counts_remaining, self.n_test, rng)
+
             train = []
             test = []
 
-            for i, cls in enumerate(self.classes):
+            for i, _ in enumerate(self.classes):
                 permutation = rng.permutation(cls_count[i])
-                cls_i = np.where((self.y == cls))[0][permutation]
-
-                train.extend(cls_i[:n_i[i]])
-                test.extend(cls_i[n_i[i]:n_i[i] + t_i[i]])
-
-            # Because of rounding issues (as n_train and n_test are not
-            # dividers of the number of elements per class), we may end
-            # up here with less samples in train and test than asked for.
-            if len(train) + len(test) < self.n_train + self.n_test:
-                # We complete by affecting randomly the missing indexes
-                missing_idx = np.where(bincount(train + test,
-                                                minlength=len(self.y)) == 0,
-                                       )[0]
-                missing_idx = rng.permutation(missing_idx)
-                n_missing_train = self.n_train - len(train)
-                n_missing_test = self.n_test - len(test)
-
-                if n_missing_train > 0:
-                    train.extend(missing_idx[:n_missing_train])
-                if n_missing_test > 0:
-                    test.extend(missing_idx[-n_missing_test:])
+                perm_indices_class_i = np.where(
+                    (i == self.y_indices))[0][permutation]
 
+                train.extend(perm_indices_class_i[:n_i[i]])
+                test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
             train = rng.permutation(train)
             test = rng.permutation(test)
 
diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py
index 5989edd30b1093ad1c7fff0ca076dcf51c886c29..a4762652854374c5f59a97b68be0d70c7b67cd4b 100644
--- a/sklearn/model_selection/_split.py
+++ b/sklearn/model_selection/_split.py
@@ -30,6 +30,7 @@ from ..externals.six import with_metaclass
 from ..externals.six.moves import zip
 from ..utils.fixes import bincount
 from ..utils.fixes import signature
+from ..utils.random import choice
 from ..base import _pprint
 from ..gaussian_process.kernels import Kernel as GPKernel
 
@@ -1098,6 +1099,73 @@ class LabelShuffleSplit(ShuffleSplit):
             yield train, test
 
 
+def _approximate_mode(class_counts, n_draws, rng):
+    """Computes approximate mode of multivariate hypergeometric.
+
+    This is an approximation to the mode of the multivariate
+    hypergeometric given by class_counts and n_draws.
+    It shouldn't be off by more than one.
+
+    It is the mostly likely outcome of drawing n_draws many
+    samples from the population given by class_counts.
+
+    Parameters
+    ----------
+    class_counts : ndarray of int
+        Population per class.
+    n_draws : int
+        Number of draws (samples to draw) from the overall population.
+    rng : random state
+        Used to break ties.
+
+    Returns
+    -------
+    sampled_classes : ndarray of int
+        Number of samples drawn from each class.
+        np.sum(sampled_classes) == n_draws
+
+    Examples
+    --------
+    >>> from sklearn.model_selection._split import _approximate_mode
+    >>> _approximate_mode(class_counts=np.array([4, 2]), n_draws=3, rng=0)
+    array([2, 1])
+    >>> _approximate_mode(class_counts=np.array([5, 2]), n_draws=4, rng=0)
+    array([3, 1])
+    >>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
+    ...                   n_draws=2, rng=0)
+    array([0, 1, 1, 0])
+    >>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
+    ...                   n_draws=2, rng=42)
+    array([1, 1, 0, 0])
+    """
+    # this computes a bad approximation to the mode of the
+    # multivariate hypergeometric given by class_counts and n_draws
+    continuous = n_draws * class_counts / class_counts.sum()
+    # floored means we don't overshoot n_samples, but probably undershoot
+    floored = np.floor(continuous)
+    # we add samples according to how much "left over" probability
+    # they had, until we arrive at n_samples
+    need_to_add = int(n_draws - floored.sum())
+    if need_to_add > 0:
+        remainder = continuous - floored
+        values = np.sort(np.unique(remainder))[::-1]
+        # add according to remainder, but break ties
+        # randomly to avoid biases
+        for value in values:
+            inds, = np.where(remainder == value)
+            # if we need_to_add less than what's in inds
+            # we draw randomly from them.
+            # if we need to add more, we add them all and
+            # go to the next value
+            add_now = min(len(inds), need_to_add)
+            inds = choice(inds, size=add_now, replace=False, random_state=rng)
+            floored[inds] += 1
+            need_to_add -= add_now
+            if need_to_add == 0:
+                break
+    return floored.astype(np.int)
+
+
 class StratifiedShuffleSplit(BaseShuffleSplit):
     """Stratified ShuffleSplit cross-validator
 
@@ -1181,12 +1249,14 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
                              (n_test, n_classes))
 
         rng = check_random_state(self.random_state)
-        p_i = class_counts / float(n_samples)
-        n_i = np.round(n_train * p_i).astype(int)
-        t_i = np.minimum(class_counts - n_i,
-                         np.round(n_test * p_i).astype(int))
 
         for _ in range(self.n_splits):
+            # if there are ties in the class-counts, we want
+            # to make sure to break them anew in each iteration
+            n_i = _approximate_mode(class_counts, n_train, rng)
+            class_counts_remaining = class_counts - n_i
+            t_i = _approximate_mode(class_counts_remaining, n_test, rng)
+
             train = []
             test = []
 
@@ -1196,23 +1266,6 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
 
                 train.extend(perm_indices_class_i[:n_i[i]])
                 test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
-
-            # Because of rounding issues (as n_train and n_test are not
-            # dividers of the number of elements per class), we may end
-            # up here with less samples in train and test than asked for.
-            if len(train) + len(test) < n_train + n_test:
-                # We complete by affecting randomly the missing indexes
-                missing_indices = np.where(bincount(train + test,
-                                                    minlength=len(y)) == 0)[0]
-                missing_indices = rng.permutation(missing_indices)
-                n_missing_train = n_train - len(train)
-                n_missing_test = n_test - len(test)
-
-                if n_missing_train > 0:
-                    train.extend(missing_indices[:n_missing_train])
-                if n_missing_test > 0:
-                    test.extend(missing_indices[-n_missing_test:])
-
             train = rng.permutation(train)
             test = rng.permutation(test)
 
diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py
index d28148efe6956bc9599e8b6790f8ec8016f93978..d4130182b0e10c468b924c685e6240203c32a1e8 100644
--- a/sklearn/model_selection/tests/test_split.py
+++ b/sklearn/model_selection/tests/test_split.py
@@ -535,17 +535,33 @@ def test_stratified_shuffle_split_init():
                   StratifiedShuffleSplit(test_size=2).split(X, y))
 
 
+def test_stratified_shuffle_split_respects_test_size():
+    y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
+    test_size = 5
+    train_size = 10
+    sss = StratifiedShuffleSplit(6, test_size=test_size, train_size=train_size,
+                                 random_state=0).split(np.ones(len(y)), y)
+    for train, test in sss:
+        assert_equal(len(train), train_size)
+        assert_equal(len(test), test_size)
+
+
 def test_stratified_shuffle_split_iter():
     ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
           np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
-          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
+          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
           np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
-          np.array([-1] * 800 + [1] * 50)
+          np.array([-1] * 800 + [1] * 50),
+          np.concatenate([[i] * (100 + i) for i in range(11)])
           ]
 
     for y in ys:
         sss = StratifiedShuffleSplit(6, test_size=0.33,
                                      random_state=0).split(np.ones(len(y)), y)
+        # this is how test-size is computed internally
+        # in _validate_shuffle_split
+        test_size = np.ceil(0.33 * len(y))
+        train_size = len(y) - test_size
         for train, test in sss:
             assert_array_equal(np.unique(y[train]), np.unique(y[test]))
             # Checks if folds keep classes proportions
@@ -556,7 +572,9 @@ def test_stratified_shuffle_split_iter():
                                   return_inverse=True)[1]) /
                       float(len(y[test])))
             assert_array_almost_equal(p_train, p_test, 1)
-            assert_equal(y[train].size + y[test].size, y.size)
+            assert_equal(len(train) + len(test), y.size)
+            assert_equal(len(train), train_size)
+            assert_equal(len(test), test_size)
             assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
 
 
@@ -572,8 +590,8 @@ def test_stratified_shuffle_split_even():
         threshold = 0.05 / n_splits
         bf = stats.binom(n_splits, p)
         for count in idx_counts:
-            p = bf.pmf(count)
-            assert_true(p > threshold,
+            prob = bf.pmf(count)
+            assert_true(prob > threshold,
                         "An index is not drawn with chance corresponding "
                         "to even draws")
 
@@ -593,9 +611,8 @@ def test_stratified_shuffle_split_even():
                     counter[id] += 1
         assert_equal(n_splits_actual, n_splits)
 
-        n_train, n_test = _validate_shuffle_split(n_samples,
-                                                  test_size=1./n_folds,
-                                                  train_size=1.-(1./n_folds))
+        n_train, n_test = _validate_shuffle_split(
+            n_samples, test_size=1. / n_folds, train_size=1. - (1. / n_folds))
 
         assert_equal(len(train), n_train)
         assert_equal(len(test), n_test)
@@ -656,7 +673,7 @@ def test_label_shuffle_split():
     for l in labels:
         X = y = np.ones(len(l))
         n_splits = 6
-        test_size = 1./3
+        test_size = 1. / 3
         slo = LabelShuffleSplit(n_splits, test_size=test_size, random_state=0)
 
         # Make sure the repr works
diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py
index 06a0a7efd8c228eb807484412f7322cef36d01d3..4d756bdaa0cf85ba122485a0c7aa8bc9869b0bde 100644
--- a/sklearn/tests/test_cross_validation.py
+++ b/sklearn/tests/test_cross_validation.py
@@ -479,7 +479,7 @@ def test_stratified_shuffle_split_init():
 def test_stratified_shuffle_split_iter():
     ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
           np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
-          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
+          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
           np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
           np.array([-1] * 800 + [1] * 50)
           ]
@@ -487,16 +487,22 @@ def test_stratified_shuffle_split_iter():
     for y in ys:
         sss = cval.StratifiedShuffleSplit(y, 6, test_size=0.33,
                                           random_state=0)
+        test_size = np.ceil(0.33 * len(y))
+        train_size = len(y) - test_size
         for train, test in sss:
             assert_array_equal(np.unique(y[train]), np.unique(y[test]))
             # Checks if folds keep classes proportions
-            p_train = (np.bincount(np.unique(y[train], return_inverse=True)[1])
-                       / float(len(y[train])))
-            p_test = (np.bincount(np.unique(y[test], return_inverse=True)[1])
-                      / float(len(y[test])))
+            p_train = (np.bincount(np.unique(y[train],
+                                   return_inverse=True)[1]) /
+                       float(len(y[train])))
+            p_test = (np.bincount(np.unique(y[test],
+                                  return_inverse=True)[1]) /
+                      float(len(y[test])))
             assert_array_almost_equal(p_train, p_test, 1)
-            assert_equal(y[train].size + y[test].size, y.size)
-            assert_array_equal(np.intersect1d(train, test), [])
+            assert_equal(len(train) + len(test), y.size)
+            assert_equal(len(train), train_size)
+            assert_equal(len(test), test_size)
+            assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
 
 
 def test_stratified_shuffle_split_even():
diff --git a/sklearn/utils/random.py b/sklearn/utils/random.py
index 34738d8653b742cc4f912d3eefee7aa052e1682f..5805f9be2c8fa51c52205a2f8893fd5e5cdbd024 100644
--- a/sklearn/utils/random.py
+++ b/sklearn/utils/random.py
@@ -123,7 +123,7 @@ def choice(a, size=None, replace=True, p=None, random_state=None):
         if pop_size is 0:
             raise ValueError("a must be non-empty")
 
-    if None != p:
+    if p is not None:
         p = np.array(p, dtype=np.double, ndmin=1, copy=False)
         if p.ndim != 1:
             raise ValueError("p must be 1-dimensional")
@@ -142,7 +142,7 @@ def choice(a, size=None, replace=True, p=None, random_state=None):
 
     # Actual sampling
     if replace:
-        if None != p:
+        if p is not None:
             cdf = p.cumsum()
             cdf /= cdf[-1]
             uniform_samples = random_state.random_sample(shape)
@@ -156,7 +156,7 @@ def choice(a, size=None, replace=True, p=None, random_state=None):
             raise ValueError("Cannot take a larger sample than "
                              "population when 'replace=False'")
 
-        if None != p:
+        if p is not None:
             if np.sum(p > 0) < size:
                 raise ValueError("Fewer non-zero entries in p than size")
             n_uniq = 0