From 538917cae1d2994cf457473d28d2dd8da514f301 Mon Sep 17 00:00:00 2001
From: Fabian Pedregosa <fabian.pedregosa@inria.fr>
Date: Wed, 17 Mar 2010 13:25:37 +0000
Subject: [PATCH] Fix bug when target vector dtype != int.

The idea of using np.bincount for estimating the mode was a good one,
but infortunately it will only work with ints. Oh, well, let's just
use scipy.stats.mode

From: Fabian Pedregosa <fabian.pedregosa@inria.fr>

git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@553 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8
---
 scikits/learn/neighbors.py | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py
index 2f07525861..c09a80a5a8 100644
--- a/scikits/learn/neighbors.py
+++ b/scikits/learn/neighbors.py
@@ -4,10 +4,9 @@ k-Nearest Neighbor Algorithm.
 Uses BallTree algorithm, which is an efficient way to perform fast
 neighbor searches in high dimensionality.
 """
-
-from scipy.stats import mode
-from BallTree import BallTree
 import numpy as np
+from scipy import stats
+from BallTree import BallTree
 
 class Neighbors:
   """
@@ -97,7 +96,7 @@ class Neighbors:
     return self.ball_tree.query(data, k=k)
 
 
-  def predict(self, test, k=None):
+  def predict(self, T, k=None):
     """
     Predict the class labels for the provided data.
 
@@ -127,8 +126,9 @@ class Neighbors:
     >>> print neigh.predict([[0., -1., 0.], [3., 2., 0.]])
     [0 1]
     """
+    T = np.asanyarray(T)
     if k is None: k = self._k
-    return _predict_from_BallTree(self.ball_tree, self.y, np.asarray(test), k=k)
+    return _predict_from_BallTree(self.ball_tree, self.y, T, k=k)
 
 
 def _predict_from_BallTree(ball_tree, Y, test, k):
@@ -138,14 +138,11 @@ def _predict_from_BallTree(ball_tree, Y, test, k):
     This is a helper method, not meant to be used directly. It will
     not check that input is of the correct type.
     """
-    Y_hat = Y[ball_tree.query(test, k=k, return_distance=False)]
+    Y_ = Y[ball_tree.query(test, k=k, return_distance=False)]
     if k == 1: return Y_hat
     # search most common values along axis 1 of labels
     # much faster than scipy.stats.mode
-    return np.apply_along_axis(
-        lambda x: np.bincount(x).argmax(),
-        axis=1,
-        arr=Y_hat)
+    return stats.mode(Y_, axis=1)[0]
 
 
 def predict(X, Y, test, k=5):
-- 
GitLab