diff --git a/scikits/learn/src/ball_tree.pyx b/scikits/learn/src/ball_tree.pyx index d779017f087f16adb3300922e8dd34815130818b..0347cbf0f2287bb9195a2aa213f53389fc7cee5f 100644 --- a/scikits/learn/src/ball_tree.pyx +++ b/scikits/learn/src/ball_tree.pyx @@ -1,4 +1,8 @@ """ Cython bindings for the C++ BallTree code. + +A Ball Tree is a data structure which can be used +to perform fast neighbor searches in data sets of +low to medium dimensionality. """ # Author: Thouis Jones # License: BSD @@ -30,11 +34,33 @@ cdef Point *make_point(vals): for idx, v in enumerate(vals.flat): SET(pt, idx, v) return pt - + ################################################################################ # Cython wrapper cdef class BallTree: + """ + Ball Tree for fast nearest-neighbor searches : + + BallTree(M, leafsize=20) + + Parameters + ---------- + M : array-like, shape = [N,D] + N is the number of points in the data set, and + D is the dimension of the parameter space. + Note: if M is an aligned array of doubles (not + necessarily contiguous) then data will not be + copied. Otherwise, an internal copy will be made. + + leafsize : positive integer (default = 20) + number of points at which to switch to brute-force. Currently not + implemented. + + Notes + ----- + brute-force search was removed. docs should be accordingly. + """ cdef cBallTree *bt_ptr cdef vector[Point_p] *ptdata cdef size_t num_points @@ -65,6 +91,36 @@ cdef class BallTree: del self.bt_ptr def query(self, x, k=1, return_distance=True): + """ + query(x, k=1, return_distance=True) + + query the Ball Tree for the k nearest neighbors + + Parameters + ---------- + x : array-like, last dimension self.dim + An array of points to query + k : integer (default = 1) + The number of nearest neighbors to return + return_distance : boolean (default = True) + if True, return a tuple (d,i) + if False, return array i + + Returns + ------- + i : if return_distance == False + (d,i) : if return_distance == True + + d : array of doubles - shape: x.shape[:-1] + (k,) + each entry gives the list of distances to the + neighbors of the corresponding point + (note that distances are not sorted) + + i : array of integers - shape: x.shape[:-1] + (k,) + each entry gives the list of indices of + neighbors of the corresponding point + (note that neighbors are not sorted) + """ x = np.atleast_2d(x) assert x.shape[-1] == self.num_dims assert k <= self.num_points