Skip to content
Snippets Groups Projects
Commit 538917ca authored by Fabian Pedregosa's avatar Fabian Pedregosa
Browse files

Fix bug when target vector dtype != int.

The idea of using np.bincount for estimating the mode was a good one,
but infortunately it will only work with ints. Oh, well, let's just
use scipy.stats.mode

From: Fabian Pedregosa <fabian.pedregosa@inria.fr>

git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@553 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8
parent d973250d
No related branches found
No related tags found
No related merge requests found
......@@ -4,10 +4,9 @@ k-Nearest Neighbor Algorithm.
Uses BallTree algorithm, which is an efficient way to perform fast
neighbor searches in high dimensionality.
"""
from scipy.stats import mode
from BallTree import BallTree
import numpy as np
from scipy import stats
from BallTree import BallTree
class Neighbors:
"""
......@@ -97,7 +96,7 @@ class Neighbors:
return self.ball_tree.query(data, k=k)
def predict(self, test, k=None):
def predict(self, T, k=None):
"""
Predict the class labels for the provided data.
......@@ -127,8 +126,9 @@ class Neighbors:
>>> print neigh.predict([[0., -1., 0.], [3., 2., 0.]])
[0 1]
"""
T = np.asanyarray(T)
if k is None: k = self._k
return _predict_from_BallTree(self.ball_tree, self.y, np.asarray(test), k=k)
return _predict_from_BallTree(self.ball_tree, self.y, T, k=k)
def _predict_from_BallTree(ball_tree, Y, test, k):
......@@ -138,14 +138,11 @@ def _predict_from_BallTree(ball_tree, Y, test, k):
This is a helper method, not meant to be used directly. It will
not check that input is of the correct type.
"""
Y_hat = Y[ball_tree.query(test, k=k, return_distance=False)]
Y_ = Y[ball_tree.query(test, k=k, return_distance=False)]
if k == 1: return Y_hat
# search most common values along axis 1 of labels
# much faster than scipy.stats.mode
return np.apply_along_axis(
lambda x: np.bincount(x).argmax(),
axis=1,
arr=Y_hat)
return stats.mode(Y_, axis=1)[0]
def predict(X, Y, test, k=5):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment