Skip to content
Snippets Groups Projects
Commit 46a3108d authored by Alexandre Gramfort's avatar Alexandre Gramfort
Browse files

API renmae k->n_neighbors

parent eb761614
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ y[::5] += 1*(0.5 - np.random.rand(8))
from scikits.learn import neighbors
knn_barycenter = neighbors.NeighborsBarycenter(k=5)
knn_barycenter = neighbors.NeighborsBarycenter(n_neighbors=5)
y_ = knn_barycenter.fit(X, y).predict(T)
###############################################################################
......
......@@ -23,7 +23,7 @@ class Neighbors(BaseEstimator, ClassifierMixin):
labels : array
An array representing labels for the data (only arrays of
integers are supported).
k : int
n_neighbors : int
default number of neighbors.
window_size : int
Window size passed to BallTree
......@@ -33,9 +33,9 @@ class Neighbors(BaseEstimator, ClassifierMixin):
>>> samples = [[0.,0.,1.], [1.,0.,0.], [2.,2.,2.], [2.,5.,4.]]
>>> labels = [0,0,1,1]
>>> from scikits.learn.neighbors import Neighbors
>>> neigh = Neighbors(k=3)
>>> neigh = Neighbors(n_neighbors=3)
>>> neigh.fit(samples, labels)
Neighbors(k=3, window_size=1)
Neighbors(n_neighbors=3, window_size=1)
>>> print neigh.predict([[0,0,0]])
[ 0.]
......@@ -44,11 +44,11 @@ class Neighbors(BaseEstimator, ClassifierMixin):
http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
"""
def __init__(self, k=5, window_size=1):
def __init__(self, n_neighbors=5, window_size=1):
"""Internally uses the ball tree datastructure and algorithm for fast
neighbors lookups on high dimensional datasets.
"""
self.k = k
self.n_neighbors = n_neighbors
self.window_size = window_size
def fit(self, X, Y=()):
......@@ -57,14 +57,14 @@ class Neighbors(BaseEstimator, ClassifierMixin):
self.ball_tree = BallTree(X, self.window_size)
return self
def kneighbors(self, data, k=None):
def kneighbors(self, data, n_neighbors=None):
"""Finds the K-neighbors of a point.
Parameters
----------
point : array-like
The new point.
k : int
n_neighbors : int
Number of neighbors to get (default is the value
passed to the constructor).
......@@ -85,9 +85,9 @@ class Neighbors(BaseEstimator, ClassifierMixin):
>>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> labels = [0, 0, 1]
>>> from scikits.learn.neighbors import Neighbors
>>> neigh = Neighbors(k=1)
>>> neigh = Neighbors(n_neighbors=1)
>>> neigh.fit(samples, labels)
Neighbors(k=1, window_size=1)
Neighbors(n_neighbors=1, window_size=1)
>>> print neigh.kneighbors([1., 1., 1.])
(array(0.5), array(2))
......@@ -99,18 +99,18 @@ class Neighbors(BaseEstimator, ClassifierMixin):
(array([ 0.5 , 1.11803399]), array([1, 2]))
"""
if k is None:
k = self.k
return self.ball_tree.query(data, k=k)
if n_neighbors is None:
n_neighbors = self.n_neighbors
return self.ball_tree.query(data, k=n_neighbors)
def predict(self, T, k=None):
def predict(self, T, n_neighbors=None):
"""Predict the class labels for the provided data.
Parameters
----------
test: array
A 2-D array representing the test point.
k : int
n_neighbors : int
Number of neighbors to get (default is the value
passed to the constructor).
......@@ -124,28 +124,28 @@ class Neighbors(BaseEstimator, ClassifierMixin):
>>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> labels = [0, 0, 1]
>>> from scikits.learn.neighbors import Neighbors
>>> neigh = Neighbors(k=1)
>>> neigh = Neighbors(n_neighbors=1)
>>> neigh.fit(samples, labels)
Neighbors(k=1, window_size=1)
Neighbors(n_neighbors=1, window_size=1)
>>> print neigh.predict([.2, .1, .2])
0
>>> print neigh.predict([[0., -1., 0.], [3., 2., 0.]])
[0 1]
"""
T = np.asanyarray(T)
if k is None:
k = self.k
return _predict_from_BallTree(self.ball_tree, self.Y, T, k=k)
if n_neighbors is None:
n_neighbors = self.n_neighbors
return _predict_from_BallTree(self.ball_tree, self.Y, T, n_neighbors)
def _predict_from_BallTree(ball_tree, Y, test, k):
def _predict_from_BallTree(ball_tree, Y, test, n_neighbors):
"""Predict target from BallTree object containing the data points.
This is a helper method, not meant to be used directly. It will
not check that input is of the correct type.
"""
Y_ = Y[ball_tree.query(test, k=k, return_distance=False)]
if k == 1:
Y_ = Y[ball_tree.query(test, k=n_neighbors, return_distance=False)]
if n_neighbors == 1:
return Y_
return (stats.mode(Y_, axis=1)[0]).ravel()
......@@ -167,7 +167,7 @@ class NeighborsBarycenter(BaseEstimator, RegressorMixin):
y : array
An array representing labels for the data (only arrays of
integers are supported).
k : int
n_neighbors : int
default number of neighbors.
window_size : int
Window size passed to BallTree
......@@ -177,9 +177,9 @@ class NeighborsBarycenter(BaseEstimator, RegressorMixin):
>>> X = [[0], [1], [2], [3]]
>>> y = [0, 0, 1, 1]
>>> from scikits.learn.neighbors import NeighborsBarycenter
>>> neigh = NeighborsBarycenter(k=2)
>>> neigh = NeighborsBarycenter(n_neighbors=2)
>>> neigh.fit(X, y)
NeighborsBarycenter(k=2, window_size=1)
NeighborsBarycenter(n_neighbors=2, window_size=1)
>>> print neigh.predict([[1.5]])
[ 0.5]
......@@ -188,11 +188,11 @@ class NeighborsBarycenter(BaseEstimator, RegressorMixin):
http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
"""
def __init__(self, k=5, window_size=1):
def __init__(self, n_neighbors=5, window_size=1):
"""Internally uses the ball tree datastructure and algorithm for fast
neighbors lookups on high dimensional datasets.
"""
self.k = k
self.n_neighbors = n_neighbors
self.window_size = window_size
def fit(self, X, y, copy=True):
......@@ -200,14 +200,14 @@ class NeighborsBarycenter(BaseEstimator, RegressorMixin):
self.ball_tree = BallTree(X, self.window_size)
return self
def predict(self, T, k=None):
def predict(self, T, n_neighbors=None):
"""Predict the target for the provided data.
Parameters
----------
T : array
A 2-D array representing the test data.
k : int
n_neighbors : int
Number of neighbors to get (default is the value
passed to the constructor).
......@@ -221,18 +221,18 @@ class NeighborsBarycenter(BaseEstimator, RegressorMixin):
>>> X = [[0], [1], [2]]
>>> y = [0, 0, 1]
>>> from scikits.learn.neighbors import NeighborsBarycenter
>>> neigh = NeighborsBarycenter(k=2)
>>> neigh = NeighborsBarycenter(n_neighbors=2)
>>> neigh.fit(X, y)
NeighborsBarycenter(k=2, window_size=1)
NeighborsBarycenter(n_neighbors=2, window_size=1)
>>> print neigh.predict([[.5], [1.5]])
[ 0. 0.5]
"""
T = np.asanyarray(T)
if T.ndim == 1:
T = T[:,None]
if k is None:
k = self.k
A = kneighbors_graph(T, k=k, weight="barycenter",
if n_neighbors is None:
n_neighbors = self.n_neighbors
A = kneighbors_graph(T, n_neighbors=n_neighbors, weight="barycenter",
ball_tree=self.ball_tree).tocsr()
return A * self._y
......@@ -286,7 +286,8 @@ def barycenter_weights(x, X_neighbors, tol=1e-3):
return w
def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
def kneighbors_graph(X, n_neighbors, weight=None, ball_tree=None,
window_size=1):
"""Computes the (weighted) graph of k-Neighbors
Parameters
......@@ -294,7 +295,7 @@ def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
X : array-like, shape = [n_samples, n_features]
Coordinates of samples. One sample per row.
k : int
n_neighbors : int
Number of neighbors for each sample.
weight : None (default)
......@@ -319,7 +320,7 @@ def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
Examples
--------
>>> X = [[0], [2], [1]]
>>> A = kneighbors_graph(X, k=2, weight=None)
>>> A = kneighbors_graph(X, n_neighbors=2, weight=None)
>>> print A.todense()
[[ 1. 0. 1.]
[ 0. 1. 1.]
......@@ -331,16 +332,16 @@ def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
if ball_tree is None:
ball_tree = BallTree(X, window_size)
A = sparse.lil_matrix((n_samples, ball_tree.size))
dist, ind = ball_tree.query(X, k=k)
dist, ind = ball_tree.query(X, k=n_neighbors)
if weight is None:
for i, li in enumerate(ind):
if k > 1:
A[i, list(li)] = np.ones(k)
if n_neighbors > 1:
A[i, list(li)] = np.ones(n_neighbors)
else:
A[i, li] = 1.0
elif weight is "distance":
for i, li in enumerate(ind):
if k > 1:
if n_neighbors > 1:
A[i, list(li)] = dist[i, :]
else:
A[i, li] = dist[i, 0]
......@@ -348,7 +349,7 @@ def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
# XXX : the next loop could be done in parallel
# by parallelizing groups of indices
for i, li in enumerate(ind):
if k > 1:
if n_neighbors > 1:
X_i = ball_tree.data[li]
A[i, list(li)] = barycenter_weights(X[i], X_i)
else:
......
......@@ -17,16 +17,16 @@ def test_neighbors_1D():
X = [[x] for x in range(0, n)]
Y = [0]*n_2 + [1]*n_2
# k = 1
knn = neighbors.Neighbors(k=1)
# n_neighbors = 1
knn = neighbors.Neighbors(n_neighbors=1)
knn.fit(X, Y)
test = [[i + 0.01] for i in range(0, n_2)] + \
[[i - 0.01] for i in range(n_2, n)]
assert_array_equal(knn.predict(test), [0, 0, 0, 1, 1, 1])
# same as before, but using predict() instead of Neighbors object
# k = 3
knn = neighbors.Neighbors(k=3)
# n_neighbors = 3
knn = neighbors.Neighbors(n_neighbors=3)
knn.fit(X, Y)
assert_array_equal(knn.predict([[i +0.01] for i in range(0, n_2)]),
[0 for i in range(n_2)])
......@@ -59,7 +59,7 @@ def test_neighbors_barycenter():
"""
X = [[0], [1], [2], [3]]
y = [0, 0, 1, 1]
neigh = neighbors.NeighborsBarycenter(k=2)
neigh = neighbors.NeighborsBarycenter(n_neighbors=2)
neigh.fit(X, y)
assert_equal(neigh.predict([[1.5]]), 0.5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment