diff --git a/scikits/learn/svm.py b/scikits/learn/svm.py index 8282943708f73e9b278e43a6c85f00c5d6dd5cd5..32647428b039efb09adbb6c3285901e49645a6a5 100644 --- a/scikits/learn/svm.py +++ b/scikits/learn/svm.py @@ -8,14 +8,6 @@ class BaseLibsvm(object): support vector machine classification and regression. Should not be used directly, use derived classes instead - - Parameters - ---------- - X : array-like, shape = [N, D] - It will be converted to a floating-point array. - y : array, shape = [N] - target vector relative to X - It will be converted to a floating-point array. """ support_ = np.empty((0,0), dtype=np.float64, order='C') dual_coef_ = np.empty((0,0), dtype=np.float64, order='C') @@ -83,9 +75,8 @@ class BaseLibsvm(object): test vectors T. For a classification model, the predicted class for each - sample in T is returned. - For a regression model, the function value of T calculated is - returned. + sample in T is returned. For a regression model, the function + value of T calculated is returned. For an one-class model, +1 or -1 is returned. @@ -136,6 +127,14 @@ class BaseLibsvm(object): self.nSV_, self.label_, self.probA_, self.probB_) + + @property + def coef_(self): + if self._kernel_types[self.kernel] != 'linear': + raise NotImplementedError('coef_ is only available when using a linear kernel') + return np.dot(self.dual_coef_, self.support_) + + ### # Public API # No processing should go into these classes @@ -222,12 +221,6 @@ class SVC(BaseLibsvm): cache_size, eps, C, nr_weight, nu, p, shrinking, probability) - @property - def coef_(self): - if self._kernel_types[self.kernel] != 'linear': - raise NotImplementedError('coef_ is only available when using a linear kernel') - return np.dot(self.dual_coef_, self.support_) - class SVR(BaseLibsvm): """ Support Vector Regression. diff --git a/scikits/learn/tests/test_svm.py b/scikits/learn/tests/test_svm.py index 976d6844249caf1ba560b9418788a781113c4657..81e93316322c6d15161e1555525f9472129fa826 100644 --- a/scikits/learn/tests/test_svm.py +++ b/scikits/learn/tests/test_svm.py @@ -56,16 +56,15 @@ def test_SVR(): TODO: simplify this. btw, is it correct ? """ - clf = svm.SVR() + clf = svm.SVR(kernel='linear') clf.fit(X, Y) pred = clf.predict(T) - assert_array_almost_equal(clf.dual_coef_, - [[-0.01441007, -0.51530605, -0.01365979, - 0.51569493, 0.01387495, 0.01380604]]) - print clf.support_ - assert_array_almost_equal(clf.support_, X) - assert_array_almost_equal(pred,[ 1.10001274, 1.86682485, 1.73300377]) + assert_array_almost_equal(clf.dual_coef_, [[-0.1, 0.1]]) + assert_array_almost_equal(clf.coef_, [[0.2, 0.2]]) + assert_array_almost_equal(clf.support_, [[-1, -1], [1, 1]]) + assert_array_almost_equal(clf.intercept_, [1.5]) + assert_array_almost_equal(pred, [1.1, 2.3, 2.5]) def test_oneclass(): @@ -134,7 +133,9 @@ def test_error(): def test_LinearSVC(): clf = svm.LinearSVC() clf.fit(X, Y) + assert_array_equal(clf.predict(T), true_result) + assert_array_almost_equal(clf.intercept_, [0]) # the same with l1 penalty clf = svm.LinearSVC(penalty='L1', dual=False)