Skip to content
Snippets Groups Projects
Commit b4a6b576 authored by Fabian Pedregosa's avatar Fabian Pedregosa
Browse files

DOC: add docstrings to BallTree.

parent 0ea7d895
No related branches found
No related tags found
No related merge requests found
""" Cython bindings for the C++ BallTree code.
A Ball Tree is a data structure which can be used
to perform fast neighbor searches in data sets of
low to medium dimensionality.
"""
# Author: Thouis Jones
# License: BSD
......@@ -35,6 +39,28 @@ cdef Point *make_point(vals):
################################################################################
# Cython wrapper
cdef class BallTree:
"""
Ball Tree for fast nearest-neighbor searches :
BallTree(M, leafsize=20)
Parameters
----------
M : array-like, shape = [N,D]
N is the number of points in the data set, and
D is the dimension of the parameter space.
Note: if M is an aligned array of doubles (not
necessarily contiguous) then data will not be
copied. Otherwise, an internal copy will be made.
leafsize : positive integer (default = 20)
number of points at which to switch to brute-force. Currently not
implemented.
Notes
-----
brute-force search was removed. docs should be accordingly.
"""
cdef cBallTree *bt_ptr
cdef vector[Point_p] *ptdata
cdef size_t num_points
......@@ -65,6 +91,36 @@ cdef class BallTree:
del self.bt_ptr
def query(self, x, k=1, return_distance=True):
"""
query(x, k=1, return_distance=True)
query the Ball Tree for the k nearest neighbors
Parameters
----------
x : array-like, last dimension self.dim
An array of points to query
k : integer (default = 1)
The number of nearest neighbors to return
return_distance : boolean (default = True)
if True, return a tuple (d,i)
if False, return array i
Returns
-------
i : if return_distance == False
(d,i) : if return_distance == True
d : array of doubles - shape: x.shape[:-1] + (k,)
each entry gives the list of distances to the
neighbors of the corresponding point
(note that distances are not sorted)
i : array of integers - shape: x.shape[:-1] + (k,)
each entry gives the list of indices of
neighbors of the corresponding point
(note that neighbors are not sorted)
"""
x = np.atleast_2d(x)
assert x.shape[-1] == self.num_dims
assert k <= self.num_points
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment