From 56bf5824ee5a05bc4abab11a63d49ade2d3663e2 Mon Sep 17 00:00:00 2001
From: Fabian Pedregosa <fabian.pedregosa@inria.fr>
Date: Tue, 8 Feb 2011 21:35:12 +0100
Subject: [PATCH] Change the algorithm used in neighbors.barycenter.

Use the more stable SVD via a constrained least squares. This avoids
the use of regularization, so parameter eps was removed. Details are
give in doc/developers/
---
 doc/developers/neighbors.rst          | 69 ++++++++++++++++++++++++
 doc/index.rst                         |  1 +
 scikits/learn/neighbors.py            | 75 ++++++++++++---------------
 scikits/learn/tests/test_neighbors.py | 35 ++++++-------
 4 files changed, 118 insertions(+), 62 deletions(-)
 create mode 100644 doc/developers/neighbors.rst

diff --git a/doc/developers/neighbors.rst b/doc/developers/neighbors.rst
new file mode 100644
index 0000000000..1938f05931
--- /dev/null
+++ b/doc/developers/neighbors.rst
@@ -0,0 +1,69 @@
+
+.. _notes_neighbors:
+
+
+.. currentmodule:: scikits.learn.neighbors
+
+=====================================
+scikits.learn.neighbors working notes
+=====================================
+
+barycenter
+==========
+
+Function :func:`barycenter` tries to find appropriate weights to
+reconstruct the point x from a subset (y1, y2, ..., yn), where weights
+sum to one.
+
+This is just a simple case of Equality Constrained Least Squares
+[#f1]_ with constrain dot(np.ones(n), x) = 1. In particular, the Q
+matrix from the QR decomposition of B is the Householder reflection of
+np.ones(n).
+
+
+Purpose
+-------
+
+This method was added to ease some computations in the future manifold
+module, namely in LLE. However, it is still to be shown that it is
+useful and efficient in that context.
+
+
+Performance
+-----------
+
+The algorithm has to iterate over n_samples, which is the main
+bottleneck. It would be great to vectorize this loop.
+
+Also, least squares solution could be computed more efficiently by a
+QR factorization, since probably we don't care about a minimum norm
+solution for the undertermined case.
+
+The paper 'An introduction to Locally Linear Embeddings', Saul &
+Roweis solves the problem by the normal equation method over the
+covariance matrix. This has the disadvantage that it does not degrade
+grathefully when the covariance is singular, requiring to explicitly
+add regularization.
+
+
+Stability
+---------
+
+Should be good as it uses SVD to solve the LS problem. TODO: explicit
+bounds.
+
+
+API
+---
+
+The API is convenient to use from NeighborsBarycenter and
+kneighbors_graph, but might not be very easy to use directly due to
+the fact that Y must be a 3-D array.
+
+It should be checked that it is usable in other contexts.
+
+
+.. rubric:: Footnotes
+
+.. [#f1] Section 12.1.4 ('Equality Constrained Least Squares'),
+         'Matrix Computations' by Golub & Van Loan 
diff --git a/doc/index.rst b/doc/index.rst
index a4bc49fd61..edd183d483 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -93,5 +93,6 @@ Development
    :maxdepth: 2
 
    developers/index
+   developers/neighbors
    performance
    about
diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py
index 565db36c3d..c92c4d8767 100644
--- a/scikits/learn/neighbors.py
+++ b/scikits/learn/neighbors.py
@@ -213,7 +213,7 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
     http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
     """
 
-    def predict(self, X, eps=1e-6, **params):
+    def predict(self, X, **params):
         """Predict the target for the provided data.
 
         Parameters
@@ -225,9 +225,6 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
             Number of neighbors to get (default is the value
             passed to the constructor).
 
-        eps : float, optional
-           Amount of regularization to add to the Gram matrix.
-
         Returns
         -------
         y: array
@@ -253,7 +250,7 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
         neigh = self.ball_tree.data[neigh_ind]
 
         # compute barycenters at each point
-        B = barycenter(X, neigh, eps=eps)
+        B = barycenter(X, neigh)
         labels = self._y[neigh_ind]
 
         return (B * labels).sum(axis=1)
@@ -262,8 +259,8 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
 ###############################################################################
 # Utils k-NN based Functions
 
-def barycenter(X, Y, eps=1e-6):
-    """
+def barycenter(X, Z, cond=None):
+    """ 
     Compute barycenter weights of X from Y along the first axis.
 
     We estimate the weights to assign to each point in Y[i] to recover
@@ -273,56 +270,48 @@ def barycenter(X, Y, eps=1e-6):
     ----------
     X : array-like, shape (n_samples, n_dim)
 
-    Y : array-like, shape (n_samples, n_neighbors, n_dim)
+    Z : array-like, shape (n_samples, n_neighbors, n_dim)
 
-    eps: float, optional
-        Amount of regularization to add to the diagonal of the Gram
-        matrix.
+    cond: float, optional
+        Cutoff for small singular values; used to determine effective
+        rank of Z[i]. Singular values smaller than ``rcond *
+        largest_singular_value`` are considered zero.
 
     Returns
     -------
     B : array-like, shape (n_samples, n_neighbors)
 
-    Reference
-    ---------
-    'An introduction to Locally Linear Embeddings', Saul & Roweis
     """
+#
+#       .. local variables ..
+#
     from scipy import linalg
-
-    X = np.atleast_2d(X)
+    X, Z = map(np.asanyarray, (X, Z))
+    n_samples, n_neighbors = X.shape[0], Z.shape[1]
     if X.dtype.kind == 'i':
-        # integer arrays truncate regularization
         X = X.astype(np.float)
-
-    Y = np.asanyarray(Y)
-    n_samples, n_neighbors = X.shape[0], Y.shape[1]
-
     B = np.empty((n_samples, n_neighbors), dtype=X.dtype)
+    v = np.ones(n_neighbors, dtype=X.dtype)
+    rank_update, = linalg.get_blas_funcs(('ger',), (X,))
 
-    Z = X[:, np.newaxis] - Y
-
-    # Compute the weights that best reconstruct each data point
-    # from its neighbors, minimizing the cost function:
-    #
-    #     phi(X) = Sum_i|X_i - Sum_j{W_ij X_j}|
-    #
-
-    for i, D in enumerate(Z):
-
-        # Compute Gram matrix
-        Gram = np.dot(D, D.T)
-
-        # Add regularization
-        Gram.flat[::n_neighbors + 1] += eps * np.trace(Gram)
-
-        # Solve the system
-        B[i] = linalg.solve(Gram, np.ones(Gram.shape[0]), sym_pos=True)
-        B[i] /= np.sum(B[i])
-
+#
+#       .. constrained least squares ..
+#
+    v[0] -= np.sqrt(n_neighbors)
+    B[:, 0] = 1. / np.sqrt(n_neighbors)
+    if n_neighbors <= 1:
+        return B
+    alpha = - 1. / (n_neighbors - np.sqrt(n_neighbors))
+    for i, A in enumerate(Z.transpose(0, 2, 1)):
+        C = rank_update(alpha, np.dot(A, v), v, a=A)
+        B[i, 1:] = linalg.lstsq(
+            C[:, 1:], X[i] - C[:, 0] / np.sqrt(n_neighbors), cond=cond,
+            overwrite_a=True, overwrite_b=True)[0]
+        B[i] = rank_update(alpha, v, np.dot(v.T, B[i]), a=B[i])
     return B
 
 
-def kneighbors_graph(X, n_neighbors, mode='connectivity', eps=1e-6):
+def kneighbors_graph(X, n_neighbors, mode='connectivity'):
     """Computes the (weighted) graph of k-Neighbors for points in X
 
     Parameters
@@ -379,7 +368,7 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity', eps=1e-6):
         ind = ball_tree.query(
             X, k=n_neighbors + 1, return_distance=False)
         A_ind = ind[:, 1:]
-        A_data = barycenter(X, X[A_ind], eps=eps)
+        A_data = barycenter(X, X[A_ind])
 
     else:
         raise ValueError("Unsupported mode type")
diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py
index 8d4e1c40e9..e13b008ecf 100644
--- a/scikits/learn/tests/test_neighbors.py
+++ b/scikits/learn/tests/test_neighbors.py
@@ -65,59 +65,56 @@ def test_neighbors_barycenter():
     y = [0, 0, 1, 1]
     neigh = neighbors.NeighborsBarycenter(n_neighbors=2)
     neigh.fit(X, y)
-    assert_equal(neigh.predict([[1.5]]), 0.5)
+    assert_array_almost_equal(neigh.predict([[1.5]]), [0.5])
 
 
 def test_kneighbors_graph():
     """
     Test kneighbors_graph to build the k-Nearest Neighbor graph.
     """
-    X = [[0, 0], [1.01, 0], [2, 0]]
+    X = [[0, 1], [1.01, 1.], [2, 0]]
 
     # n_neighbors = 1
     A = neighbors.kneighbors_graph(X, 1, mode='connectivity')
     assert_array_equal(A.todense(), np.eye(A.shape[0]))
 
     A = neighbors.kneighbors_graph(X, 1, mode='distance')
-    assert_array_equal(
+    assert_array_almost_equal(
         A.todense(),
-        [[ 0.  ,  1.01,  0.  ],
-         [ 0.  ,  0.  ,  0.99],
-         [ 0.  ,  0.99,  0.  ]])
+        [[ 0.        ,  1.01      ,  0.        ],
+         [ 1.01      ,  0.        ,  0.        ],
+         [ 0.        ,  1.40716026,  0.        ]])
 
     A = neighbors.kneighbors_graph(X, 1, mode='barycenter')
-    assert_array_equal(
+    assert_array_almost_equal(
         A.todense(),
         [[ 0.,  1.,  0.],
-        [ 0.,  0.,  1.],
-        [ 0.,  1.,  0.]])
+         [ 1.,  0.,  0.],
+         [ 0.,  1.,  0.]])
 
     # n_neigbors = 2
     A = neighbors.kneighbors_graph(X, 2, mode='connectivity')
     assert_array_equal(
         A.todense(),
         [[ 1.,  1.,  0.],
-         [ 0.,  1.,  1.],
+         [ 1.,  1.,  0.],
          [ 0.,  1.,  1.]])
 
     A = neighbors.kneighbors_graph(X, 2, mode='distance')
     assert_array_almost_equal(
         A.todense(),
-        [[ 0.  ,  1.01,  2.  ],
-        [ 1.01,  0.  ,  0.99],
-        [ 2.  ,  0.99,  0.  ]])
+        [[ 0.        ,  1.01      ,  2.23606798],
+         [ 1.01      ,  0.        ,  1.40716026],
+         [ 2.23606798,  1.40716026,  0.        ]])
 
     A = neighbors.kneighbors_graph(X, 2, mode='barycenter')
     # check that columns sum to one
     assert_array_almost_equal(np.sum(A.todense(), 1), np.ones((3, 1)))
     assert_array_almost_equal(
         A.todense(),
-        [[ 0.        ,  2.02018645, -1.02018645],
-        [ 0.49500001,  0.        ,  0.50499999],
-        [-0.98018357,  1.98018357,  0.        ]])
-    # check that we can reconstruct X from A
-    assert_array_almost_equal(
-        X, A.dot(X), decimal=3)
+        [[ 0.        ,  1.5049745 , -0.5049745 ],
+        [ 0.596     ,  0.        ,  0.404     ],
+        [-0.98019802,  1.98019802,  0.        ]])
 
     # n_neighbors = 3
     A = neighbors.kneighbors_graph(X, 3, mode='connectivity')
-- 
GitLab