Skip to content
Snippets Groups Projects
Commit 01e93e74 authored by Alexandre Gramfort's avatar Alexandre Gramfort
Browse files

ENH : adding kneighbors_graph to build the graph of neighbors as a sparse matrix

MISC : pep8 in neighbors module
parent c9abe509
Branches
No related tags found
No related merge requests found
......@@ -10,9 +10,9 @@ from scipy import stats
from .base import BaseEstimator, ClassifierMixin
from .ball_tree import BallTree
class Neighbors(BaseEstimator, ClassifierMixin):
"""
Classifier implementing k-Nearest Neighbor Algorithm.
"""Classifier implementing k-Nearest Neighbor Algorithm.
Parameters
----------
......@@ -40,8 +40,7 @@ class Neighbors(BaseEstimator, ClassifierMixin):
"""
def __init__(self, k=5, window_size=1):
"""
Internally uses the ball tree datastructure and algorithm for fast
"""Internally uses the ball tree datastructure and algorithm for fast
neighbors lookups on high dimensional datasets.
"""
self.k = k
......@@ -54,8 +53,7 @@ class Neighbors(BaseEstimator, ClassifierMixin):
return self
def kneighbors(self, data, k=None):
"""
Finds the K-neighbors of a point.
"""Finds the K-neighbors of a point.
Parameters
----------
......@@ -68,7 +66,7 @@ class Neighbors(BaseEstimator, ClassifierMixin):
Returns
-------
dist : array
Array representing the lenghts to point.
Array representing the lengths to point.
ind : array
Array representing the indices of the nearest points in the
population matrix.
......@@ -100,10 +98,8 @@ class Neighbors(BaseEstimator, ClassifierMixin):
k = self.k
return self.ball_tree.query(data, k=k)
def predict(self, T, k=None):
"""
Predict the class labels for the provided data.
"""Predict the class labels for the provided data.
Parameters
----------
......@@ -138,12 +134,59 @@ class Neighbors(BaseEstimator, ClassifierMixin):
def _predict_from_BallTree(ball_tree, Y, test, k):
"""
Predict target from BallTree object containing the data points.
"""Predict target from BallTree object containing the data points.
This is a helper method, not meant to be used directly. It will
not check that input is of the correct type.
"""
Y_ = Y[ball_tree.query(test, k=k, return_distance=False)]
if k == 1: return Y_
if k == 1:
return Y_
return (stats.mode(Y_, axis=1)[0]).ravel()
def kneighbors_graph(X, k, with_dist=True):
"""Computes the graph of k-Neighbors
Parameters
----------
X : array-like, shape = [n_samples, n_features]
Coordinates of samples. One sample per row.
k : int
Number of neighbors for each sample.
Returns
-------
A : sparse matrix, shape = [n_samples, n_samples]
A is returned as LInked List Sparse matrix
A[i,j] = 1 if sample j is a neighbor of sample i
Examples
--------
>>> X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> A = kneighbors_graph(X, k=2, with_dist=False)
>>> print A
(0, 1) 1.0
(1, 0) 1.0
(2, 1) 1.0
"""
from scipy import sparse
X = np.asanyarray(X)
n_samples = X.shape[0]
A = sparse.lil_matrix((n_samples, n_samples))
knn = Neighbors(k=k)
dist, ind = knn.fit(X).kneighbors(X)
if with_dist:
for i, li in enumerate(ind):
if k > 2:
A[i, list(li[1:])] = dist[i, 1:]
else:
A[i, li[1]] = dist[i, 1]
else:
for i, li in enumerate(ind):
if k > 2:
A[i, list(li[1:])] = np.ones(k-1)
else:
A[i, li[1]] = 1.0
return A
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal
from .. import neighbors
def test_neighbors_1D():
"""
Nearest Neighbors in a line.
......@@ -18,7 +19,8 @@ def test_neighbors_1D():
# k = 1
knn = neighbors.Neighbors(k=1)
knn.fit(X, Y)
test = [[i +0.01] for i in range(0,n_2)] + [[i-0.01] for i in range(n_2, n)]
test = [[i + 0.01] for i in range(0, n_2)] + \
[[i - 0.01] for i in range(n_2, n)]
assert_array_equal(knn.predict(test), [0, 0, 0, 1, 1, 1])
# same as before, but using predict() instead of Neighbors object
......@@ -49,3 +51,19 @@ def test_neighbors_2D():
prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]])
assert_array_equal(prediction, [0, 1, 0, 1])
def test_kneighbors_graph():
"""
Test kneighbors_graph to build the k-Nearest Neighbor graph.
"""
X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
A = neighbors.kneighbors_graph(X, 2, with_dist=False)
assert_array_equal(A.todense(),
[[0, 1, 0], [1, 0, 0], [0, 1, 0]])
A = neighbors.kneighbors_graph(X, 2)
assert_array_almost_equal(A.todense(),
[[0, 0.5, 0], [0.5, 0, 0], [0, 1.2247, 0]], 4)
# Also check corner cases
A = neighbors.kneighbors_graph(X, 3, with_dist=False)
A = neighbors.kneighbors_graph(X, 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment