Skip to content
Snippets Groups Projects
Commit fa76ac38 authored by Joel Nothman's avatar Joel Nothman
Browse files

FIX OvR with constant label for non-predict methods

parent aaefdbd3
No related branches found
No related tags found
No related merge requests found
......@@ -128,6 +128,9 @@ class _ConstantPredictor(BaseEstimator):
def decision_function(self, X):
return np.repeat(self.y_, X.shape[0])
def predict_proba(self, X):
return np.repeat([[0, 1]], X.shape[0], axis=0)
class OneVsRestClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""One-vs-the-rest (OvR) multiclass/multilabel strategy
......
......@@ -20,7 +20,7 @@ from sklearn.metrics import recall_score
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
Perceptron)
Perceptron, LogisticRegression)
from sklearn.tree import DecisionTreeClassifier
from sklearn.grid_search import GridSearchCV
from sklearn.pipeline import Pipeline
......@@ -63,10 +63,14 @@ def test_ovr_always_present():
X[:5, :] = 0
y = [[int(i >= 5), 2, 3] for i in range(10)]
with warnings.catch_warnings(record=True):
ovr = OneVsRestClassifier(DecisionTreeClassifier())
ovr = OneVsRestClassifier(LogisticRegression())
ovr.fit(X, y)
y_pred = ovr.predict(X)
assert_array_equal(np.array(y_pred), np.array(y))
y_pred = ovr.decision_function(X)
assert_equal(np.unique(y_pred[:, -2:]), 1)
y_pred = ovr.predict_proba(X)
assert_array_equal(y_pred[:, -2:], np.ones((X.shape[0], 2)))
def test_ovr_multilabel():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment