diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst
index d5452b1e3856fe2f3854c52ef00578e6df4378ef..74070a2c09f7ce7f0cbb4aacf6e7f66803814f58 100644
--- a/doc/modules/preprocessing.rst
+++ b/doc/modules/preprocessing.rst
@@ -370,7 +370,7 @@ matrix from a list of multi-class labels::
 
     >>> lb = preprocessing.LabelBinarizer()
     >>> lb.fit([1, 2, 6, 4, 2])
-    LabelBinarizer(neg_label=0, pos_label=1)
+    LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False)
     >>> lb.classes_
     array([1, 2, 4, 6])
     >>> lb.transform([1, 6])
diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py
index f7d513e4581743a31642f084c81a572d330573a3..3d2b188184d7e6403c6f192537ca37bedb96be86 100644
--- a/sklearn/linear_model/ridge.py
+++ b/sklearn/linear_model/ridge.py
@@ -571,7 +571,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
         """
         self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
         Y = self._label_binarizer.fit_transform(y)
-        if not self._label_binarizer.multilabel_:
+        if not self._label_binarizer.y_type_.startswith('multilabel'):
             y = column_or_1d(y, warn=True)
 
         if self.class_weight:
diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py
index 851dc5d9738d9b8bdf3929bb132e9d12458a4cec..6baff66c65f6290d10228860848ae0ce8836efbb 100644
--- a/sklearn/preprocessing/label.py
+++ b/sklearn/preprocessing/label.py
@@ -3,11 +3,13 @@
 #          Olivier Grisel <olivier.grisel@ensta.org>
 #          Andreas Mueller <amueller@ais.uni-bonn.de>
 #          Joel Nothman <joel.nothman@gmail.com>
+#          Hamzeh Alsalhi <ha258@cornell.edu>
 # License: BSD 3 clause
 
 from collections import defaultdict
 import itertools
 import array
+import warnings
 
 import numpy as np
 import scipy.sparse as sp
@@ -15,6 +17,7 @@ import scipy.sparse as sp
 from ..base import BaseEstimator, TransformerMixin
 
 from ..utils.fixes import np_version
+from ..utils.fixes import sparse_min_max
 from ..utils import deprecated, column_or_1d
 
 from ..utils.multiclass import unique_labels
@@ -188,6 +191,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
     pos_label : int (default: 1)
         Value with which positive labels must be encoded.
 
+    sparse_output : boolean (default: False)
+        True if the returned array from transform is desired to be in sparse
+        CSR format.
+
     Attributes
     ----------
     `classes_` : array of shape [n_class]
@@ -195,18 +202,27 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
 
     `multilabel_` : boolean
         True if the transformer was fitted on a multilabel rather than a
-        multiclass set of labels.
+        multiclass set of labels. The multilabel_ attribute is deprecated
+        and will be removed in 0.18
+
+    `sparse_input_` : boolean,
+        True if the input data to transform is given as a sparse matrix, False
+        otherwise.
+
+    `indicator_matrix_` : str
+        'sparse' when the input data to tansform is a multilable-indicator and
+        is sparse, None otherwise. The indicator_matrix_ attribute is
+        deprecated as of version 0.16 and will be removed in 0.18
+
 
     Examples
     --------
     >>> from sklearn import preprocessing
     >>> lb = preprocessing.LabelBinarizer()
     >>> lb.fit([1, 2, 6, 4, 2])
-    LabelBinarizer(neg_label=0, pos_label=1)
+    LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False)
     >>> lb.classes_
     array([1, 2, 4, 6])
-    >>> lb.multilabel_
-    False
     >>> lb.transform([1, 6])
     array([[1, 0, 0, 0],
            [0, 0, 0, 1]])
@@ -221,11 +237,9 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
 
     >>> import numpy as np
     >>> lb.fit(np.array([[0, 1, 1], [1, 0, 0]]))
-    LabelBinarizer(neg_label=0, pos_label=1)
+    LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False)
     >>> lb.classes_
     array([0, 1, 2])
-    >>> lb.multilabel_
-    True
 
     See also
     --------
@@ -233,12 +247,20 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         LabelBinarizer with fixed classes.
     """
 
-    def __init__(self, neg_label=0, pos_label=1):
+    def __init__(self, neg_label=0, pos_label=1, sparse_output=False):
         if neg_label >= pos_label:
-            raise ValueError("neg_label must be strictly less than pos_label.")
+            raise ValueError("neg_label={0} must be strictly less than "
+                             "pos_label={1}.".format(neg_label, pos_label))
+
+        if sparse_output and (pos_label == 0 or neg_label != 0):
+            raise ValueError("Sparse binarization is only supported with non "
+                             "zero pos_label and zero neg_label, got "
+                             "pos_label={0} and neg_label={1}"
+                             "".format(pos_label, neg_label))
 
         self.neg_label = neg_label
         self.pos_label = pos_label
+        self.sparse_output = sparse_output
 
     @property
     @deprecated("Attribute `multilabel` was renamed to `multilabel_` in "
@@ -246,6 +268,20 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
     def multilabel(self):
         return self.multilabel_
 
+    @property
+    @deprecated("Attribute indicator_matrix_ is deprecated and will be "
+                "removed in 0.17. Use 'y_type_ == 'multilabel-indicator'' "
+                "instead")
+    def indicator_matrix_(self):
+        return self.y_type_ == 'multilabel-indicator'
+
+    @property
+    @deprecated("Attribute multilabel_ is deprecated and will be removed "
+                "in 0.17. Use 'y_type_.startswith('multilabel')' "
+                "instead")
+    def multilabel_(self):
+        return self.y_type_.startswith('multilabel')
+
     def _check_fitted(self):
         if not hasattr(self, "classes_"):
             raise ValueError("LabelBinarizer was not fitted yet.")
@@ -263,13 +299,9 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         -------
         self : returns an instance of self.
         """
-        y_type = type_of_target(y)
-        self.multilabel_ = y_type.startswith('multilabel')
-        if self.multilabel_:
-            self.indicator_matrix_ = y_type == 'multilabel-indicator'
-
+        self.y_type_ = type_of_target(y)
+        self.sparse_input_ = sp.issparse(y)
         self.classes_ = unique_labels(y)
-
         return self
 
     def transform(self, y):
@@ -280,35 +312,30 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
 
         Parameters
         ----------
-        y : numpy array of shape (n_samples,) or (n_samples, n_classes)
-            Target values. The 2-d matrix should only contain 0 and 1,
-            represents multilabel classification.
+        y : numpy array or sparse matrix of shape (n_samples,) or
+            (n_samples, n_classes) Target values. The 2-d matrix should only
+            contain 0 and 1, represents multilabel classification. Sparse
+            matrix can be CSR, CSC, COO, DOK, or LIL.
 
         Returns
         -------
-        Y : numpy array of shape [n_samples, n_classes]
+        Y : numpy array or CSR matrix of shape [n_samples, n_classes]
             Shape will be [n_samples, 1] for binary problems.
         """
         self._check_fitted()
-
-        y_is_multilabel = type_of_target(y).startswith('multilabel')
-
-        if y_is_multilabel and not self.multilabel_:
-            raise ValueError("The object was not fitted with multilabel"
-                             " input.")
-
         return label_binarize(y, self.classes_,
-                              multilabel=self.multilabel_,
                               pos_label=self.pos_label,
-                              neg_label=self.neg_label)
+                              neg_label=self.neg_label,
+                              sparse_output=self.sparse_output)
 
     def inverse_transform(self, Y, threshold=None):
         """Transform binary labels back to multi-class labels
 
         Parameters
         ----------
-        Y : numpy array of shape [n_samples, n_classes]
-            Target values.
+        Y : numpy array or sparse matrix with shape [n_samples, n_classes]
+            Target values. All sparse matrices are converted to CSR before
+            inverse transformation.
 
         threshold : float or None
             Threshold used in the binary and multi-label cases.
@@ -323,9 +350,7 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
 
         Returns
         -------
-        y : numpy array of shape (n_samples,) or (n_samples, n_classes)
-            Target values. The 2-d matrix should only contain 0 and 1,
-            represents multilabel classification.
+        y : numpy array or CSR matrix of shape [n_samples] Target values.
 
         Notes
         -----
@@ -338,30 +363,24 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
         self._check_fitted()
 
         if threshold is None:
-            half = (self.pos_label - self.neg_label) / 2.0
-            threshold = self.neg_label + half
-
-        if self.multilabel_:
-            Y = np.array(Y > threshold, dtype=int)
-            # Return the predictions in the same format as in fit
-            if self.indicator_matrix_:
-                # Label indicator matrix format
-                return Y
-            else:
-                # Lists of tuples format
-                mlb = MultiLabelBinarizer(classes=self.classes_).fit([])
-                return mlb.inverse_transform(Y)
-
-        if len(Y.shape) == 1 or Y.shape[1] == 1:
-            y = np.array(Y.ravel() > threshold, dtype=int)
+            threshold = (self.pos_label + self.neg_label) / 2.
 
+        if self.y_type_ == "multiclass":
+            y_inv = _inverse_binarize_multiclass(Y, self.classes_)
         else:
-            y = Y.argmax(axis=1)
+            y_inv = _inverse_binarize_thresholding(Y, self.y_type_,
+                                                   self.classes_, threshold)
 
-        return self.classes_[y]
+        if self.sparse_input_:
+            y_inv = sp.csr_matrix(y_inv)
+        elif sp.issparse(y_inv):
+            y_inv = y_inv.toarray()
 
+        return y_inv
 
-def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1):
+
+def label_binarize(y, classes, neg_label=0, pos_label=1,
+                   sparse_output=False, multilabel=None):
     """Binarize labels in a one-vs-all fashion
 
     Several regression and binary classification algorithms are
@@ -380,20 +399,18 @@ def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1):
     classes : array-like of shape [n_classes]
         Uniquely holds the label for each class.
 
-    multilabel : boolean
-        Set to true if y is encoding a multilabel tasks (with a variable
-        number of label assignements per sample) rather than a multiclass task
-        where one sample has one and only one label assigned.
-
-    neg_label: int (default: 0)
+    neg_label : int (default: 0)
         Value with which negative labels must be encoded.
 
-    pos_label: int (default: 1)
+    pos_label : int (default: 1)
         Value with which positive labels must be encoded.
 
+    sparse_output : boolean (default: False),
+        Set to true if output binary array is desired in CSR sparse format
+
     Returns
     -------
-    Y : numpy array of shape [n_samples, n_classes]
+    Y : numpy array or CSR matrix of shape [n_samples, n_classes]
         Shape will be [n_samples, 1] for binary problems.
 
     Examples
@@ -419,43 +436,192 @@ def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1):
 
     See also
     --------
-    label_binarize : function to perform the transform operation of
-        LabelBinarizer with fixed classes.
+    LabelBinarizer : class used to wrap the functionality of label_binarize and
+        allow for fitting to classes independently of the transform operation
     """
-    y_type = type_of_target(y)
+    if neg_label >= pos_label:
+        raise ValueError("neg_label={0} must be strictly less than "
+                         "pos_label={1}.".format(neg_label, pos_label))
+
+    if (sparse_output and (pos_label == 0 or neg_label != 0)):
+        raise ValueError("Sparse binarization is only supported with non "
+                         "zero pos_label and zero neg_label, got "
+                         "pos_label={0} and neg_label={1}"
+                         "".format(pos_label, neg_label))
+
+    if multilabel is not None:
+        warnings.warn("The multilabel parameter is deprecated as of version "
+                      "0.15 and will be removed in 0.17. The parameter is no "
+                      "longer necessary because the value is automatically "
+                      "inferred.", DeprecationWarning)
+
+    # To account for pos_label == 0 in the dense case
+    pos_switch = pos_label == 0
+    if pos_switch:
+        pos_label = -neg_label
 
-    if multilabel or len(classes) > 2:
-        Y = np.zeros((len(y), len(classes)), dtype=np.int)
-    else:
-        Y = np.zeros((len(y), 1), dtype=np.int)
+    y_type = type_of_target(y)
 
-    Y += neg_label
+    n_samples = y.shape[0] if sp.issparse(y) else len(y)
+    n_classes = len(classes)
+    classes = np.asarray(classes)
 
-    if multilabel:
-        if y_type == "multilabel-indicator":
-            Y[y == 1] = pos_label
+    if y_type == "binary":
+        if len(classes) == 1:
+            Y = np.zeros((len(y), 1), dtype=np.int)
+            Y += neg_label
             return Y
-        elif y_type == "multilabel-sequences":
-            return MultiLabelBinarizer(classes=classes).fit_transform(y)
+        elif len(classes) >= 3:
+            y_type = "multiclass"
+
+    sorted_class = np.sort(classes)
+    if (y_type == "multilabel-indicator" and classes.size != y.shape[1] or
+            not set(classes).issuperset(unique_labels(y))):
+        raise ValueError("classes {0} missmatch with the labels {1}"
+                         "found in the data".format(classes, unique_labels(y)))
+
+    if y_type in ("binary", "multiclass"):
+        y = column_or_1d(y)
+        indptr = np.arange(n_samples + 1)
+        indices = np.searchsorted(sorted_class, y)
+        data = np.empty_like(indices)
+        data.fill(pos_label)
+
+        Y = sp.csr_matrix((data, indices, indptr),
+                          shape=(n_samples, n_classes))
+
+    elif y_type == "multilabel-indicator":
+        Y = sp.csr_matrix(y)
+        if pos_label != 1:
+            data = np.empty_like(Y.data)
+            data.fill(pos_label)
+            Y.data = data
+
+    elif y_type == "multilabel-sequences":
+        Y = MultiLabelBinarizer(classes=classes,
+                                sparse_output=sparse_output).fit_transform(y)
+
+        if sp.issparse(Y):
+            Y.data[:] = pos_label
         else:
-            raise ValueError("y should be in a multilabel format, "
-                             "got %r" % (y,))
+            Y[Y == 1] = pos_label
+        return Y
+
+    if not sparse_output:
+        Y = Y.toarray()
+
+        if neg_label != 0:
+            Y[Y == 0] = neg_label
+
+        if pos_switch:
+            Y[Y == pos_label] = 0
+
+    # preserve label ordering
+    if np.any(classes != sorted_class):
+        indices = np.argsort(classes)
+        Y = Y[:, indices]
+
+    if y_type == "binary":
+        Y = Y[:, -1].reshape((-1, 1))
+
+    return Y
 
+
+def _inverse_binarize_multiclass(y, classes):
+    """Inverse label binarization transformation for multiclass.
+
+    Multiclass uses the maximal score instead of a threshold
+    """
+    classes = np.asarray(classes)
+
+    if sp.issparse(y):
+        # Find the argmax for each row in y where y is a CSR matrix
+
+        y = y.tocsr()
+        n_samples, n_outputs = y.shape
+        outputs = np.arange(n_outputs)
+        row_max = sparse_min_max(y, 1)[1]
+        row_nnz = np.diff(y.indptr)
+
+        y_data_repeated_max = np.repeat(row_max, row_nnz)
+        # picks out all indices obtaining the maximum per row
+        y_i_all_argmax = np.flatnonzero(y_data_repeated_max == y.data)
+
+        # For corner case where last row has a max of 0
+        if row_max[-1] == 0:
+            y_i_all_argmax = np.append(y_i_all_argmax, [len(y.data)])
+
+        # Gets the index of the first argmax in each row from y_i_all_argmax
+        index_first_argmax = np.searchsorted(y_i_all_argmax, y.indptr[:-1])
+        # first argmax of each row
+        y_ind_ext = np.append(y.indices, [0])
+        y_i_argmax = y_ind_ext[y_i_all_argmax[index_first_argmax]]
+        # Handle rows of all 0
+        y_i_argmax[np.where(row_nnz == 0)[0]] = 0
+
+        # Handles rows with max of 0 that contain negative numbers
+        samples = np.arange(n_samples)[(row_nnz > 0) &
+                                       (row_max.ravel() == 0)]
+        for i in samples:
+            ind = y.indices[y.indptr[i]:y.indptr[i+1]]
+            y_i_argmax[i] = classes[np.setdiff1d(outputs, ind)][0]
+
+        return classes[y_i_argmax]
     else:
-        y = column_or_1d(y)
+        return classes.take(y.argmax(axis=1), mode="clip")
 
-        if len(classes) == 2:
-            Y[y == classes[1], 0] = pos_label
-            return Y
 
-        elif len(classes) >= 2:
-            for i, k in enumerate(classes):
-                Y[y == k, i] = pos_label
-            return Y
+def _inverse_binarize_thresholding(y, output_type, classes, threshold):
+    """Inverse label binarization transformation using thresholding."""
+
+    if output_type == "binary" and y.ndim == 2 and y.shape[1] > 2:
+        raise ValueError("output_type='binary', but y.shape = {0}".
+                         format(y.shape))
+
+    if output_type != "binary" and y.shape[1] != len(classes):
+        raise ValueError("The number of class is not equal to the number of "
+                         "dimension of y.")
 
+    classes = np.asarray(classes)
+
+    # Perform thresholding
+    if sp.issparse(y):
+        if threshold > 0:
+            if y.format not in ('csr', 'csc'):
+                y = y.tocsr()
+            y.data = np.array(y.data > threshold, dtype=np.int)
+            y.eliminate_zeros()
         else:
-            # Only one class, returns a matrix with all negative labels.
-            return Y
+            y = np.array(y.toarray() > threshold, dtype=np.int)
+    else:
+        y = np.array(y > threshold, dtype=np.int)
+
+    # Inverse transform data
+    if output_type == "binary":
+        if y.ndim == 2 and y.shape[1] == 2:
+            return classes[y[:, 1]]
+        else:
+            if len(classes) == 1:
+                y = np.empty(len(y), dtype=classes.dtype)
+                y.fill(classes[0])
+                return y
+            else:
+                return classes[y.ravel()]
+
+    elif output_type == "multilabel-indicator":
+        return y
+
+    elif output_type == "multilabel-sequences":
+        warnings.warn('Direct support for sequence of sequences multilabel '
+                      'representation will be unavailable from version 0.17. '
+                      'Use sklearn.preprocessing.MultiLabelBinarizer to '
+                      'convert to a label indicator representation.',
+                      DeprecationWarning)
+        mlb = MultiLabelBinarizer(classes=classes).fit([])
+        return mlb.inverse_transform(y)
+
+    else:
+        raise ValueError("{0} format is not supported".format(output_type))
 
 
 class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
@@ -493,8 +659,9 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
     ['comedy', 'sci-fi', 'thriller']
 
     """
