diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index d5452b1e3856fe2f3854c52ef00578e6df4378ef..74070a2c09f7ce7f0cbb4aacf6e7f66803814f58 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -370,7 +370,7 @@ matrix from a list of multi-class labels:: >>> lb = preprocessing.LabelBinarizer() >>> lb.fit([1, 2, 6, 4, 2]) - LabelBinarizer(neg_label=0, pos_label=1) + LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False) >>> lb.classes_ array([1, 2, 4, 6]) >>> lb.transform([1, 6]) diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index f7d513e4581743a31642f084c81a572d330573a3..3d2b188184d7e6403c6f192537ca37bedb96be86 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -571,7 +571,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge): """ self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) Y = self._label_binarizer.fit_transform(y) - if not self._label_binarizer.multilabel_: + if not self._label_binarizer.y_type_.startswith('multilabel'): y = column_or_1d(y, warn=True) if self.class_weight: diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 851dc5d9738d9b8bdf3929bb132e9d12458a4cec..6baff66c65f6290d10228860848ae0ce8836efbb 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -3,11 +3,13 @@ # Olivier Grisel <olivier.grisel@ensta.org> # Andreas Mueller <amueller@ais.uni-bonn.de> # Joel Nothman <joel.nothman@gmail.com> +# Hamzeh Alsalhi <ha258@cornell.edu> # License: BSD 3 clause from collections import defaultdict import itertools import array +import warnings import numpy as np import scipy.sparse as sp @@ -15,6 +17,7 @@ import scipy.sparse as sp from ..base import BaseEstimator, TransformerMixin from ..utils.fixes import np_version +from ..utils.fixes import sparse_min_max from ..utils import deprecated, column_or_1d from ..utils.multiclass import unique_labels @@ -188,6 +191,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): pos_label : int (default: 1) Value with which positive labels must be encoded. + sparse_output : boolean (default: False) + True if the returned array from transform is desired to be in sparse + CSR format. + Attributes ---------- `classes_` : array of shape [n_class] @@ -195,18 +202,27 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): `multilabel_` : boolean True if the transformer was fitted on a multilabel rather than a - multiclass set of labels. + multiclass set of labels. The multilabel_ attribute is deprecated + and will be removed in 0.18 + + `sparse_input_` : boolean, + True if the input data to transform is given as a sparse matrix, False + otherwise. + + `indicator_matrix_` : str + 'sparse' when the input data to tansform is a multilable-indicator and + is sparse, None otherwise. The indicator_matrix_ attribute is + deprecated as of version 0.16 and will be removed in 0.18 + Examples -------- >>> from sklearn import preprocessing >>> lb = preprocessing.LabelBinarizer() >>> lb.fit([1, 2, 6, 4, 2]) - LabelBinarizer(neg_label=0, pos_label=1) + LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False) >>> lb.classes_ array([1, 2, 4, 6]) - >>> lb.multilabel_ - False >>> lb.transform([1, 6]) array([[1, 0, 0, 0], [0, 0, 0, 1]]) @@ -221,11 +237,9 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): >>> import numpy as np >>> lb.fit(np.array([[0, 1, 1], [1, 0, 0]])) - LabelBinarizer(neg_label=0, pos_label=1) + LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False) >>> lb.classes_ array([0, 1, 2]) - >>> lb.multilabel_ - True See also -------- @@ -233,12 +247,20 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): LabelBinarizer with fixed classes. """ - def __init__(self, neg_label=0, pos_label=1): + def __init__(self, neg_label=0, pos_label=1, sparse_output=False): if neg_label >= pos_label: - raise ValueError("neg_label must be strictly less than pos_label.") + raise ValueError("neg_label={0} must be strictly less than " + "pos_label={1}.".format(neg_label, pos_label)) + + if sparse_output and (pos_label == 0 or neg_label != 0): + raise ValueError("Sparse binarization is only supported with non " + "zero pos_label and zero neg_label, got " + "pos_label={0} and neg_label={1}" + "".format(pos_label, neg_label)) self.neg_label = neg_label self.pos_label = pos_label + self.sparse_output = sparse_output @property @deprecated("Attribute `multilabel` was renamed to `multilabel_` in " @@ -246,6 +268,20 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): def multilabel(self): return self.multilabel_ + @property + @deprecated("Attribute indicator_matrix_ is deprecated and will be " + "removed in 0.17. Use 'y_type_ == 'multilabel-indicator'' " + "instead") + def indicator_matrix_(self): + return self.y_type_ == 'multilabel-indicator' + + @property + @deprecated("Attribute multilabel_ is deprecated and will be removed " + "in 0.17. Use 'y_type_.startswith('multilabel')' " + "instead") + def multilabel_(self): + return self.y_type_.startswith('multilabel') + def _check_fitted(self): if not hasattr(self, "classes_"): raise ValueError("LabelBinarizer was not fitted yet.") @@ -263,13 +299,9 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): ------- self : returns an instance of self. """ - y_type = type_of_target(y) - self.multilabel_ = y_type.startswith('multilabel') - if self.multilabel_: - self.indicator_matrix_ = y_type == 'multilabel-indicator' - + self.y_type_ = type_of_target(y) + self.sparse_input_ = sp.issparse(y) self.classes_ = unique_labels(y) - return self def transform(self, y): @@ -280,35 +312,30 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): Parameters ---------- - y : numpy array of shape (n_samples,) or (n_samples, n_classes) - Target values. The 2-d matrix should only contain 0 and 1, - represents multilabel classification. + y : numpy array or sparse matrix of shape (n_samples,) or + (n_samples, n_classes) Target values. The 2-d matrix should only + contain 0 and 1, represents multilabel classification. Sparse + matrix can be CSR, CSC, COO, DOK, or LIL. Returns ------- - Y : numpy array of shape [n_samples, n_classes] + Y : numpy array or CSR matrix of shape [n_samples, n_classes] Shape will be [n_samples, 1] for binary problems. """ self._check_fitted() - - y_is_multilabel = type_of_target(y).startswith('multilabel') - - if y_is_multilabel and not self.multilabel_: - raise ValueError("The object was not fitted with multilabel" - " input.") - return label_binarize(y, self.classes_, - multilabel=self.multilabel_, pos_label=self.pos_label, - neg_label=self.neg_label) + neg_label=self.neg_label, + sparse_output=self.sparse_output) def inverse_transform(self, Y, threshold=None): """Transform binary labels back to multi-class labels Parameters ---------- - Y : numpy array of shape [n_samples, n_classes] - Target values. + Y : numpy array or sparse matrix with shape [n_samples, n_classes] + Target values. All sparse matrices are converted to CSR before + inverse transformation. threshold : float or None Threshold used in the binary and multi-label cases. @@ -323,9 +350,7 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): Returns ------- - y : numpy array of shape (n_samples,) or (n_samples, n_classes) - Target values. The 2-d matrix should only contain 0 and 1, - represents multilabel classification. + y : numpy array or CSR matrix of shape [n_samples] Target values. Notes ----- @@ -338,30 +363,24 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): self._check_fitted() if threshold is None: - half = (self.pos_label - self.neg_label) / 2.0 - threshold = self.neg_label + half - - if self.multilabel_: - Y = np.array(Y > threshold, dtype=int) - # Return the predictions in the same format as in fit - if self.indicator_matrix_: - # Label indicator matrix format - return Y - else: - # Lists of tuples format - mlb = MultiLabelBinarizer(classes=self.classes_).fit([]) - return mlb.inverse_transform(Y) - - if len(Y.shape) == 1 or Y.shape[1] == 1: - y = np.array(Y.ravel() > threshold, dtype=int) + threshold = (self.pos_label + self.neg_label) / 2. + if self.y_type_ == "multiclass": + y_inv = _inverse_binarize_multiclass(Y, self.classes_) else: - y = Y.argmax(axis=1) + y_inv = _inverse_binarize_thresholding(Y, self.y_type_, + self.classes_, threshold) - return self.classes_[y] + if self.sparse_input_: + y_inv = sp.csr_matrix(y_inv) + elif sp.issparse(y_inv): + y_inv = y_inv.toarray() + return y_inv -def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1): + +def label_binarize(y, classes, neg_label=0, pos_label=1, + sparse_output=False, multilabel=None): """Binarize labels in a one-vs-all fashion Several regression and binary classification algorithms are @@ -380,20 +399,18 @@ def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1): classes : array-like of shape [n_classes] Uniquely holds the label for each class. - multilabel : boolean - Set to true if y is encoding a multilabel tasks (with a variable - number of label assignements per sample) rather than a multiclass task - where one sample has one and only one label assigned. - - neg_label: int (default: 0) + neg_label : int (default: 0) Value with which negative labels must be encoded. - pos_label: int (default: 1) + pos_label : int (default: 1) Value with which positive labels must be encoded. + sparse_output : boolean (default: False), + Set to true if output binary array is desired in CSR sparse format + Returns ------- - Y : numpy array of shape [n_samples, n_classes] + Y : numpy array or CSR matrix of shape [n_samples, n_classes] Shape will be [n_samples, 1] for binary problems. Examples @@ -419,43 +436,192 @@ def label_binarize(y, classes, multilabel=False, neg_label=0, pos_label=1): See also -------- - label_binarize : function to perform the transform operation of - LabelBinarizer with fixed classes. + LabelBinarizer : class used to wrap the functionality of label_binarize and + allow for fitting to classes independently of the transform operation """ - y_type = type_of_target(y) + if neg_label >= pos_label: + raise ValueError("neg_label={0} must be strictly less than " + "pos_label={1}.".format(neg_label, pos_label)) + + if (sparse_output and (pos_label == 0 or neg_label != 0)): + raise ValueError("Sparse binarization is only supported with non " + "zero pos_label and zero neg_label, got " + "pos_label={0} and neg_label={1}" + "".format(pos_label, neg_label)) + + if multilabel is not None: + warnings.warn("The multilabel parameter is deprecated as of version " + "0.15 and will be removed in 0.17. The parameter is no " + "longer necessary because the value is automatically " + "inferred.", DeprecationWarning) + + # To account for pos_label == 0 in the dense case + pos_switch = pos_label == 0 + if pos_switch: + pos_label = -neg_label - if multilabel or len(classes) > 2: - Y = np.zeros((len(y), len(classes)), dtype=np.int) - else: - Y = np.zeros((len(y), 1), dtype=np.int) + y_type = type_of_target(y) - Y += neg_label + n_samples = y.shape[0] if sp.issparse(y) else len(y) + n_classes = len(classes) + classes = np.asarray(classes) - if multilabel: - if y_type == "multilabel-indicator": - Y[y == 1] = pos_label + if y_type == "binary": + if len(classes) == 1: + Y = np.zeros((len(y), 1), dtype=np.int) + Y += neg_label return Y - elif y_type == "multilabel-sequences": - return MultiLabelBinarizer(classes=classes).fit_transform(y) + elif len(classes) >= 3: + y_type = "multiclass" + + sorted_class = np.sort(classes) + if (y_type == "multilabel-indicator" and classes.size != y.shape[1] or + not set(classes).issuperset(unique_labels(y))): + raise ValueError("classes {0} missmatch with the labels {1}" + "found in the data".format(classes, unique_labels(y))) + + if y_type in ("binary", "multiclass"): + y = column_or_1d(y) + indptr = np.arange(n_samples + 1) + indices = np.searchsorted(sorted_class, y) + data = np.empty_like(indices) + data.fill(pos_label) + + Y = sp.csr_matrix((data, indices, indptr), + shape=(n_samples, n_classes)) + + elif y_type == "multilabel-indicator": + Y = sp.csr_matrix(y) + if pos_label != 1: + data = np.empty_like(Y.data) + data.fill(pos_label) + Y.data = data + + elif y_type == "multilabel-sequences": + Y = MultiLabelBinarizer(classes=classes, + sparse_output=sparse_output).fit_transform(y) + + if sp.issparse(Y): + Y.data[:] = pos_label else: - raise ValueError("y should be in a multilabel format, " - "got %r" % (y,)) + Y[Y == 1] = pos_label + return Y + + if not sparse_output: + Y = Y.toarray() + + if neg_label != 0: + Y[Y == 0] = neg_label + + if pos_switch: + Y[Y == pos_label] = 0 + + # preserve label ordering + if np.any(classes != sorted_class): + indices = np.argsort(classes) + Y = Y[:, indices] + + if y_type == "binary": + Y = Y[:, -1].reshape((-1, 1)) + + return Y + +def _inverse_binarize_multiclass(y, classes): + """Inverse label binarization transformation for multiclass. + + Multiclass uses the maximal score instead of a threshold + """ + classes = np.asarray(classes) + + if sp.issparse(y): + # Find the argmax for each row in y where y is a CSR matrix + + y = y.tocsr() + n_samples, n_outputs = y.shape + outputs = np.arange(n_outputs) + row_max = sparse_min_max(y, 1)[1] + row_nnz = np.diff(y.indptr) + + y_data_repeated_max = np.repeat(row_max, row_nnz) + # picks out all indices obtaining the maximum per row + y_i_all_argmax = np.flatnonzero(y_data_repeated_max == y.data) + + # For corner case where last row has a max of 0 + if row_max[-1] == 0: + y_i_all_argmax = np.append(y_i_all_argmax, [len(y.data)]) + + # Gets the index of the first argmax in each row from y_i_all_argmax + index_first_argmax = np.searchsorted(y_i_all_argmax, y.indptr[:-1]) + # first argmax of each row + y_ind_ext = np.append(y.indices, [0]) + y_i_argmax = y_ind_ext[y_i_all_argmax[index_first_argmax]] + # Handle rows of all 0 + y_i_argmax[np.where(row_nnz == 0)[0]] = 0 + + # Handles rows with max of 0 that contain negative numbers + samples = np.arange(n_samples)[(row_nnz > 0) & + (row_max.ravel() == 0)] + for i in samples: + ind = y.indices[y.indptr[i]:y.indptr[i+1]] + y_i_argmax[i] = classes[np.setdiff1d(outputs, ind)][0] + + return classes[y_i_argmax] else: - y = column_or_1d(y) + return classes.take(y.argmax(axis=1), mode="clip") - if len(classes) == 2: - Y[y == classes[1], 0] = pos_label - return Y - elif len(classes) >= 2: - for i, k in enumerate(classes): - Y[y == k, i] = pos_label - return Y +def _inverse_binarize_thresholding(y, output_type, classes, threshold): + """Inverse label binarization transformation using thresholding.""" + + if output_type == "binary" and y.ndim == 2 and y.shape[1] > 2: + raise ValueError("output_type='binary', but y.shape = {0}". + format(y.shape)) + + if output_type != "binary" and y.shape[1] != len(classes): + raise ValueError("The number of class is not equal to the number of " + "dimension of y.") + classes = np.asarray(classes) + + # Perform thresholding + if sp.issparse(y): + if threshold > 0: + if y.format not in ('csr', 'csc'): + y = y.tocsr() + y.data = np.array(y.data > threshold, dtype=np.int) + y.eliminate_zeros() else: - # Only one class, returns a matrix with all negative labels. - return Y + y = np.array(y.toarray() > threshold, dtype=np.int) + else: + y = np.array(y > threshold, dtype=np.int) + + # Inverse transform data + if output_type == "binary": + if y.ndim == 2 and y.shape[1] == 2: + return classes[y[:, 1]] + else: + if len(classes) == 1: + y = np.empty(len(y), dtype=classes.dtype) + y.fill(classes[0]) + return y + else: + return classes[y.ravel()] + + elif output_type == "multilabel-indicator": + return y + + elif output_type == "multilabel-sequences": + warnings.warn('Direct support for sequence of sequences multilabel ' + 'representation will be unavailable from version 0.17. ' + 'Use sklearn.preprocessing.MultiLabelBinarizer to ' + 'convert to a label indicator representation.', + DeprecationWarning) + mlb = MultiLabelBinarizer(classes=classes).fit([]) + return mlb.inverse_transform(y) + + else: + raise ValueError("{0} format is not supported".format(output_type)) class MultiLabelBinarizer(BaseEstimator, TransformerMixin): @@ -493,8 +659,9 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): ['comedy', 'sci-fi', 'thriller'] """ - def __init__(self, classes=None): + def __init__(self, classes=None, sparse_output=False): self.classes = classes + self.sparse_output = sparse_output def fit(self, y): """Fit the label sets binarizer, storing `classes_` @@ -531,7 +698,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): Returns ------- - y_indicator : array, shape (n_samples, n_classes) + y_indicator : array or CSR matrix, shape (n_samples, n_classes) A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]` is in `y[i]`, and 0 otherwise. """ @@ -552,7 +719,11 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): class_mapping[:] = tmp self.classes_, inverse = np.unique(class_mapping, return_inverse=True) yt.indices = np.take(inverse, yt.indices) - return yt.toarray() + + if not self.sparse_output: + yt = yt.toarray() + + return yt def transform(self, y): """Transform the given label sets @@ -566,13 +737,17 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): Returns ------- - y_indicator : array, shape (n_samples, n_classes) + y_indicator : array or CSR matrix, shape (n_samples, n_classes) A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]` is in `y[i]`, and 0 otherwise. """ class_to_index = dict(zip(self.classes_, range(len(self.classes_)))) yt = self._transform(y, class_to_index) - return yt.toarray() + + if not self.sparse_output: + yt = yt.toarray() + + return yt def _transform(self, y, class_mapping): """Transforms the label sets with a given mapping @@ -603,7 +778,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): Parameters ---------- - yt : array of shape (n_samples, n_classes) + yt : array or sparse matrix of shape (n_samples, n_classes) A matrix containing only 1s ands 0s. Returns @@ -612,13 +787,20 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): The set of labels for each sample such that `y[i]` consists of `classes_[j]` for each `yt[i, j] == 1`. """ - yt = np.asarray(yt) if yt.shape[1] != len(self.classes_): raise ValueError('Expected indicator for {0} classes, but got {1}' .format(len(self.classes_), yt.shape[1])) - unexpected = np.setdiff1d(yt, [0, 1]) - if len(unexpected) > 0: - raise ValueError('Expected only 0s and 1s in label indicator. ' - 'Also got {0}'.format(unexpected)) - return [tuple(self.classes_.compress(indicators)) for indicators in yt] + if sp.issparse(yt): + yt = yt.tocsr() + if len(yt.data) != 0 and len(np.setdiff1d(yt.data, [0, 1])) > 0: + raise ValueError('Expected only 0s and 1s in label indicator.') + return [tuple(self.classes_.take(yt.indices[start:end])) + for start, end in zip(yt.indptr[:-1], yt.indptr[1:])] + else: + unexpected = np.setdiff1d(yt, [0, 1]) + if len(unexpected) > 0: + raise ValueError('Expected only 0s and 1s in label indicator. ' + 'Also got {0}'.format(unexpected)) + return [tuple(self.classes_.compress(indicators)) for indicators + in yt] diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index 17619cab7915aef1c3da3934ca76851df0ef3228..2cc786605f3808b1214573c630db1eb04856831e 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -1,5 +1,14 @@ import numpy as np +from scipy.sparse import issparse +from scipy.sparse import coo_matrix +from scipy.sparse import csc_matrix +from scipy.sparse import csr_matrix +from scipy.sparse import dok_matrix +from scipy.sparse import lil_matrix + +from sklearn.utils.multiclass import type_of_target + from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_equal @@ -7,6 +16,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import ignore_warnings from sklearn.preprocessing.label import LabelBinarizer @@ -14,6 +24,8 @@ from sklearn.preprocessing.label import MultiLabelBinarizer from sklearn.preprocessing.label import LabelEncoder from sklearn.preprocessing.label import label_binarize +from sklearn.preprocessing.label import _inverse_binarize_thresholding +from sklearn.preprocessing.label import _inverse_binarize_multiclass from sklearn import datasets from sklearn.linear_model.stochastic_gradient import SGDClassifier @@ -30,14 +42,28 @@ def toarray(a): def test_label_binarizer(): lb = LabelBinarizer() + # one-class case defaults to negative label + inp = ["pos", "pos", "pos", "pos"] + expected = np.array([[0, 0, 0, 0]]).T + got = lb.fit_transform(inp) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) + assert_array_equal(lb.classes_, ["pos"]) + assert_array_equal(expected, got) + assert_array_equal(lb.inverse_transform(got), inp) + # two-class case inp = ["neg", "pos", "pos", "neg"] expected = np.array([[0, 1, 1, 0]]).T got = lb.fit_transform(inp) - assert_false(lb.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) assert_array_equal(lb.classes_, ["neg", "pos"]) assert_array_equal(expected, got) - assert_array_equal(lb.inverse_transform(got), inp) + + to_invert = np.array([[1, 0], + [0, 1], + [0, 1], + [1, 0]]) + assert_array_equal(lb.inverse_transform(to_invert), inp) # multi-class case inp = ["spam", "ham", "eggs", "ham", "0"] @@ -48,7 +74,7 @@ def test_label_binarizer(): [1, 0, 0, 0]]) got = lb.fit_transform(inp) assert_array_equal(lb.classes_, ['0', 'eggs', 'ham', 'spam']) - assert_false(lb.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) assert_array_equal(expected, got) assert_array_equal(lb.inverse_transform(got), inp) @@ -70,10 +96,13 @@ def test_label_binarizer_column_y(): out_2 = lb_2.fit_transform(inp_array) assert_array_equal(out_1, multilabel_indicator) - assert_true(lb_1.multilabel_) + assert_true(assert_warns(DeprecationWarning, getattr, lb_1, "multilabel_")) + assert_false(assert_warns(DeprecationWarning, getattr, lb_1, + "indicator_matrix_")) assert_array_equal(out_2, binaryclass_array) - assert_false(lb_2.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb_2, + "multilabel_")) # second for multiclass classification vs multi-label with multiple # classes @@ -90,23 +119,26 @@ def test_label_binarizer_column_y(): out_2 = lb_2.fit_transform(inp_array) assert_array_equal(out_1, out_2) - assert_true(lb_1.multilabel_) + assert_true(assert_warns(DeprecationWarning, getattr, lb_1, "multilabel_")) assert_array_equal(out_2, indicator) - assert_false(lb_2.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb_2, + "multilabel_")) def test_label_binarizer_set_label_encoding(): - lb = LabelBinarizer(neg_label=-2, pos_label=2) + lb = LabelBinarizer(neg_label=-2, pos_label=0) - # two-class case + # two-class case with pos_label=0 inp = np.array([0, 1, 1, 0]) - expected = np.array([[-2, 2, 2, -2]]).T + expected = np.array([[-2, 0, 0, -2]]).T got = lb.fit_transform(inp) - assert_false(lb.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) assert_array_equal(expected, got) assert_array_equal(lb.inverse_transform(got), inp) + lb = LabelBinarizer(neg_label=-2, pos_label=2) + # multi-class case inp = np.array([3, 2, 1, 2, 0]) expected = np.array([[-2, -2, -2, +2], @@ -115,59 +147,53 @@ def test_label_binarizer_set_label_encoding(): [-2, -2, +2, -2], [+2, -2, -2, -2]]) got = lb.fit_transform(inp) - assert_false(lb.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) assert_array_equal(expected, got) assert_array_equal(lb.inverse_transform(got), inp) -def test_label_binarizer_multilabel(): - lb = LabelBinarizer() - - # test input as lists of tuples - inp = [(2, 3), (1,), (1, 2)] - indicator_mat = np.array([[0, 1, 1], - [1, 0, 0], - [1, 1, 0]]) - got = assert_warns(DeprecationWarning, lb.fit_transform, inp) - assert_true(lb.multilabel_) - assert_array_equal(indicator_mat, got) - assert_equal(lb.inverse_transform(got), inp) - - # test input as label indicator matrix - lb.fit(indicator_mat) - assert_array_equal(indicator_mat, - lb.inverse_transform(indicator_mat)) - - # regression test for the two-class multilabel case - lb = LabelBinarizer() - inp = [[1, 0], [0], [1], [0, 1]] - expected = np.array([[1, 1], - [1, 0], - [0, 1], - [1, 1]]) - got = assert_warns(DeprecationWarning, lb.fit_transform, inp) - assert_true(lb.multilabel_) - assert_array_equal(expected, got) - assert_equal([set(x) for x in lb.inverse_transform(got)], - [set(x) for x in inp]) - - +@ignore_warnings def test_label_binarizer_errors(): """Check that invalid arguments yield ValueError""" one_class = np.array([0, 0, 0, 0]) lb = LabelBinarizer().fit(one_class) - assert_false(lb.multilabel_) + assert_false(assert_warns(DeprecationWarning, getattr, lb, "multilabel_")) - multi_label = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]) + multi_label = [(2, 3), (0,), (0, 2)] assert_raises(ValueError, lb.transform, multi_label) lb = LabelBinarizer() assert_raises(ValueError, lb.transform, []) assert_raises(ValueError, lb.inverse_transform, []) + y = np.array([[0, 1, 0], [1, 1, 1]]) + classes = np.arange(3) + assert_raises(ValueError, label_binarize, y, classes, multilabel=True, + neg_label=2, pos_label=1) + assert_raises(ValueError, label_binarize, y, classes, multilabel=True, + neg_label=2, pos_label=2) + assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=1) assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=2) + assert_raises(ValueError, LabelBinarizer, neg_label=1, pos_label=2, + sparse_output=True) + + # Fail on y_type + assert_raises(ValueError, _inverse_binarize_thresholding, + y=csr_matrix([[1, 2], [2, 1]]), output_type="foo", + classes=[1, 2], threshold=0) + + # Fail on the number of classes + assert_raises(ValueError, _inverse_binarize_thresholding, + y=csr_matrix([[1, 2], [2, 1]]), output_type="foo", + classes=[1, 2, 3], threshold=0) + + # Fail on the dimension of 'binary' + assert_raises(ValueError, _inverse_binarize_thresholding, + y=np.array([[1, 2, 3], [2, 1, 3]]), output_type="binary", + classes=[1, 2, 3], threshold=0) + def test_label_encoder(): """Test LabelEncoder's transform and inverse_transform methods""" @@ -192,19 +218,6 @@ def test_label_encoder_fit_transform(): assert_array_equal(ret, [1, 1, 2, 0]) -def test_label_encoder_string_labels(): - """Test LabelEncoder's transform and inverse_transform methods with - non-numeric labels""" - le = LabelEncoder() - le.fit(["paris", "paris", "tokyo", "amsterdam"]) - assert_array_equal(le.classes_, ["amsterdam", "paris", "tokyo"]) - assert_array_equal(le.transform(["tokyo", "tokyo", "paris"]), - [2, 2, 1]) - assert_array_equal(le.inverse_transform([2, 2, 1]), - ["tokyo", "tokyo", "paris"]) - assert_raises(ValueError, le.transform, ["london"]) - - def test_label_encoder_errors(): """Check that invalid arguments yield ValueError""" le = LabelEncoder() @@ -212,52 +225,44 @@ def test_label_encoder_errors(): assert_raises(ValueError, le.inverse_transform, []) -def test_label_binarizer_iris(): - lb = LabelBinarizer() - Y = lb.fit_transform(iris.target) - clfs = [SGDClassifier().fit(iris.data, Y[:, k]) - for k in range(len(lb.classes_))] - Y_pred = np.array([clf.decision_function(iris.data) for clf in clfs]).T - y_pred = lb.inverse_transform(Y_pred) - accuracy = np.mean(iris.target == y_pred) - y_pred2 = SGDClassifier().fit(iris.data, iris.target).predict(iris.data) - accuracy2 = np.mean(iris.target == y_pred2) - assert_almost_equal(accuracy, accuracy2) - - -def test_label_binarizer_multilabel_unlabeled(): - """Check that LabelBinarizer can handle an unlabeled sample""" - lb = LabelBinarizer() - y = [[1, 2], [1], []] - Y = np.array([[1, 1], - [1, 0], - [0, 0]]) - assert_array_equal(assert_warns(DeprecationWarning, - lb.fit_transform, y), Y) - - -def test_label_binarize_with_multilabel_indicator(): - """Check that passing a binary indicator matrix is not noop""" - - classes = np.arange(3) - neg_label = -1 - pos_label = 2 - - y = np.array([[0, 1, 0], [1, 1, 1]]) - expected = np.array([[-1, 2, -1], [2, 2, 2]]) - - # With label binarize - output = label_binarize(y, classes, multilabel=True, neg_label=neg_label, - pos_label=pos_label) - assert_array_equal(output, expected) - - # With the transformer - lb = LabelBinarizer(pos_label=pos_label, neg_label=neg_label) - output = lb.fit_transform(y) - assert_array_equal(output, expected) +def test_sparse_output_mutlilabel_binarizer(): + # test input as iterable of iterables + inputs = [ + lambda: [(2, 3), (1,), (1, 2)], + lambda: (set([2, 3]), set([1]), set([1, 2])), + lambda: iter([iter((2, 3)), iter((1,)), set([1, 2])]), + ] + indicator_mat = np.array([[0, 1, 1], + [1, 0, 0], + [1, 1, 0]]) - output = lb.fit(y).transform(y) - assert_array_equal(output, expected) + inverse = inputs[0]() + for sparse_output in [True, False]: + for inp in inputs: + # With fit_tranform + mlb = MultiLabelBinarizer(sparse_output=sparse_output) + got = mlb.fit_transform(inp()) + assert_equal(issparse(got), sparse_output) + if sparse_output: + got = got.toarray() + assert_array_equal(indicator_mat, got) + assert_array_equal([1, 2, 3], mlb.classes_) + assert_equal(mlb.inverse_transform(got), inverse) + + # With fit + mlb = MultiLabelBinarizer(sparse_output=sparse_output) + got = mlb.fit(inp()).transform(inp()) + assert_equal(issparse(got), sparse_output) + if sparse_output: + got = got.toarray() + assert_array_equal(indicator_mat, got) + assert_array_equal([1, 2, 3], mlb.classes_) + assert_equal(mlb.inverse_transform(got), inverse) + + assert_raises(ValueError, mlb.inverse_transform, + csr_matrix(np.array([[0, 1, 1], + [2, 0, 0], + [1, 1, 0]]))) def test_mutlilabel_binarizer(): @@ -399,3 +404,140 @@ def test_multilabel_binarizer_inverse_validation(): # Wrong shape assert_raises(ValueError, mlb.inverse_transform, np.array([[1]])) assert_raises(ValueError, mlb.inverse_transform, np.array([[1, 1, 1]])) + + +def test_label_binarize_with_class_order(): + out = label_binarize([1, 6], classes=[1, 2, 4, 6]) + expected = np.array([[1, 0, 0, 0], [0, 0, 0, 1]]) + assert_array_equal(out, expected) + + # Modified class order + out = label_binarize([1, 6], classes=[1, 6, 4, 2]) + expected = np.array([[1, 0, 0, 0], [0, 1, 0, 0]]) + assert_array_equal(out, expected) + + +def check_binarized_results(y, classes, pos_label, neg_label, expected): + for sparse_output in [True, False]: + if ((pos_label == 0 or neg_label != 0) and sparse_output): + assert_raises(ValueError, label_binarize, y, classes, + neg_label=neg_label, pos_label=pos_label, + sparse_output=sparse_output) + continue + + # check label_binarize + binarized = label_binarize(y, classes, neg_label=neg_label, + pos_label=pos_label, + sparse_output=sparse_output) + assert_array_equal(toarray(binarized), expected) + assert_equal(issparse(binarized), sparse_output) + + # check inverse + y_type = type_of_target(y) + if y_type == "multiclass": + inversed = _inverse_binarize_multiclass(binarized, classes=classes) + + else: + inversed = _inverse_binarize_thresholding(binarized, + output_type=y_type, + classes=classes, + threshold=((neg_label + + pos_label) / + 2.)) + + assert_array_equal(toarray(inversed), toarray(y)) + + # Check label binarizer + lb = LabelBinarizer(neg_label=neg_label, pos_label=pos_label, + sparse_output=sparse_output) + binarized = lb.fit_transform(y) + assert_array_equal(toarray(binarized), expected) + assert_equal(issparse(binarized), sparse_output) + inverse_output = lb.inverse_transform(binarized) + assert_array_equal(toarray(inverse_output), toarray(y)) + assert_equal(issparse(inverse_output), issparse(y)) + + +def test_label_binarize_binary(): + y = [0, 1, 0] + classes = [0, 1] + pos_label = 2 + neg_label = -1 + expected = np.array([[2, -1], [-1, 2], [2, -1]])[:, 1].reshape((-1, 1)) + + yield check_binarized_results, y, classes, pos_label, neg_label, expected + + +def test_label_binarize_multiclass(): + y = [0, 1, 2] + classes = [0, 1, 2] + pos_label = 2 + neg_label = 0 + expected = 2 * np.eye(3) + + yield check_binarized_results, y, classes, pos_label, neg_label, expected + + assert_raises(ValueError, label_binarize, y, classes, neg_label=-1, + pos_label=pos_label, sparse_output=True) + + +def test_label_binarize_multilabel(): + y_seq = [(1,), (0, 1, 2), tuple()] + y_ind = np.array([[0, 1, 0], [1, 1, 1], [0, 0, 0]]) + classes = [0, 1, 2] + pos_label = 2 + neg_label = 0 + expected = pos_label * y_ind + y_sparse = [sparse_matrix(y_ind) + for sparse_matrix in [coo_matrix, csc_matrix, csr_matrix, + dok_matrix, lil_matrix]] + + for y in [y_ind] + y_sparse: + yield (check_binarized_results, y, classes, pos_label, neg_label, + expected) + + deprecation_message = ("Direct support for sequence of sequences " + + "multilabel representation will be unavailable " + + "from version 0.17. Use sklearn.preprocessing." + + "MultiLabelBinarizer to convert to a label " + + "indicator representation.") + + assert_warns_message(DeprecationWarning, deprecation_message, + check_binarized_results, y_seq, classes, pos_label, + neg_label, expected) + + assert_raises(ValueError, label_binarize, y, classes, neg_label=-1, + pos_label=pos_label, sparse_output=True) + + +def test_deprecation_inverse_binarize_thresholding(): + deprecation_message = ("Direct support for sequence of sequences " + + "multilabel representation will be unavailable " + + "from version 0.17. Use sklearn.preprocessing." + + "MultiLabelBinarizer to convert to a label " + + "indicator representation.") + + assert_warns_message(DeprecationWarning, deprecation_message, + _inverse_binarize_thresholding, + y=csr_matrix([[1, 0], [0, 1]]), + output_type="multilabel-sequences", + classes=[1, 2], threshold=0) + + +def test_invalid_input_label_binarize(): + assert_raises(ValueError, label_binarize, [0.5, 2], classes=[1, 2]) + assert_raises(ValueError, label_binarize, [0, 2], classes=[0, 2], + pos_label=0, neg_label=1) + assert_raises(ValueError, label_binarize, [1, 2], classes=[0, 2]) + + +def test_inverse_binarize_multiclass(): + got = _inverse_binarize_multiclass(csr_matrix([[0, 1, 0], + [-1, 0, -1], + [0, 0, 0]]), + np.arange(3)) + assert_array_equal(got, np.array([1, 1, 0])) + +if __name__ == "__main__": + import nose + nose.runmodule() diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 4dc4bd5ee8a8195bbfe646a4aafeea619879f563..eb8e08ecd6191aaccbbccabeae68ae5778668475 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -1,4 +1,4 @@ -# Author: Arnaud Joly, Joel Nothman +# Author: Arnaud Joly, Joel Nothman, Hamzeh Alsalhi # # License: BSD 3 clause """ @@ -10,6 +10,11 @@ from collections import Sequence from itertools import chain import warnings +from scipy.sparse import issparse +from scipy.sparse.base import spmatrix +from scipy.sparse import dok_matrix +from scipy.sparse import lil_matrix + import numpy as np from ..externals.six import string_types @@ -139,9 +144,18 @@ def is_label_indicator_matrix(y): """ if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1): return False - labels = np.unique(y) - return len(labels) <= 2 and (y.dtype.kind in 'biu' # bool, int, uint - or _is_integral_float(labels)) + + if issparse(y): + if isinstance(y, (dok_matrix, lil_matrix)): + y = y.tocsr() + return (len(y.data) == 0 or np.ptp(y.data) == 0 and + (y.dtype.kind in 'biu' or # bool, int, uint + _is_integral_float(np.unique(y.data)))) + else: + labels = np.unique(y) + + return len(labels) < 3 and (y.dtype.kind in 'biu' or # bool, int, uint + _is_integral_float(labels)) def is_sequence_of_sequences(y): @@ -163,7 +177,7 @@ def is_sequence_of_sequences(y): try: out = (not isinstance(y[0], np.ndarray) and isinstance(y[0], Sequence) and not isinstance(y[0], string_types)) - except IndexError: + except (IndexError, TypeError): return False if out: warnings.warn('Direct support for sequence of sequences multilabel ' @@ -256,7 +270,7 @@ def type_of_target(y): 'multilabel-indicator' """ # XXX: is there a way to duck-type this condition? - valid = (isinstance(y, (np.ndarray, Sequence)) + valid = (isinstance(y, (np.ndarray, Sequence, spmatrix)) and not isinstance(y, string_types)) if not valid: raise ValueError('Expected array-like (array or non-string sequence), ' diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py index e3abdfeea5217f0684a411cb0f54a5d04fc4c6f2..688803e45180233c82e3af8fa863d1efb0331905 100644 --- a/sklearn/utils/tests/test_multiclass.py +++ b/sklearn/utils/tests/test_multiclass.py @@ -4,6 +4,13 @@ from functools import partial from sklearn.externals.six.moves import xrange from sklearn.externals.six import iteritems +from scipy.sparse import issparse +from scipy.sparse import csc_matrix +from scipy.sparse import csr_matrix +from scipy.sparse import coo_matrix +from scipy.sparse import dok_matrix +from scipy.sparse import lil_matrix + from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_true @@ -20,17 +27,20 @@ from sklearn.utils.multiclass import type_of_target EXAMPLES = { 'multilabel-indicator': [ - np.random.RandomState(42).randint(2, size=(10, 10)), - np.array([[0, 1], [1, 0]]), - np.array([[0, 1], [1, 0]], dtype=np.bool), - np.array([[0, 1], [1, 0]], dtype=np.int8), - np.array([[0, 1], [1, 0]], dtype=np.uint8), - np.array([[0, 1], [1, 0]], dtype=np.float), - np.array([[0, 1], [1, 0]], dtype=np.float32), - np.array([[0, 0], [0, 0]]), + # valid when the data is formated as sparse or dense, identified + # by CSR format when the testing takes place + csr_matrix(np.random.RandomState(42).randint(2, size=(10, 10))), + csr_matrix(np.array([[0, 1], [1, 0]])), + csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.bool)), + csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.int8)), + csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.uint8)), + csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float)), + csr_matrix(np.array([[0, 1], [1, 0]], dtype=np.float32)), + csr_matrix(np.array([[0, 0], [0, 0]])), + csr_matrix(np.array([[0, 1]])), + # Only valid when data is dense np.array([[-1, 1], [1, -1]]), np.array([[-3, 3], [3, -3]]), - np.array([[0, 1]]), ], 'multilabel-sequences': [ [[0, 1]], @@ -196,14 +206,14 @@ def test_unique_labels_non_specific(): @ignore_warnings def test_unique_labels_mixed_types(): - #Mix of multilabel-indicator and multilabel-sequences + # Mix of multilabel-indicator and multilabel-sequences mix_multilabel_format = product(EXAMPLES["multilabel-indicator"], EXAMPLES["multilabel-sequences"]) for y_multilabel, y_multiclass in mix_multilabel_format: assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel) assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass) - #Mix with binary or multiclass and multilabel + # Mix with binary or multiclass and multilabel mix_clf_format = product(EXAMPLES["multilabel-indicator"] + EXAMPLES["multilabel-sequences"], EXAMPLES["multiclass"] + @@ -239,14 +249,43 @@ def test_is_multilabel(): def test_is_label_indicator_matrix(): for group, group_examples in iteritems(EXAMPLES): - if group == 'multilabel-indicator': - assert_, exp = assert_true, 'True' + if group in ['multilabel-indicator']: + dense_assert_, dense_exp = assert_true, 'True' else: - assert_, exp = assert_false, 'False' + dense_assert_, dense_exp = assert_false, 'False' + for example in group_examples: - assert_(is_label_indicator_matrix(example), - msg='is_label_indicator_matrix(%r) should be %s' - % (example, exp)) + # Only mark explicitly defined sparse examples as valid sparse + # multilabel-indicators + if group == 'multilabel-indicator' and issparse(example): + sparse_assert_, sparse_exp = assert_true, 'True' + else: + sparse_assert_, sparse_exp = assert_false, 'False' + + if (issparse(example) or + (isinstance(example, np.ndarray) and + example.ndim == 2 and + example.dtype.kind in 'biuf' and + example.shape[1] > 0)): + examples_sparse = [sparse_matrix(example) + for sparse_matrix in [coo_matrix, + csc_matrix, + csr_matrix, + dok_matrix, + lil_matrix]] + for exmpl_sparse in examples_sparse: + sparse_assert_(is_label_indicator_matrix(exmpl_sparse), + msg=('is_label_indicator_matrix(%r)' + ' should be %s') + % (exmpl_sparse, sparse_exp)) + + # Densify sparse examples before testing + if issparse(example): + example = example.toarray() + + dense_assert_(is_label_indicator_matrix(example), + msg='is_label_indicator_matrix(%r) should be %s' + % (example, dense_exp)) def test_is_sequence_of_sequences(): @@ -274,3 +313,8 @@ def test_type_of_target(): for example in NON_ARRAY_LIKE_EXAMPLES: assert_raises(ValueError, type_of_target, example) + + +if __name__ == "__main__": + import nose + nose.runmodule()