From 538917cae1d2994cf457473d28d2dd8da514f301 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Wed, 17 Mar 2010 13:25:37 +0000 Subject: [PATCH] Fix bug when target vector dtype != int. The idea of using np.bincount for estimating the mode was a good one, but infortunately it will only work with ints. Oh, well, let's just use scipy.stats.mode From: Fabian Pedregosa <fabian.pedregosa@inria.fr> git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@553 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8 --- scikits/learn/neighbors.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py index 2f07525861..c09a80a5a8 100644 --- a/scikits/learn/neighbors.py +++ b/scikits/learn/neighbors.py @@ -4,10 +4,9 @@ k-Nearest Neighbor Algorithm. Uses BallTree algorithm, which is an efficient way to perform fast neighbor searches in high dimensionality. """ - -from scipy.stats import mode -from BallTree import BallTree import numpy as np +from scipy import stats +from BallTree import BallTree class Neighbors: """ @@ -97,7 +96,7 @@ class Neighbors: return self.ball_tree.query(data, k=k) - def predict(self, test, k=None): + def predict(self, T, k=None): """ Predict the class labels for the provided data. @@ -127,8 +126,9 @@ class Neighbors: >>> print neigh.predict([[0., -1., 0.], [3., 2., 0.]]) [0 1] """ + T = np.asanyarray(T) if k is None: k = self._k - return _predict_from_BallTree(self.ball_tree, self.y, np.asarray(test), k=k) + return _predict_from_BallTree(self.ball_tree, self.y, T, k=k) def _predict_from_BallTree(ball_tree, Y, test, k): @@ -138,14 +138,11 @@ def _predict_from_BallTree(ball_tree, Y, test, k): This is a helper method, not meant to be used directly. It will not check that input is of the correct type. """ - Y_hat = Y[ball_tree.query(test, k=k, return_distance=False)] + Y_ = Y[ball_tree.query(test, k=k, return_distance=False)] if k == 1: return Y_hat # search most common values along axis 1 of labels # much faster than scipy.stats.mode - return np.apply_along_axis( - lambda x: np.bincount(x).argmax(), - axis=1, - arr=Y_hat) + return stats.mode(Y_, axis=1)[0] def predict(X, Y, test, k=5): -- GitLab