diff --git a/scikits/learn/svm/base.py b/scikits/learn/svm/base.py index 1f21c3d670f31070dc3ad6ac344728d0f5e99730..7e1da0ba3526ec325b63fd3091e5dc10de782a31 100644 --- a/scikits/learn/svm/base.py +++ b/scikits/learn/svm/base.py @@ -435,7 +435,8 @@ class BaseLibLinear(BaseEstimator): self._get_bias()) if len(self.label_) <= 2: - # one class + # in the two-class case, the decision sign needs be flipped + # due to liblinear's design return -dec_func else: return dec_func @@ -451,14 +452,22 @@ class BaseLibLinear(BaseEstimator): @property def intercept_(self): if self.fit_intercept: - return self.intercept_scaling * self.raw_coef_[:, -1] + ret = self.intercept_scaling * self.raw_coef_[:, -1] + if len(self.label_) <= 2: + ret *= -1 + return ret return 0.0 @property def coef_(self): if self.fit_intercept: - return self.raw_coef_[:, : -1] - return self.raw_coef_ + ret = self.raw_coef_[:, : -1] + else: + ret = self.raw_coef_ + if len(self.label_) <= 2: + return -ret + else: + return ret def predict_proba(self, T): # only available for logistic regression diff --git a/scikits/learn/svm/sparse/base.py b/scikits/learn/svm/sparse/base.py index 36e85bebac68e3eaf5e12f7982c8392c40c8f41c..62b3d3ed18a4f8ee137de5c94f4d9f6545ecb3cb 100644 --- a/scikits/learn/svm/sparse/base.py +++ b/scikits/learn/svm/sparse/base.py @@ -265,7 +265,8 @@ class SparseBaseLibLinear(BaseLibLinear): self._get_bias()) if len(self.label_) <= 2: - # one class + # in the two-class case, the decision sign needs be flipped + # due to liblinear's design return -dec_func else: return dec_func diff --git a/scikits/learn/svm/tests/test_svm.py b/scikits/learn/svm/tests/test_svm.py index 46e0e3be0c13502bf1f3c22618c28f8330ed9000..18bbab4818d611f0e8f8c4cc38ecc8f85ec5ac78 100644 --- a/scikits/learn/svm/tests/test_svm.py +++ b/scikits/learn/svm/tests/test_svm.py @@ -417,7 +417,7 @@ def test_dense_liblinear_intercept_handling(classifier=svm.LinearSVC): clf.intercept_scaling = 100 clf.fit(X, y) intercept1 = clf.intercept_ - assert intercept1 > 1 + assert intercept1 < -1 # when intercept_scaling is sufficiently high, the intercept value # doesn't depend on intercept_scaling value @@ -435,14 +435,26 @@ def test_liblinear_predict(): returns the same as the one in libliblinear """ + # multi-class case clf = svm.LinearSVC().fit(iris.data, iris.target) - weights = clf.coef_.T bias = clf.intercept_ H = np.dot(iris.data, weights) + bias - assert_array_equal(clf.predict(iris.data), H.argmax(axis=1)) + # binary-class case + X = [[2, 1], + [3, 1], + [1, 3], + [2, 3]] + y = [0, 0, 1, 1] + + clf = svm.LinearSVC().fit(X, y) + weights = np.ravel(clf.coef_) + bias = clf.intercept_ + H = np.dot(X, weights) + bias + assert_array_equal(clf.predict(X), (H > 0).astype(int)) + if __name__ == '__main__': import nose nose.runmodule()