-    def __init__(self, classes=None):
+    def __init__(self, classes=None, sparse_output=False):
         self.classes = classes
+        self.sparse_output = sparse_output
 
     def fit(self, y):
         """Fit the label sets binarizer, storing `classes_`
@@ -531,7 +698,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
 
         Returns
         -------
-        y_indicator : array, shape (n_samples, n_classes)
+        y_indicator : array or CSR matrix, shape (n_samples, n_classes)
             A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]` is in
             `y[i]`, and 0 otherwise.
         """
@@ -552,7 +719,11 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
         class_mapping[:] = tmp
         self.classes_, inverse = np.unique(class_mapping, return_inverse=True)
         yt.indices = np.take(inverse, yt.indices)
-        return yt.toarray()
+
+        if not self.sparse_output:
+            yt = yt.toarray()
+
+        return yt
 
     def transform(self, y):
         """Transform the given label sets
@@ -566,13 +737,17 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
 
         Returns
         -------
-        y_indicator : array, shape (n_samples, n_classes)
+        y_indicator : array or CSR matrix, shape (n_samples, n_classes)
             A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]` is in
             `y[i]`, and 0 otherwise.
         """
         class_to_index = dict(zip(self.classes_, range(len(self.classes_))))
         yt = self._transform(y, class_to_index)
-        return yt.toarray()
+
+        if not self.sparse_output:
+            yt = yt.toarray()
+
+        return yt
 
     def _transform(self, y, class_mapping):
         """Transforms the label sets with a given mapping
@@ -603,7 +778,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
 
         Parameters
         ----------
-        yt : array of shape (n_samples, n_classes)
+        yt : array or sparse matrix of shape (n_samples, n_classes)
             A matrix containing only 1s ands 0s.
 
         Returns
@@ -612,13 +787,20 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
             The set of labels for each sample such that `y[i]` consists of
             `classes_[j]` for each `yt[i, j] == 1`.
         """
-        yt = np.asarray(yt)
         if yt.shape[1] != len(self.classes_):
             raise ValueError('Expected indicator for {0} classes, but got {1}'
                              .format(len(self.classes_), yt.shape[1]))
-        unexpected = np.setdiff1d(yt, [0, 1])
-        if len(unexpected) > 0:
-            raise ValueError('Expected only 0s and 1s in label indicator. '
-                             'Also got {0}'.format(unexpected))
 
-        return [tuple(self.classes_.compress(indicators)) for indicators in yt]
+        if sp.issparse(yt):
+            yt = yt.tocsr()
+            if len(yt.data) != 0 and len(np.setdiff1d(yt.data, [0, 1])) > 0:
+                raise ValueError('Expected only 0s and 1s in label indicator.')
+            return [tuple(self.classes_.take(yt.indices[start:end]))
+                    for start, end in zip(yt.indptr[:-1], yt.indptr[1:])]
+        else:
+            unexpected = np.setdiff1d(yt, [0, 1])
+            if len(unexpected) > 0:
+                raise ValueError('Expected only 0s and 1s in label indicator. '
+                                 'Also got {0}'.format(unexpected))
+            return [tuple(self.classes_.compress(indicators)) for indicators
+                    in yt]
diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py
index 17619cab7915aef1c3da3934ca76851df0ef3228..2cc786605f3808b1214573c630db1eb04856831e 100644
--- a/sklearn/preprocessing/tests/test_label.py
+++ b/sklearn/preprocessing/tests/test_label.py
@@ -1,5 +1,14 @@
 import numpy as np
 
+from scipy.sparse import issparse
+from scipy.sparse import coo_matrix
+from scipy.sparse import csc_matrix
+from scipy.sparse import csr_matrix
+from scipy.sparse import dok_matrix
+from scipy.sparse import lil_matrix
+
+from sklearn.utils.multiclass import type_of_target
+
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_equal
@@ -7,6 +16,7 @@ from sklearn.utils.testing import assert_raises
 from sklearn.utils.testing import assert_true
 from sklearn.utils.testing import assert_false
 from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_warns_message
 from sklearn.utils.testing import ignore_warnings
 
 from sklearn.preprocessing.label import LabelBinarizer
@@ -14,6 +24,8 @@ from sklearn.preprocessing.label import MultiLabelBinarizer
 from sklearn.preprocessing.label import LabelEncoder
 from sklearn.preprocessing.label import label_binarize
 
+from sklearn.preprocessing.label import _inverse_binarize_thresholding
+from sklearn.preprocessing.label import _inverse_binarize_multiclass
 
 from sklearn import datasets
 from sklearn.linear_model.stochastic_gradient import SGDClassifier
@@ -30,14 +42,28 @@ def toarray(a):
 def test_label_binarizer():
     lb = LabelBinarizer()
 
+    # one-class case defaults to negative label
+    inp = ["pos", "pos", "pos", "pos"]
+    expected = np.array([[0, 0, 0, 0]]).T
+    got = lb.fit_transform(inp)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
+    assert_array_equal(lb.classes_, ["pos"])
+    assert_array_equal(expected, got)
+    assert_array_equal(lb.inverse_transform(got), inp)
+
     # two-class case
     inp = ["neg", "pos", "pos", "neg"]
     expected = np.array([[0, 1, 1, 0]]).T
     got = lb.fit_transform(inp)
