From c0b4aa41e36ba78185f78020682282e52c080484 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort <alexandre.gramfort@inria.fr> Date: Sun, 28 Nov 2010 12:44:37 -0500 Subject: [PATCH] ENH : adding NeighborsBarycenter for regression pbs using k-Nearest Neighbors API : change api for kneighbors_graph function --- examples/plot_neighbors_regression.py | 44 +++++++ scikits/learn/neighbors.py | 180 +++++++++++++++++++++++--- scikits/learn/tests/test_neighbors.py | 32 +++-- 3 files changed, 228 insertions(+), 28 deletions(-) create mode 100644 examples/plot_neighbors_regression.py diff --git a/examples/plot_neighbors_regression.py b/examples/plot_neighbors_regression.py new file mode 100644 index 0000000000..b5975a321f --- /dev/null +++ b/examples/plot_neighbors_regression.py @@ -0,0 +1,44 @@ +""" +============================== +k-Nearest Neighbors regression +============================== + +Demonstrate the resolution of a regression problem +using a k-Nearest Neighbor and the interpolation of the +target using barycenter computation. + +""" +print __doc__ + +############################################################################### +# Generate sample data +import numpy as np + +np.random.seed(0) +X = np.sort(5*np.random.rand(40, 1), axis=0) +T = np.linspace(0, 5, 500) +y = np.sin(X).ravel() + +# Add noise to targets +y[::5] += 1*(0.5 - np.random.rand(8)) + +############################################################################### +# Fit regression model + +from scikits.learn import neighbors + +knn_barycenter = neighbors.NeighborsBarycenter(k=5) +y_ = knn_barycenter.fit(X, y).predict(T) + +############################################################################### +# look at the results +import pylab as pl +pl.scatter(X, y, c='k', label='data') +pl.hold('on') +pl.plot(T, y_, c='g', label='k-NN prediction') +pl.xlabel('data') +pl.ylabel('target') +pl.legend() +pl.title('k-NN Regression') +pl.show() + diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py index c90e37895a..7f2d15e752 100644 --- a/scikits/learn/neighbors.py +++ b/scikits/learn/neighbors.py @@ -6,8 +6,9 @@ neighbor searches in high dimensionality. """ import numpy as np from scipy import stats +from scipy import linalg -from .base import BaseEstimator, ClassifierMixin +from .base import BaseEstimator, ClassifierMixin, RegressorMixin from .ball_tree import BallTree @@ -144,9 +145,123 @@ def _predict_from_BallTree(ball_tree, Y, test, k): return Y_ return (stats.mode(Y_, axis=1)[0]).ravel() +############################################################################### +# Neighbors Barycenter class for regression problems -def kneighbors_graph(X, k, with_dist=True): - """Computes the graph of k-Neighbors +class NeighborsBarycenter(BaseEstimator, RegressorMixin): + """Regression based on k-Nearest Neighbor Algorithm. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The data points to be indexed. This array is not copied, and so + modifying this data will result in bogus results. + y : array + An array representing labels for the data (only arrays of + integers are supported). + k : int + default number of neighbors. + window_size : int + the default window size. + + Examples + -------- + >>> X = [[0], [1], [2], [3]] + >>> y = [0, 0, 1, 1] + >>> from scikits.learn.neighbors import NeighborsBarycenter + >>> neigh = NeighborsBarycenter(k=2) + >>> neigh.fit(X, y) + NeighborsBarycenter(k=2, window_size=1) + >>> print neigh.predict([[1.5]]) + [ 0.5] + """ + + def __init__(self, k=5, window_size=1): + """Internally uses the ball tree datastructure and algorithm for fast + neighbors lookups on high dimensional datasets. + """ + self.k = k + self.window_size = window_size + + def fit(self, X, y, copy=True): + self._y = np.array(y, copy=copy) + self.ball_tree = BallTree(X, self.window_size) + return self + + def predict(self, T, k=None): + """Predict the target for the provided data. + + Parameters + ---------- + T : array + A 2-D array representing the test data. + k : int + Number of neighbors to get (default is the value + passed to the constructor). + + Returns + ------- + y: array + List of target values (one for each data sample). + + Examples + -------- + >>> X = [[0], [1], [2]] + >>> y = [0, 0, 1] + >>> from scikits.learn.neighbors import NeighborsBarycenter + >>> neigh = NeighborsBarycenter(k=2) + >>> neigh.fit(X, y) + NeighborsBarycenter(k=2, window_size=1) + >>> print neigh.predict([[.5], [1.5]]) + [ 0. 0.5] + """ + T = np.asanyarray(T) + if T.ndim == 1: + T = T[:,None] + if k is None: + k = self.k + A = kneighbors_graph(T, k=k, weight="barycenter", + ball_tree=self.ball_tree).tocsr() + return A * self._y + +############################################################################### +# Utils k-NN based Functions + +def barycenter_weights(x, X_neighbors, tol=1e-3): + """ + Computes barycenter weights so that point may be reconstructed from its + neighbors + + Parameters + ---------- + x : array + a 1D array + + X_neighbors : array + a 2D array containing samples + + tol : float + tolerance + + Returns + ------- + array of barycenter weights that sum to 1 + """ + if x.ndim == 1: + x = x[None,:] + if X_neighbors.ndim == 1: + X_neighbors = X_neighbors[:,None] + z = x - X_neighbors + gram = np.dot(z, z.T) + diag_stride = gram.shape[0] + 1 + gram.flat[::diag_stride] += tol * np.trace(gram) + w = linalg.solve(gram, np.ones(len(X_neighbors))) + w /= np.sum(w) + return w + + +def kneighbors_graph(X, k, weight=None, ball_tree=None, window_size=1): + """Computes the (weighted) graph of k-Neighbors Parameters ---------- @@ -156,37 +271,62 @@ def kneighbors_graph(X, k, with_dist=True): k : int Number of neighbors for each sample. + weight : None (default) + Weights to apply on graph edges. If weight is None + then no weighting is applied (1 for each edge). + If weight equals "distance" the edge weight is the + euclidian distance. If weight equals "barycenter" + the weights are barycenter weights estimated by + solving a linear system for each point. + + ball_tree : None or instance of precomputed BallTree + + window_size : int + Window size pass to the BallTree + Returns ------- A : sparse matrix, shape = [n_samples, n_samples] A is returned as LInked List Sparse matrix - A[i,j] = 1 if sample j is a neighbor of sample i + A[i,j] = weight of edge that connects i to j Examples -------- - >>> X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]] - >>> A = kneighbors_graph(X, k=2, with_dist=False) + >>> X = [[0], [2], [1]] + >>> A = kneighbors_graph(X, k=2, weight=None) >>> print A.todense() - [[ 0. 1. 0.] - [ 1. 0. 0.] - [ 0. 1. 0.]] + [[ 1. 0. 1.] + [ 0. 1. 1.] + [ 0. 1. 1.]] """ from scipy import sparse X = np.asanyarray(X) n_samples = X.shape[0] - A = sparse.lil_matrix((n_samples, n_samples)) - knn = Neighbors(k=k) - dist, ind = knn.fit(X).kneighbors(X) - if with_dist: + if ball_tree is None: + ball_tree = BallTree(X, window_size) + A = sparse.lil_matrix((n_samples, ball_tree.size)) + dist, ind = ball_tree.query(X, k=k) + if weight is None: for i, li in enumerate(ind): - if k > 2: - A[i, list(li[1:])] = dist[i, 1:] + if k > 1: + A[i, list(li)] = np.ones(k) else: - A[i, li[1]] = dist[i, 1] - else: + A[i, li] = 1.0 + elif weight is "distance": + for i, li in enumerate(ind): + if k > 1: + A[i, list(li)] = dist[i, :] + else: + A[i, li] = dist[i, 0] + elif weight is "barycenter": + # XXX : the next loop could be done in parallel + # by parallelizing groups of indices for i, li in enumerate(ind): - if k > 2: - A[i, list(li[1:])] = np.ones(k-1) + if k > 1: + X_i = ball_tree.data[li] + A[i, list(li)] = barycenter_weights(X[i], X_i) else: - A[i, li[1]] = 1.0 + A[i, li] = 1.0 + else: + raise ValueError("Unknown weight type") return A diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py index 501c5fc280..42f9ab2029 100644 --- a/scikits/learn/tests/test_neighbors.py +++ b/scikits/learn/tests/test_neighbors.py @@ -1,5 +1,6 @@ -from numpy.testing import assert_array_equal, assert_array_almost_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal, \ + assert_equal from .. import neighbors @@ -52,18 +53,33 @@ def test_neighbors_2D(): assert_array_equal(prediction, [0, 1, 0, 1]) +def test_neighbors_barycenter(): + """ + NeighborsBarycenter for regression using k-NN + """ + X = [[0], [1], [2], [3]] + y = [0, 0, 1, 1] + neigh = neighbors.NeighborsBarycenter(k=2) + neigh.fit(X, y) + assert_equal(neigh.predict([[1.5]]), 0.5) + + def test_kneighbors_graph(): """ Test kneighbors_graph to build the k-Nearest Neighbor graph. """ - X = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]] - A = neighbors.kneighbors_graph(X, 2, with_dist=False) + X = [[0], [1.01], [2]] + A = neighbors.kneighbors_graph(X, 2, weight=None) assert_array_equal(A.todense(), - [[0, 1, 0], [1, 0, 0], [0, 1, 0]]) - A = neighbors.kneighbors_graph(X, 2) + [[1, 1, 0], [0, 1, 1], [0, 1, 1]]) + A = neighbors.kneighbors_graph(X, 2, weight="distance") + assert_array_almost_equal(A.todense(), + [[0, 1.01, 0], [0, 0, 0.99], [0, 0.99, 0]], 4) + A = neighbors.kneighbors_graph(X, 2, weight="barycenter") assert_array_almost_equal(A.todense(), - [[0, 0.5, 0], [0.5, 0, 0], [0, 1.2247, 0]], 4) + [[0.99, 0, 0], [0, 0.99, 0], [0, 0, 0.99]], 2) # Also check corner cases - A = neighbors.kneighbors_graph(X, 3, with_dist=False) - A = neighbors.kneighbors_graph(X, 3) + A = neighbors.kneighbors_graph(X, 3, weight=None) + A = neighbors.kneighbors_graph(X, 3, weight="distance") + A = neighbors.kneighbors_graph(X, 3, weight="barycenter") -- GitLab