From ebcfcafcea873ef5067450bf548cd683ca9ab770 Mon Sep 17 00:00:00 2001 From: Vathsala Achar <vathsala.sachar@gmail.com> Date: Thu, 6 Apr 2017 23:42:34 +0100 Subject: [PATCH] [MRG+1] OneClassSVM predict now returns int (#8711) * Added predict method to class OneClassSVM - This method overrides the default behaviour to only return integer class values. * Test for the change in OneClassSVM predict method - small addition to test that the predicted result is integer * Changed test to check for intp datatype * Updated whats new --- doc/whats_new.rst | 3 +++ sklearn/svm/classes.py | 21 +++++++++++++++++++++ sklearn/svm/tests/test_svm.py | 3 ++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6978d3943b..ce72f193ed 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -262,6 +262,9 @@ Bug fixes - Fixed a bug in :class:`manifold.TSNE` where it stored the incorrect ``kl_divergence_``. :issue:`6507` by :user:`Sebastian Saeger <ssaeger>`. + - Fixed a bug in :class:`svm.OneClassSVM` where it returned floats instead of + integer classes. :issue:`8676` by :user:`Vathsala Achar <VathsalaAchar>`. + API changes summary ------------------- diff --git a/sklearn/svm/classes.py b/sklearn/svm/classes.py index 2de3029cb2..8b7d2f42bd 100644 --- a/sklearn/svm/classes.py +++ b/sklearn/svm/classes.py @@ -1064,3 +1064,24 @@ class OneClassSVM(BaseLibSVM): """ dec = self._decision_function(X) return dec + + def predict(self, X): + """ + Perform classification on samples in X. + + For an one-class model, +1 or -1 is returned. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + For kernel="precomputed", the expected shape of X is + [n_samples_test, n_samples_train] + + Returns + ------- + y_pred : array, shape (n_samples,) + Class labels for samples in X. + """ + y = super(OneClassSVM, self).predict(X) + return np.asarray(y, dtype=np.intp) + diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 0f85be117a..daf35f82a3 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -241,7 +241,8 @@ def test_oneclass(): clf.fit(X) pred = clf.predict(T) - assert_array_almost_equal(pred, [-1, -1, -1]) + assert_array_equal(pred, [-1, -1, -1]) + assert_equal(pred.dtype, np.dtype('intp')) assert_array_almost_equal(clf.intercept_, [-1.008], decimal=3) assert_array_almost_equal(clf.dual_coef_, [[0.632, 0.233, 0.633, 0.234, 0.632, 0.633]], -- GitLab