Skip to content
Snippets Groups Projects
Commit ecb869cc authored by Mathieu Blondel's avatar Mathieu Blondel
Browse files

Support multilabel case in LabelBinarizer.

parent 138e688e
Branches
Tags
No related merge requests found
...@@ -128,6 +128,10 @@ class Binarizer(BaseEstimator): ...@@ -128,6 +128,10 @@ class Binarizer(BaseEstimator):
return X return X
def _is_multilabel(y):
return isinstance(y[0], tuple) or isinstance(y[0], list)
class LabelBinarizer(BaseEstimator, TransformerMixin): class LabelBinarizer(BaseEstimator, TransformerMixin):
"""Binarize labels in a one-vs-all fashion. """Binarize labels in a one-vs-all fashion.
...@@ -160,6 +164,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): ...@@ -160,6 +164,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
>>> clf.transform([1, 6]) >>> clf.transform([1, 6])
array([[ 1., 0., 0., 0.], array([[ 1., 0., 0., 0.],
[ 0., 0., 0., 1.]]) [ 0., 0., 0., 1.]])
>>> clf.fit_transform([(1,2),(3,)])
array([[ 1., 1., 0.],
[ 0., 0., 1.]])
""" """
def fit(self, y): def fit(self, y):
...@@ -174,6 +182,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): ...@@ -174,6 +182,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
------- -------
self : returns an instance of self. self : returns an instance of self.
""" """
self.multilabel = _is_multilabel(y)
if self.multilabel:
self.classes_ = np.unique(reduce(lambda a,b:a+b, y))
else:
self.classes_ = np.unique(y) self.classes_ = np.unique(y)
return self return self
...@@ -192,13 +204,30 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): ...@@ -192,13 +204,30 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
------- -------
Y : numpy array of shape [n_samples, n_classes] Y : numpy array of shape [n_samples, n_classes]
""" """
if len(self.classes_) == 2: if len(self.classes_) == 2:
Y = np.zeros((len(y), 1)) Y = np.zeros((len(y), 1))
else:
Y = np.zeros((len(y), len(self.classes_)))
if self.multilabel:
if not _is_multilabel(y):
raise ValueError, "y should be a list of label lists/tuples"
# inverse map: label => column index
imap = dict((v,k) for k,v in enumerate(self.classes_))
for i, label_tuple in enumerate(y):
for label in label_tuple:
Y[i, imap[label]] = 1
return Y
elif len(self.classes_) == 2:
Y[y == self.classes_[1], 0] = 1 Y[y == self.classes_[1], 0] = 1
return Y return Y
elif len(self.classes_) >= 2: elif len(self.classes_) >= 2:
Y = np.zeros((len(y), len(self.classes_)))
for i, k in enumerate(self.classes_): for i, k in enumerate(self.classes_):
Y[y == k, i] = 1 Y[y == k, i] = 1
return Y return Y
...@@ -225,8 +254,15 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): ...@@ -225,8 +254,15 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
this allows to use the output of a linear model's decision_function this allows to use the output of a linear model's decision_function
method directly as the input of inverse_transform. method directly as the input of inverse_transform.
""" """
if self.multilabel:
Y = np.array(Y > 0, dtype=int)
return [tuple(self.classes_[np.flatnonzero(Y[i])])
for i in range(Y.shape[0])]
if len(Y.shape) == 1 or Y.shape[1] == 1: if len(Y.shape) == 1 or Y.shape[1] == 1:
y = np.array(Y.ravel() > 0, dtype=int) y = np.array(Y.ravel() > 0, dtype=int)
else: else:
y = Y.argmax(axis=1) y = Y.argmax(axis=1)
return self.classes_[y] return self.classes_[y]
...@@ -150,6 +150,17 @@ def test_label_binarizer(): ...@@ -150,6 +150,17 @@ def test_label_binarizer():
assert_array_equal(expected, got) assert_array_equal(expected, got)
assert_array_equal(lb.inverse_transform(got), inp) assert_array_equal(lb.inverse_transform(got), inp)
def test_label_binarizer_multilabel():
lb = LabelBinarizer()
inp = [(2, 3), (1,), (1, 2)]
expected = np.array([[0, 1, 1],
[1, 0, 0],
[1, 1, 0]])
got = lb.fit_transform(inp)
assert_array_equal(expected, got)
assert_equal(lb.inverse_transform(got), inp)
def test_label_binarizer_iris(): def test_label_binarizer_iris():
lb = LabelBinarizer() lb = LabelBinarizer()
Y = lb.fit_transform(iris.target) Y = lb.fit_transform(iris.target)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment