From 56bf5824ee5a05bc4abab11a63d49ade2d3663e2 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Tue, 8 Feb 2011 21:35:12 +0100 Subject: [PATCH] Change the algorithm used in neighbors.barycenter. Use the more stable SVD via a constrained least squares. This avoids the use of regularization, so parameter eps was removed. Details are give in doc/developers/ --- doc/developers/neighbors.rst | 69 ++++++++++++++++++++++++ doc/index.rst | 1 + scikits/learn/neighbors.py | 75 ++++++++++++--------------- scikits/learn/tests/test_neighbors.py | 35 ++++++------- 4 files changed, 118 insertions(+), 62 deletions(-) create mode 100644 doc/developers/neighbors.rst diff --git a/doc/developers/neighbors.rst b/doc/developers/neighbors.rst new file mode 100644 index 0000000000..1938f05931 --- /dev/null +++ b/doc/developers/neighbors.rst @@ -0,0 +1,69 @@ + +.. _notes_neighbors: + + +.. currentmodule:: scikits.learn.neighbors + +===================================== +scikits.learn.neighbors working notes +===================================== + +barycenter +========== + +Function :func:`barycenter` tries to find appropriate weights to +reconstruct the point x from a subset (y1, y2, ..., yn), where weights +sum to one. + +This is just a simple case of Equality Constrained Least Squares +[#f1]_ with constrain dot(np.ones(n), x) = 1. In particular, the Q +matrix from the QR decomposition of B is the Householder reflection of +np.ones(n). + + +Purpose +------- + +This method was added to ease some computations in the future manifold +module, namely in LLE. However, it is still to be shown that it is +useful and efficient in that context. + + +Performance +----------- + +The algorithm has to iterate over n_samples, which is the main +bottleneck. It would be great to vectorize this loop. + +Also, least squares solution could be computed more efficiently by a +QR factorization, since probably we don't care about a minimum norm +solution for the undertermined case. + +The paper 'An introduction to Locally Linear Embeddings', Saul & +Roweis solves the problem by the normal equation method over the +covariance matrix. This has the disadvantage that it does not degrade +grathefully when the covariance is singular, requiring to explicitly +add regularization. + + +Stability +--------- + +Should be good as it uses SVD to solve the LS problem. TODO: explicit +bounds. + + +API +--- + +The API is convenient to use from NeighborsBarycenter and +kneighbors_graph, but might not be very easy to use directly due to +the fact that Y must be a 3-D array. + +It should be checked that it is usable in other contexts. + + +.. rubric:: Footnotes + +.. [#f1] Section 12.1.4 ('Equality Constrained Least Squares'), + 'Matrix Computations' by Golub & Van Loan diff --git a/doc/index.rst b/doc/index.rst index a4bc49fd61..edd183d483 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -93,5 +93,6 @@ Development :maxdepth: 2 developers/index + developers/neighbors performance about diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py index 565db36c3d..c92c4d8767 100644 --- a/scikits/learn/neighbors.py +++ b/scikits/learn/neighbors.py @@ -213,7 +213,7 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm """ - def predict(self, X, eps=1e-6, **params): + def predict(self, X, **params): """Predict the target for the provided data. Parameters @@ -225,9 +225,6 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): Number of neighbors to get (default is the value passed to the constructor). - eps : float, optional - Amount of regularization to add to the Gram matrix. - Returns ------- y: array @@ -253,7 +250,7 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): neigh = self.ball_tree.data[neigh_ind] # compute barycenters at each point - B = barycenter(X, neigh, eps=eps) + B = barycenter(X, neigh) labels = self._y[neigh_ind] return (B * labels).sum(axis=1) @@ -262,8 +259,8 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): ############################################################################### # Utils k-NN based Functions -def barycenter(X, Y, eps=1e-6): - """ +def barycenter(X, Z, cond=None): + """ Compute barycenter weights of X from Y along the first axis. We estimate the weights to assign to each point in Y[i] to recover @@ -273,56 +270,48 @@ def barycenter(X, Y, eps=1e-6): ---------- X : array-like, shape (n_samples, n_dim) - Y : array-like, shape (n_samples, n_neighbors, n_dim) + Z : array-like, shape (n_samples, n_neighbors, n_dim) - eps: float, optional - Amount of regularization to add to the diagonal of the Gram - matrix. + cond: float, optional + Cutoff for small singular values; used to determine effective + rank of Z[i]. Singular values smaller than ``rcond * + largest_singular_value`` are considered zero. Returns ------- B : array-like, shape (n_samples, n_neighbors) - Reference - --------- - 'An introduction to Locally Linear Embeddings', Saul & Roweis """ +# +# .. local variables .. +# from scipy import linalg - - X = np.atleast_2d(X) + X, Z = map(np.asanyarray, (X, Z)) + n_samples, n_neighbors = X.shape[0], Z.shape[1] if X.dtype.kind == 'i': - # integer arrays truncate regularization X = X.astype(np.float) - - Y = np.asanyarray(Y) - n_samples, n_neighbors = X.shape[0], Y.shape[1] - B = np.empty((n_samples, n_neighbors), dtype=X.dtype) + v = np.ones(n_neighbors, dtype=X.dtype) + rank_update, = linalg.get_blas_funcs(('ger',), (X,)) - Z = X[:, np.newaxis] - Y - - # Compute the weights that best reconstruct each data point - # from its neighbors, minimizing the cost function: - # - # phi(X) = Sum_i|X_i - Sum_j{W_ij X_j}| - # - - for i, D in enumerate(Z): - - # Compute Gram matrix - Gram = np.dot(D, D.T) - - # Add regularization - Gram.flat[::n_neighbors + 1] += eps * np.trace(Gram) - - # Solve the system - B[i] = linalg.solve(Gram, np.ones(Gram.shape[0]), sym_pos=True) - B[i] /= np.sum(B[i]) - +# +# .. constrained least squares .. +# + v[0] -= np.sqrt(n_neighbors) + B[:, 0] = 1. / np.sqrt(n_neighbors) + if n_neighbors <= 1: + return B + alpha = - 1. / (n_neighbors - np.sqrt(n_neighbors)) + for i, A in enumerate(Z.transpose(0, 2, 1)): + C = rank_update(alpha, np.dot(A, v), v, a=A) + B[i, 1:] = linalg.lstsq( + C[:, 1:], X[i] - C[:, 0] / np.sqrt(n_neighbors), cond=cond, + overwrite_a=True, overwrite_b=True)[0] + B[i] = rank_update(alpha, v, np.dot(v.T, B[i]), a=B[i]) return B -def kneighbors_graph(X, n_neighbors, mode='connectivity', eps=1e-6): +def kneighbors_graph(X, n_neighbors, mode='connectivity'): """Computes the (weighted) graph of k-Neighbors for points in X Parameters @@ -379,7 +368,7 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity', eps=1e-6): ind = ball_tree.query( X, k=n_neighbors + 1, return_distance=False) A_ind = ind[:, 1:] - A_data = barycenter(X, X[A_ind], eps=eps) + A_data = barycenter(X, X[A_ind]) else: raise ValueError("Unsupported mode type") diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py index 8d4e1c40e9..e13b008ecf 100644 --- a/scikits/learn/tests/test_neighbors.py +++ b/scikits/learn/tests/test_neighbors.py @@ -65,59 +65,56 @@ def test_neighbors_barycenter(): y = [0, 0, 1, 1] neigh = neighbors.NeighborsBarycenter(n_neighbors=2) neigh.fit(X, y) - assert_equal(neigh.predict([[1.5]]), 0.5) + assert_array_almost_equal(neigh.predict([[1.5]]), [0.5]) def test_kneighbors_graph(): """ Test kneighbors_graph to build the k-Nearest Neighbor graph. """ - X = [[0, 0], [1.01, 0], [2, 0]] + X = [[0, 1], [1.01, 1.], [2, 0]] # n_neighbors = 1 A = neighbors.kneighbors_graph(X, 1, mode='connectivity') assert_array_equal(A.todense(), np.eye(A.shape[0])) A = neighbors.kneighbors_graph(X, 1, mode='distance') - assert_array_equal( + assert_array_almost_equal( A.todense(), - [[ 0. , 1.01, 0. ], - [ 0. , 0. , 0.99], - [ 0. , 0.99, 0. ]]) + [[ 0. , 1.01 , 0. ], + [ 1.01 , 0. , 0. ], + [ 0. , 1.40716026, 0. ]]) A = neighbors.kneighbors_graph(X, 1, mode='barycenter') - assert_array_equal( + assert_array_almost_equal( A.todense(), [[ 0., 1., 0.], - [ 0., 0., 1.], - [ 0., 1., 0.]]) + [ 1., 0., 0.], + [ 0., 1., 0.]]) # n_neigbors = 2 A = neighbors.kneighbors_graph(X, 2, mode='connectivity') assert_array_equal( A.todense(), [[ 1., 1., 0.], - [ 0., 1., 1.], + [ 1., 1., 0.], [ 0., 1., 1.]]) A = neighbors.kneighbors_graph(X, 2, mode='distance') assert_array_almost_equal( A.todense(), - [[ 0. , 1.01, 2. ], - [ 1.01, 0. , 0.99], - [ 2. , 0.99, 0. ]]) + [[ 0. , 1.01 , 2.23606798], + [ 1.01 , 0. , 1.40716026], + [ 2.23606798, 1.40716026, 0. ]]) A = neighbors.kneighbors_graph(X, 2, mode='barycenter') # check that columns sum to one assert_array_almost_equal(np.sum(A.todense(), 1), np.ones((3, 1))) assert_array_almost_equal( A.todense(), - [[ 0. , 2.02018645, -1.02018645], - [ 0.49500001, 0. , 0.50499999], - [-0.98018357, 1.98018357, 0. ]]) - # check that we can reconstruct X from A - assert_array_almost_equal( - X, A.dot(X), decimal=3) + [[ 0. , 1.5049745 , -0.5049745 ], + [ 0.596 , 0. , 0.404 ], + [-0.98019802, 1.98019802, 0. ]]) # n_neighbors = 3 A = neighbors.kneighbors_graph(X, 3, mode='connectivity') -- GitLab