-    assert_false(lb.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
     assert_array_equal(lb.classes_, ["neg", "pos"])
     assert_array_equal(expected, got)
-    assert_array_equal(lb.inverse_transform(got), inp)
+
+    to_invert = np.array([[1, 0],
+                          [0, 1],
+                          [0, 1],
+                          [1, 0]])
+    assert_array_equal(lb.inverse_transform(to_invert), inp)
 
     # multi-class case
     inp = ["spam", "ham", "eggs", "ham", "0"]
@@ -48,7 +74,7 @@ def test_label_binarizer():
                          [1, 0, 0, 0]])
     got = lb.fit_transform(inp)
     assert_array_equal(lb.classes_, ['0', 'eggs', 'ham', 'spam'])
-    assert_false(lb.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
     assert_array_equal(expected, got)
     assert_array_equal(lb.inverse_transform(got), inp)
 
@@ -70,10 +96,13 @@ def test_label_binarizer_column_y():
     out_2 = lb_2.fit_transform(inp_array)
 
     assert_array_equal(out_1, multilabel_indicator)
-    assert_true(lb_1.multilabel_)
+    assert_true(assert_warns(DeprecationWarning, getattr, lb_1, "multilabel_"))
+    assert_false(assert_warns(DeprecationWarning, getattr, lb_1,
+                              "indicator_matrix_"))
 
     assert_array_equal(out_2, binaryclass_array)
-    assert_false(lb_2.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb_2,
+                              "multilabel_"))
 
     # second for multiclass classification vs multi-label with multiple
     # classes
@@ -90,23 +119,26 @@ def test_label_binarizer_column_y():
     out_2 = lb_2.fit_transform(inp_array)
 
     assert_array_equal(out_1, out_2)
-    assert_true(lb_1.multilabel_)
+    assert_true(assert_warns(DeprecationWarning, getattr, lb_1, "multilabel_"))
 
     assert_array_equal(out_2, indicator)
-    assert_false(lb_2.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb_2,
+                              "multilabel_"))
 
 
 def test_label_binarizer_set_label_encoding():
-    lb = LabelBinarizer(neg_label=-2, pos_label=2)
+    lb = LabelBinarizer(neg_label=-2, pos_label=0)
 
-    # two-class case
+    # two-class case with pos_label=0
     inp = np.array([0, 1, 1, 0])
-    expected = np.array([[-2, 2, 2, -2]]).T
+    expected = np.array([[-2, 0, 0, -2]]).T
     got = lb.fit_transform(inp)
-    assert_false(lb.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
     assert_array_equal(expected, got)
     assert_array_equal(lb.inverse_transform(got), inp)
 
+    lb = LabelBinarizer(neg_label=-2, pos_label=2)
+
     # multi-class case
     inp = np.array([3, 2, 1, 2, 0])
     expected = np.array([[-2, -2, -2, +2],
@@ -115,59 +147,53 @@ def test_label_binarizer_set_label_encoding():
                          [-2, -2, +2, -2],
                          [+2, -2, -2, -2]])
     got = lb.fit_transform(inp)
-    assert_false(lb.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
     assert_array_equal(expected, got)
     assert_array_equal(lb.inverse_transform(got), inp)
 
 
-def test_label_binarizer_multilabel():
-    lb = LabelBinarizer()
-
-    # test input as lists of tuples
-    inp = [(2, 3), (1,), (1, 2)]
-    indicator_mat = np.array([[0, 1, 1],
-                              [1, 0, 0],
-                              [1, 1, 0]])
-    got = assert_warns(DeprecationWarning, lb.fit_transform, inp)
-    assert_true(lb.multilabel_)
-    assert_array_equal(indicator_mat, got)
-    assert_equal(lb.inverse_transform(got), inp)
-
-    # test input as label indicator matrix
-    lb.fit(indicator_mat)
-    assert_array_equal(indicator_mat,
-                       lb.inverse_transform(indicator_mat))
-
-    # regression test for the two-class multilabel case
-    lb = LabelBinarizer()
-    inp = [[1, 0], [0], [1], [0, 1]]
-    expected = np.array([[1, 1],
-                         [1, 0],
-                         [0, 1],
-                         [1, 1]])
-    got = assert_warns(DeprecationWarning, lb.fit_transform, inp)
-    assert_true(lb.multilabel_)
-    assert_array_equal(expected, got)
-    assert_equal([set(x) for x in lb.inverse_transform(got)],
-                 [set(x) for x in inp])
-
-
+@ignore_warnings
 def test_label_binarizer_errors():
     """Check that invalid arguments yield ValueError"""
     one_class = np.array([0, 0, 0, 0])
     lb = LabelBinarizer().fit(one_class)
-    assert_false(lb.multilabel_)
+    assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_"))
 
-    multi_label = np.array([[0, 0, 1, 0], [1, 0, 1, 0]])
+    multi_label = [(2, 3), (0,), (0, 2)]
     assert_raises(ValueError, lb.transform, multi_label)
 
     lb = LabelBinarizer()
     assert_raises(ValueError, lb.transform, [])
     assert_raises(ValueError, lb.inverse_transform, [])
 
+    y = np.array([[0, 1, 0], [1, 1, 1]])
+    classes = np.arange(3)
+    assert_raises(ValueError, label_binarize, y, classes, multilabel=True,
+                  neg_label=2, pos_label=1)
+    assert_raises(ValueError, label_binarize, y, classes, multilabel=True,
+                  neg_label=2, pos_label=2)
+
     assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=1)
     assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=2)
 
+    assert_raises(ValueError, LabelBinarizer, neg_label=1, pos_label=2,
+                  sparse_output=True)
+
+    # Fail on y_type
+    assert_raises(ValueError, _inverse_binarize_thresholding,
+                  y=csr_matrix([[1, 2], [2, 1]]), output_type="foo",
+                  classes=[1, 2], threshold=0)
+
+    # Fail on the number of classes
+    assert_raises(ValueError, _inverse_binarize_thresholding,
+                  y=csr_matrix([[1, 2], [2, 1]]), output_type="foo",
+                  classes=[1, 2, 3], threshold=0)
+
+    # Fail on the dimension of 'binary'
+    assert_raises(ValueError, _inverse_binarize_thresholding,
+                  y=np.array([[1, 2, 3], [2, 1, 3]]), output_type="binary",
+                  classes=[1, 2, 3], threshold=0)
+
 
 def test_label_encoder():
     """Test LabelEncoder's transform and inverse_transform methods"""
@@ -192,19 +218,6 @@ def test_label_encoder_fit_transform():
     assert_array_equal(ret, [1, 1, 2, 0])
 
 
-def test_label_encoder_string_labels():
-    """Test LabelEncoder's transform and inverse_transform methods with
-    non-numeric labels"""
-    le = LabelEncoder()
-    le.fit(["paris", "paris", "tokyo", "amsterdam"])
-    assert_array_equal(le.classes_, ["amsterdam", "paris", "tokyo"])
-    assert_array_equal(le.transform(["tokyo", "tokyo", "paris"]),
-                       [2, 2, 1])
-    assert_array_equal(le.inverse_transform([2, 2, 1]),
-                       ["tokyo", "tokyo", "paris"])
-    assert_raises(ValueError, le.transform, ["london"])
-
-
 def test_label_encoder_errors():
     """Check that invalid arguments yield ValueError"""
     le = LabelEncoder()
@@ -212,52 +225,44 @@ def test_label_encoder_errors():
     assert_raises(ValueError, le.inverse_transform, [])
 
 
-def test_label_binarizer_iris():
-    lb = LabelBinarizer()
-    Y = lb.fit_transform(iris.target)
-    clfs = [SGDClassifier().fit(iris.data, Y[:, k])
-            for k in range(len(lb.classes_))]
-    Y_pred = np.array([clf.decision_function(iris.data) for clf in clfs]).T
-    y_pred = lb.inverse_transform(Y_pred)
-    accuracy = np.mean(iris.target == y_pred)
-    y_pred2 = SGDClassifier().fit(iris.data, iris.target).predict(iris.data)
-    accuracy2 = np.mean(iris.target == y_pred2)
-    assert_almost_equal(accuracy, accuracy2)
-
-
-def test_label_binarizer_multilabel_unlabeled():
-    """Check that LabelBinarizer can handle an unlabeled sample"""
-    lb = LabelBinarizer()
-    y = [[1, 2], [1], []]
-    Y = np.array([[1, 1],
-                  [1, 0],
-                  [0, 0]])
-    assert_array_equal(assert_warns(DeprecationWarning,
-                                    lb.fit_transform, y), Y)
-
-
-def test_label_binarize_with_multilabel_indicator():
-    """Check that passing a binary indicator matrix is not noop"""
-
-    classes = np.arange(3)
-    neg_label = -1
-    pos_label = 2
-
-    y = np.array([[0, 1, 0], [1, 1, 1]])
-    expected = np.array([[-1, 2, -1], [2, 2, 2]])
-
-    # With label binarize
-    output = label_binarize(y, classes, multilabel=True, neg_label=neg_label,
-                            pos_label=pos_label)
-    assert_array_equal(output, expected)
-
-    # With the transformer
-    lb = LabelBinarizer(pos_label=pos_label, neg_label=neg_label)
-    output = lb.fit_transform(y)
-    assert_array_equal(output, expected)
+def test_sparse_output_mutlilabel_binarizer():
+    # test input as iterable of iterables
+    inputs = [
+        lambda: [(2, 3), (1,), (1, 2)],
+        lambda: (set([2, 3]), set([1]), set([1, 2])),
+        lambda: iter([iter((2, 3)), iter((1,)), set([1, 2])]),
+    ]
+    indicator_mat = np.array([[0, 1, 1],
+                              [1, 0, 0],
+                              [1, 1, 0]])
 
-    output = lb.fit(y).transform(y)
-    assert_array_equal(output, expected)
+    inverse = inputs[0]()
+    for sparse_output in [True, False]:
+        for inp in inputs:
+            # With fit_tranform
+            mlb = MultiLabelBinarizer(sparse_output=sparse_output)
+            got = mlb.fit_transform(inp())
+            assert_equal(issparse(got), sparse_output)
+            if sparse_output:
+                got = got.toarray()
+            assert_array_equal(indicator_mat, got)
+            assert_array_equal([1, 2, 3], mlb.classes_)
+            assert_equal(mlb.inverse_transform(got), inverse)
+
+            # With fit
+            mlb = MultiLabelBinarizer(sparse_output=sparse_output)
+            got = mlb.fit(inp()).transform(inp())
+            assert_equal(issparse(got), sparse_output)
+            if sparse_output:
+                got = got.toarray()
+            assert_array_equal(indicator_mat, got)
+            assert_array_equal([1, 2, 3], mlb.classes_)
+            assert_equal(mlb.inverse_transform(got), inverse)
+
+    assert_raises(ValueError, mlb.inverse_transform,
+                  csr_matrix(np.array([[0, 1, 1],
+                                       [2, 0, 0],
+                                       [1, 1, 0]])))
 
 
 def test_mutlilabel_binarizer():
@@ -399,3 +404,140 @@ def test_multilabel_binarizer_inverse_validation():
     # Wrong shape
     assert_raises(ValueError, mlb.inverse_transform, np.array([[1]]))
     assert_raises(ValueError, mlb.inverse_transform, np.array([[1, 1, 1]]))
