From c0b4aa41e36ba78185f78020682282e52c080484 Mon Sep 17 00:00:00 2001
From: Alexandre Gramfort <alexandre.gramfort@inria.fr>
Date: Sun, 28 Nov 2010 12:44:37 -0500
Subject: [PATCH] ENH : adding NeighborsBarycenter for regression pbs using
 k-Nearest Neighbors API : change api for kneighbors_graph function

---
 examples/plot_neighbors_regression.py |  44 +++++++
 scikits/learn/neighbors.py            | 180 +++++++++++++++++++++++---
 scikits/learn/tests/test_neighbors.py |  32 +++--
 3 files changed, 228 insertions(+), 28 deletions(-)
 create mode 100644 examples/plot_neighbors_regression.py

diff --git a/examples/plot_neighbors_regression.py b/examples/plot_neighbors_regression.py
new file mode 100644
index 0000000000..b5975a321f
--- /dev/null
+++ b/examples/plot_neighbors_regression.py
@@ -0,0 +1,44 @@
+"""
+==============================
+k-Nearest Neighbors regression
+==============================
+
+Demonstrate the resolution of a regression problem
+using a k-Nearest Neighbor and the interpolation of the
+target using barycenter computation.
+
+"""
+print __doc__
+
+###############################################################################
+# Generate sample data
+import numpy as np
+
+np.random.seed(0)
+X = np.sort(5*np.random.rand(40, 1), axis=0)
+T = np.linspace(0, 5, 500)
+y = np.sin(X).ravel()
+
+# Add noise to targets
+y[::5] += 1*(0.5 - np.random.rand(8))
+
+###############################################################################
+# Fit regression model
+
+from scikits.learn import neighbors
+
+knn_barycenter = neighbors.NeighborsBarycenter(k=5)
+y_ = knn_barycenter.fit(X, y).predict(T)
+
+###############################################################################
+# look at the results
+import pylab as pl
+pl.scatter(X, y, c='k', label='data')
+pl.hold('on')
+pl.plot(T, y_, c='g', label='k-NN prediction')
+pl.xlabel('data')
+pl.ylabel('target')
+pl.legend()
+pl.title('k-NN Regression')
+pl.show()
+
diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py
index c90e37895a..7f2d15e752 100644
--- a/scikits/learn/neighbors.py
+++ b/scikits/learn/neighbors.py
@@ -6,8 +6,9 @@ neighbor searches in high dimensionality.
 """
 import numpy as np
 from scipy import stats
+from scipy import linalg
 
-from .base import BaseEstimator, ClassifierMixin
+from .base import BaseEstimator, ClassifierMixin, RegressorMixin
 from .ball_tree import BallTree
 
 
@@ -144,9 +145,123 @@ def _predict_from_BallTree(ball_tree, Y, test, k):
         return Y_
     return (stats.mode(Y_, axis=1)[0]).ravel()
 
+###############################################################################
+# Neighbors Barycenter class for regression problems
 
-def kneighbors_graph(X, k, with_dist=True):
-    """Computes the graph of k-Neighbors
+class NeighborsBarycenter(BaseEstimator, RegressorMixin):
+    """Regression based on k-Nearest Neighbor Algorithm.
+
+    Parameters
+    ----------
+    X : array-like, shape (n_samples, n_features)
+        The data points to be indexed. This array is not copied, and so
+        modifying this data will result in bogus results.
+    y : array
+        An array representing labels for the data (only arrays of
+        integers are supported).
+    k : int
+        default number of neighbors.
+    window_size : int
+        the default window size.
+
+    Examples
+    --------
+    >>> X = [[0], [1], [2], [3]]
+    >>> y = [0, 0, 1, 1]
+    >>> from scikits.learn.neighbors import NeighborsBarycenter
+    >>> neigh = NeighborsBarycenter(k=2)
+    >>> neigh.fit(X, y)
+    NeighborsBarycenter(k=2, window_size=1)
+    >>> print neigh.predict([[1.5]])
+    [ 0.5]
+    """
+
+    def __init__(self, k=5, window_size=1):
+        """Internally uses the ball tree datastructure and algorithm for fast
+        neighbors lookups on high dimensional datasets.
+        """
+        self.k = k
+        self.window_size = window_size
+
+    def fit(self, X, y, copy=True):
+        self._y = np.array(y, copy=copy)
+        self.ball_tree = BallTree(X, self.window_size)
+        return self
+
+    def predict(self, T, k=None):
+        """Predict the target for the provided data.
+
+        Parameters
+        ----------
+        T : array
+            A 2-D array representing the test data.
+        k : int
+            Number of neighbors to get (default is the value
+            passed to the constructor).
+
+        Returns
+        -------
+        y: array
+            List of target values (one for each data sample).
+
+        Examples
+        --------
+        >>> X = [[0], [1], [2]]
+        >>> y = [0, 0, 1]
+        >>> from scikits.learn.neighbors import NeighborsBarycenter
+        >>> neigh = NeighborsBarycenter(k=2)
+        >>> neigh.fit(X, y)
+        NeighborsBarycenter(k=2, window_size=1)
+        >>> print neigh.predict([[.5], [1.5]])
+        [ 0.   0.5]
+        """
+        T = np.asanyarray(T)
+        if T.ndim == 1:
+            T = T[:,None]
+        if k is None:
+            k = self.k
+        A = kneighbors_graph(T, k=k, weight="barycenter",
+                                  ball_tree=self.ball_tree).tocsr()
+        return A * self._y
+
+###############################################################################
+# Utils k-NN based Functions
+
+def barycenter_weights(x, X_neighbors, tol=1e-3):
+    """
+    Computes barycenter weights so that point may be reconstructed from its
+    neighbors
+
+    Parameters
+    ----------
+    x : array
+        a 1D array
+
+    X_neighbors : array
+        a 2D array containing samples
+
+    tol : float
+        tolerance
+
+    Returns
+    -------
+    array of barycenter weights that sum to 1
+    """
+    if x.ndim == 1:
+        x = x[None,:]
+    if X_neighbors.ndim == 1:
+        X_neighbors = X_neighbors[:,None]
+    z = x - X_neighbors
+    gram = np.dot(z, z.T)
+    diag_stride = gram.shape[0] + 1
+    gram.flat[::diag_stride] += tol * np.trace(gram)
+    w = linalg.solve(gram, np.ones(len(X_neighbors)))
+    w /= np.sum(w)
+    return w
+
+
+def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1):
+    """Computes the (weighted) graph of k-Neighbors
 
     Parameters
     ----------
@@ -156,37 +271,62 @@ def kneighbors_graph(X, k, with_dist=True):
     k : int
         Number of neighbors for each sample.
 
+    weight : None (default)
+        Weights to apply on graph edges. If weight is None
+        then no weighting is applied (1 for each edge).
+        If weight equals "distance" the edge weight is the
+        euclidian distance. If weight equals "barycenter"
+        the weights are barycenter weights estimated by
+        solving a linear system for each point.
+
+    ball_tree : None or instance of precomputed BallTree
+
+    window_size : int
+        Window size pass to the BallTree
+
     Returns
     -------
     A : sparse matrix, shape = [n_samples, n_samples]
         A is returned as LInked List Sparse matrix
-        A[i,j] = 1 if sample j is a neighbor of sample i
+        A[i,j] = weight of edge that connects i to j
 
     Examples
     --------
-    >>> X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
-    >>> A = kneighbors_graph(X, k=2, with_dist=False)
+    >>> X = [[0], [2], [1]]
+    >>> A = kneighbors_graph(X, k=2, weight=None)
     >>> print A.todense()
-    [[ 0.  1.  0.]
-     [ 1.  0.  0.]
-     [ 0.  1.  0.]]
+    [[ 1.  0.  1.]
+     [ 0.  1.  1.]
+     [ 0.  1.  1.]]
     """
     from scipy import sparse
     X = np.asanyarray(X)
     n_samples = X.shape[0]
-    A = sparse.lil_matrix((n_samples, n_samples))
-    knn = Neighbors(k=k)
-    dist, ind = knn.fit(X).kneighbors(X)
-    if with_dist:
+    if ball_tree is None:
+        ball_tree = BallTree(X, window_size)
+    A = sparse.lil_matrix((n_samples, ball_tree.size))
+    dist, ind = ball_tree.query(X, k=k)
+    if weight is None:
         for i, li in enumerate(ind):
-            if k > 2:
-                A[i, list(li[1:])] = dist[i, 1:]
+            if k > 1:
+                A[i, list(li)] = np.ones(k)
             else:
-                A[i, li[1]] = dist[i, 1]
-    else:
+                A[i, li] = 1.0
+    elif weight is "distance":
+        for i, li in enumerate(ind):
+            if k > 1:
+                A[i, list(li)] = dist[i, :]
+            else:
+                A[i, li] = dist[i, 0]
+    elif weight is "barycenter":
+        # XXX : the next loop could be done in parallel
+        # by parallelizing groups of indices
         for i, li in enumerate(ind):
-            if k > 2:
-                A[i, list(li[1:])] = np.ones(k-1)
+            if k > 1:
+                X_i = ball_tree.data[li]
+                A[i, list(li)] = barycenter_weights(X[i], X_i)
             else:
-                A[i, li[1]] = 1.0
+                A[i, li] = 1.0
+    else:
+        raise ValueError("Unknown weight type")
     return A
diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py
index 501c5fc280..42f9ab2029 100644
--- a/scikits/learn/tests/test_neighbors.py
+++ b/scikits/learn/tests/test_neighbors.py
@@ -1,5 +1,6 @@
 
-from numpy.testing import assert_array_equal, assert_array_almost_equal
+from numpy.testing import assert_array_equal, assert_array_almost_equal, \
+                          assert_equal
 
 from .. import neighbors
 
@@ -52,18 +53,33 @@ def test_neighbors_2D():
     assert_array_equal(prediction, [0, 1, 0, 1])
 
 
+def test_neighbors_barycenter():
+    """
+    NeighborsBarycenter for regression using k-NN
+    """
+    X = [[0], [1], [2], [3]]
+    y = [0, 0, 1, 1]
+    neigh = neighbors.NeighborsBarycenter(k=2)
+    neigh.fit(X, y)
+    assert_equal(neigh.predict([[1.5]]), 0.5)
+
+
 def test_kneighbors_graph():
     """
     Test kneighbors_graph to build the k-Nearest Neighbor graph.
     """
-    X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
-    A = neighbors.kneighbors_graph(X, 2, with_dist=False)
+    X = [[0], [1.01], [2]]
+    A = neighbors.kneighbors_graph(X, 2, weight=None)
     assert_array_equal(A.todense(),
-                              [[0, 1, 0], [1, 0, 0], [0, 1, 0]])
-    A = neighbors.kneighbors_graph(X, 2)
+                              [[1, 1, 0], [0, 1, 1], [0, 1, 1]])
+    A = neighbors.kneighbors_graph(X, 2, weight="distance")
+    assert_array_almost_equal(A.todense(),
+                              [[0, 1.01, 0], [0, 0, 0.99], [0, 0.99, 0]], 4)
+    A = neighbors.kneighbors_graph(X, 2, weight="barycenter")
     assert_array_almost_equal(A.todense(),
-                              [[0, 0.5, 0], [0.5, 0, 0], [0, 1.2247, 0]], 4)
+                              [[0.99, 0, 0], [0, 0.99, 0], [0, 0, 0.99]], 2)
 
     # Also check corner cases
-    A = neighbors.kneighbors_graph(X, 3, with_dist=False)
-    A = neighbors.kneighbors_graph(X, 3)
+    A = neighbors.kneighbors_graph(X, 3, weight=None)
+    A = neighbors.kneighbors_graph(X, 3, weight="distance")
+    A = neighbors.kneighbors_graph(X, 3, weight="barycenter")
-- 
GitLab