+
+
+def test_label_binarize_with_class_order():
+    out = label_binarize([1, 6], classes=[1, 2, 4, 6])
+    expected = np.array([[1, 0, 0, 0], [0, 0, 0, 1]])
+    assert_array_equal(out, expected)
+
+    # Modified class order
+    out = label_binarize([1, 6], classes=[1, 6, 4, 2])
+    expected = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
+    assert_array_equal(out, expected)
+
+
+def check_binarized_results(y, classes, pos_label, neg_label, expected):
+    for sparse_output in [True, False]:
+        if ((pos_label == 0 or neg_label != 0) and sparse_output):
+            assert_raises(ValueError, label_binarize, y, classes,
+                          neg_label=neg_label, pos_label=pos_label,
+                          sparse_output=sparse_output)
+            continue
+
+        # check label_binarize
+        binarized = label_binarize(y, classes, neg_label=neg_label,
+                                   pos_label=pos_label,
+                                   sparse_output=sparse_output)
+        assert_array_equal(toarray(binarized), expected)
+        assert_equal(issparse(binarized), sparse_output)
+
+        # check inverse
+        y_type = type_of_target(y)
+        if y_type == "multiclass":
+            inversed = _inverse_binarize_multiclass(binarized, classes=classes)
+
+        else:
+            inversed = _inverse_binarize_thresholding(binarized,
+                                                      output_type=y_type,
+                                                      classes=classes,
+                                                      threshold=((neg_label +
+                                                                 pos_label) /
+                                                                 2.))
+
+        assert_array_equal(toarray(inversed), toarray(y))
+
+        # Check label binarizer
+        lb = LabelBinarizer(neg_label=neg_label, pos_label=pos_label,
+                            sparse_output=sparse_output)
+        binarized = lb.fit_transform(y)
+        assert_array_equal(toarray(binarized), expected)
+        assert_equal(issparse(binarized), sparse_output)
+        inverse_output = lb.inverse_transform(binarized)
+        assert_array_equal(toarray(inverse_output), toarray(y))
+        assert_equal(issparse(inverse_output), issparse(y))
+
+
+def test_label_binarize_binary():
+    y = [0, 1, 0]
+    classes = [0, 1]
+    pos_label = 2
+    neg_label = -1
+    expected = np.array([[2, -1], [-1, 2], [2, -1]])[:, 1].reshape((-1, 1))
+
+    yield check_binarized_results, y, classes, pos_label, neg_label, expected
+
+
+def test_label_binarize_multiclass():
+    y = [0, 1, 2]
+    classes = [0, 1, 2]
+    pos_label = 2
+    neg_label = 0
+    expected = 2 * np.eye(3)
+
+    yield check_binarized_results, y, classes, pos_label, neg_label, expected
+
+    assert_raises(ValueError, label_binarize, y, classes, neg_label=-1,
+                  pos_label=pos_label, sparse_output=True)
+
+
+def test_label_binarize_multilabel():
+    y_seq = [(1,), (0, 1, 2), tuple()]
+    y_ind = np.array([[0, 1, 0], [1, 1, 1], [0, 0, 0]])
+    classes = [0, 1, 2]
+    pos_label = 2
+    neg_label = 0
+    expected = pos_label * y_ind
+    y_sparse = [sparse_matrix(y_ind)
+                for sparse_matrix in [coo_matrix, csc_matrix, csr_matrix,
+                                      dok_matrix, lil_matrix]]
+
+    for y in [y_ind] + y_sparse:
+        yield (check_binarized_results, y, classes, pos_label, neg_label,
+               expected)
+
+    deprecation_message = ("Direct support for sequence of sequences " +
+                           "multilabel representation will be unavailable " +
+                           "from version 0.17. Use sklearn.preprocessing." +
+                           "MultiLabelBinarizer to convert to a label " +
+                           "indicator representation.")
+
+    assert_warns_message(DeprecationWarning, deprecation_message,
+                         check_binarized_results, y_seq, classes, pos_label,
+                         neg_label, expected)
+
+    assert_raises(ValueError, label_binarize, y, classes, neg_label=-1,
+                  pos_label=pos_label, sparse_output=True)
+
+
+def test_deprecation_inverse_binarize_thresholding():
+    deprecation_message = ("Direct support for sequence of sequences " +
+                           "multilabel representation will be unavailable " +
+                           "from version 0.17. Use sklearn.preprocessing." +
+                           "MultiLabelBinarizer to convert to a label " +
+                           "indicator representation.")
+
+    assert_warns_message(DeprecationWarning, deprecation_message,
+                         _inverse_binarize_thresholding,
+                         y=csr_matrix([[1, 0], [0, 1]]),
+                         output_type="multilabel-sequences",
+                         classes=[1, 2], threshold=0)
+
+
+def test_invalid_input_label_binarize():
+    assert_raises(ValueError, label_binarize, [0.5, 2], classes=[1, 2])
+    assert_raises(ValueError, label_binarize, [0, 2], classes=[0, 2],
+                  pos_label=0, neg_label=1)
+    assert_raises(ValueError, label_binarize, [1, 2], classes=[0, 2])
+
+
+def test_inverse_binarize_multiclass():
+    got = _inverse_binarize_multiclass(csr_matrix([[0, 1, 0],
+                                                   [-1, 0, -1],
+                                                   [0, 0, 0]]),
+                                       np.arange(3))
+    assert_array_equal(got, np.array([1, 1, 0]))
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()
diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py
index 4dc4bd5ee8a8195bbfe646a4aafeea619879f563..eb8e08ecd6191aaccbbccabeae68ae5778668475 100644
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -1,4 +1,4 @@
-# Author: Arnaud Joly, Joel Nothman
+# Author: Arnaud Joly, Joel Nothman, Hamzeh Alsalhi
 #
 # License: BSD 3 clause
 """
@@ -10,6 +10,11 @@ from collections import Sequence
 from itertools import chain
 import warnings
 
+from scipy.sparse import issparse
+from scipy.sparse.base import spmatrix
+from scipy.sparse import dok_matrix
+from scipy.sparse import lil_matrix
+
 import numpy as np
 
 from ..externals.six import string_types
@@ -139,9 +144,18 @@ def is_label_indicator_matrix(y):
     """
     if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
         return False
-    labels = np.unique(y)
-    return len(labels) <= 2 and (y.dtype.kind in 'biu'  # bool, int, uint
-                                 or _is_integral_float(labels))
+
+    if issparse(y):
+        if isinstance(y, (dok_matrix, lil_matrix)):
+            y = y.tocsr()
+        return (len(y.data) == 0 or np.ptp(y.data) == 0 and
+                (y.dtype.kind in 'biu' or  # bool, int, uint
+                 _is_integral_float(np.unique(y.data))))
+    else:
+        labels = np.unique(y)
+
+        return len(labels) < 3 and (y.dtype.kind in 'biu' or  # bool, int, uint
+                                    _is_integral_float(labels))
 
 
 def is_sequence_of_sequences(y):
@@ -163,7 +177,7 @@ def is_sequence_of_sequences(y):
     try:
         out = (not isinstance(y[0], np.ndarray) and isinstance(y[0], Sequence)
                and not isinstance(y[0], string_types))
-    except IndexError:
+    except (IndexError, TypeError):
         return False
     if out:
         warnings.warn('Direct support for sequence of sequences multilabel '
@@ -256,7 +270,7 @@ def type_of_target(y):
     'multilabel-indicator'
     """
     # XXX: is there a way to duck-type this condition?
-    valid = (isinstance(y, (np.ndarray, Sequence))
+    valid = (isinstance(y, (np.ndarray, Sequence, spmatrix))
              and not isinstance(y, string_types))
     if not valid:
         raise ValueError('Expected array-like (array or non-string sequence), '
diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py
index e3abdfeea5217f0684a411cb0f54a5d04fc4c6f2..688803e45180233c82e3af8fa863d1efb0331905 100644
--- a/sklearn/utils/tests/test_multiclass.py
+++ b/sklearn/utils/tests/test_multiclass.py
@@ -4,6 +4,13 @@ from functools import partial
 from sklearn.externals.six.moves import xrange
 from sklearn.externals.six import iteritems
 
+from scipy.sparse import issparse
+from scipy.sparse import csc_matrix
+from scipy.sparse import csr_matrix
+from scipy.sparse import coo_matrix
+from scipy.sparse import dok_matrix
+from scipy.sparse import lil_matrix
+
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_true
@@ -20,17 +27,20 @@ from sklearn.utils.multiclass import type_of_target
 
 EXAMPLES = {
     'multilabel-indicator': [
-        np.random.RandomState(42).randint(2, size=(10, 10)),
-        np.array([[0, 1], [1, 0]]),
-        np.array([[0, 1], [1, 0]], dtype=np.bool),
-        np.array([[0, 1], [1, 0]], dtype=np.int8),
-        np.array([[0, 1], [1, 0]], dtype=np.uint8),
-        np.array([[0, 1], [1, 0]], dtype=np.float),
-        np.array([[0, 1], [1, 0]], dtype=np.float32),
-        np.array([[0, 0], [0, 0]]),
+        # valid when the data is formated as sparse or dense, identified
+        # by CSR format when the testing takes place
+        csr_matrix(np.random.RandomState(42).randint(2, size=(10, 10))),
+        csr_matrix(np.array([[0, 1], [1, 0]])),
+        csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.bool)),
+        csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.int8)),
+        csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.uint8)),
+        csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float)),
+        csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float32)),
+        csr_matrix(np.array([[0, 0], [0, 0]])),
+        csr_matrix(np.array([[0, 1]])),
+        # Only valid when data is dense
         np.array([[-1, 1], [1, -1]]),
         np.array([[-3, 3], [3, -3]]),
-        np.array([[0, 1]]),
     ],
     'multilabel-sequences': [
         [[0, 1]],
@@ -196,14 +206,14 @@ def test_unique_labels_non_specific():
 
 @ignore_warnings
 def test_unique_labels_mixed_types():
-    #Mix of multilabel-indicator and multilabel-sequences
+    # Mix of multilabel-indicator and multilabel-sequences
     mix_multilabel_format = product(EXAMPLES["multilabel-indicator"],
                                     EXAMPLES["multilabel-sequences"])
     for y_multilabel, y_multiclass in mix_multilabel_format:
         assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel)
         assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass)
 
-    #Mix with binary or multiclass and multilabel
+    # Mix with binary or multiclass and multilabel
     mix_clf_format = product(EXAMPLES["multilabel-indicator"] +
                              EXAMPLES["multilabel-sequences"],
                              EXAMPLES["multiclass"] +
@@ -239,14 +249,43 @@ def test_is_multilabel():
 
 def test_is_label_indicator_matrix():
     for group, group_examples in iteritems(EXAMPLES):
-        if group == 'multilabel-indicator':
-            assert_, exp = assert_true, 'True'
+        if group in ['multilabel-indicator']:
+            dense_assert_, dense_exp = assert_true, 'True'
         else:
-            assert_, exp = assert_false, 'False'
+            dense_assert_, dense_exp = assert_false, 'False'
+
         for example in group_examples:
-            assert_(is_label_indicator_matrix(example),
-                    msg='is_label_indicator_matrix(%r) should be %s'
-                    % (example, exp))
+            # Only mark explicitly defined sparse examples as valid sparse
+            # multilabel-indicators
+            if group == 'multilabel-indicator' and issparse(example):
+                sparse_assert_, sparse_exp = assert_true, 'True'
+            else:
+                sparse_assert_, sparse_exp = assert_false, 'False'
+
+            if (issparse(example) or
+                (isinstance(example, np.ndarray) and
+                 example.ndim == 2 and
+                 example.dtype.kind in 'biuf' and
+                 example.shape[1] > 0)):
+                    examples_sparse = [sparse_matrix(example)
+                                       for sparse_matrix in [coo_matrix,
+                                                             csc_matrix,
+                                                             csr_matrix,
+                                                             dok_matrix,
+                                                             lil_matrix]]
+                    for exmpl_sparse in examples_sparse:
+                        sparse_assert_(is_label_indicator_matrix(exmpl_sparse),
+                                       msg=('is_label_indicator_matrix(%r)'
+                                       ' should be %s')
+                                       % (exmpl_sparse, sparse_exp))
+
+            # Densify sparse examples before testing
+            if issparse(example):
+                example = example.toarray()
+
+            dense_assert_(is_label_indicator_matrix(example),
+                          msg='is_label_indicator_matrix(%r) should be %s'
+                          % (example, dense_exp))
 
 
 def test_is_sequence_of_sequences():
@@ -274,3 +313,8 @@ def test_type_of_target():
 
     for example in NON_ARRAY_LIKE_EXAMPLES:
         assert_raises(ValueError, type_of_target, example)
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()