diff --git a/doc/developers/neighbors.rst b/doc/developers/neighbors.rst index 1938f059318305c3b6fd97d8d49f99de245fdd61..ae98b5b0be323565accaef51f112996c2dbc6284 100644 --- a/doc/developers/neighbors.rst +++ b/doc/developers/neighbors.rst @@ -33,7 +33,8 @@ Performance ----------- The algorithm has to iterate over n_samples, which is the main -bottleneck. It would be great to vectorize this loop. +bottleneck. It would be great to vectorize this loop. Also, the rank +updates could probably be moved outside the loop. Also, least squares solution could be computed more efficiently by a QR factorization, since probably we don't care about a minimum norm @@ -41,9 +42,8 @@ solution for the undertermined case. The paper 'An introduction to Locally Linear Embeddings', Saul & Roweis solves the problem by the normal equation method over the -covariance matrix. This has the disadvantage that it does not degrade -grathefully when the covariance is singular, requiring to explicitly -add regularization. +covariance matrix. However, it does not degrade grathefully when the +covariance is singular, requiring to explicitly add regularization. Stability diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 50bea1baf5d8669ad4958388b43d498e4534efa9..8955bd140c652d64e34a1438dfad0a8c91250b55 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -105,8 +105,8 @@ Nearest Neighbors :toctree: generated/ :template: class.rst - neighbors.Neighbors - neighbors.NeighborsBarycenter + neighbors.NeighborsClassifier + neighbors.NeighborsRegressor ball_tree.BallTree .. autosummary:: diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index e069ec15bfc4c33ae03be30e12035bfe24a8b5a1..68bcac1e212ea91cf070f88060b21208c29059c0 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -16,9 +16,9 @@ the decision boundary is very irregular. Classification ============== -The :class:`Neighbors` estimators implements the nearest-neighbors -classification method using a vote heuristic: the class most present in -the k nearest neighbors of a point is assigned to this point. +The :class:`NeighborsClassifier` implements the nearest-neighbors +classification method using a vote heuristic: the class most present +in the k nearest neighbors of a point is assigned to this point. .. figure:: ../auto_examples/images/plot_neighbors.png :target: ../auto_examples/plot_neighbors.html @@ -31,12 +31,17 @@ the k nearest neighbors of a point is assigned to this point. * :ref:`example_plot_neighbors.py`: an example of classification using nearest neighbor. + Regression ========== -The :class:`NeighborsBarycenter` estimator implements a nearest-neighbors -regression method using barycenter weighting of the targets of the -k-neighbors. +The :class:`NeighborsRegressor` estimator implements a +nearest-neighbors regression method by weighting the targets of the +k-neighbors. Two different weighting strategies are implemented: +``barycenter`` and ``mean``. ``barycenter`` will apply the weights +that best reconstruct the point from its neighbors while ``mean`` will +apply constant weights to each point. This plot shows the behavior of +both classifier for a simple regression task. .. figure:: ../auto_examples/images/plot_neighbors_regression.png :target: ../auto_examples/plot_neighbors_regression.html diff --git a/examples/plot_neighbors.py b/examples/plot_neighbors.py index 4bdf51f44e778d217c7ac1b0d033b5450f786040..812a23f95a52ee49785b02c8a84ded165d9336ca 100644 --- a/examples/plot_neighbors.py +++ b/examples/plot_neighbors.py @@ -22,7 +22,7 @@ h = .02 # step size in the mesh # we create an instance of SVM and fit out data. We do not scale our # data since we want to plot the support vectors -clf = neighbors.Neighbors() +clf = neighbors.NeighborsClassifier() clf.fit(X, Y) # Plot the decision boundary. For that, we will asign a color to each diff --git a/examples/plot_neighbors_regression.py b/examples/plot_neighbors_regression.py index 7d259ea4998ab36ca4e1e21077d43c1332c05521..e9bf38e283673752eb44b95c4d1b7330ef9aa650 100644 --- a/examples/plot_neighbors_regression.py +++ b/examples/plot_neighbors_regression.py @@ -5,14 +5,22 @@ k-Nearest Neighbors regression Demonstrate the resolution of a regression problem using a k-Nearest Neighbor and the interpolation of the -target using barycenter computation. +target using both barycenter and constant weights. """ print __doc__ +# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr> +# Fabian Pedregosa <fabian.pedregosa@inria.fr> +# +# License: BSD, (C) INRIA + + ############################################################################### # Generate sample data import numpy as np +import pylab as pl +from scikits.learn import neighbors np.random.seed(0) X = np.sort(5*np.random.rand(40, 1), axis=0) @@ -25,20 +33,17 @@ y[::5] += 1*(0.5 - np.random.rand(8)) ############################################################################### # Fit regression model -from scikits.learn import neighbors +for i, mode in enumerate(('mean', 'barycenter')): + knn = neighbors.NeighborsRegressor(n_neighbors=4, mode=mode) + y_ = knn.fit(X, y).predict(T) -knn_barycenter = neighbors.NeighborsBarycenter(n_neighbors=5) -y_ = knn_barycenter.fit(X, y).predict(T) + pl.subplot(2, 1, 1 + i) + pl.scatter(X, y, c='k', label='data') + pl.plot(T, y_, c='g', label='prediction') + pl.axis('tight') + pl.legend() + pl.title('NeighborsRegressor with %s weights' % mode) -############################################################################### -# look at the results -import pylab as pl -pl.scatter(X, y, c='k', label='data') -pl.hold('on') -pl.plot(T, y_, c='g', label='k-NN prediction') -pl.xlabel('data') -pl.ylabel('target') -pl.legend() -pl.title('k-NN Regression') +pl.subplots_adjust(0.1, 0.04, 0.95, 0.94, 0.3, 0.28) pl.show() diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py index f933d21716c3cec6b0b4f3bab23996ad3d8b7fa0..18b5e474295de6aa7230aa81cb7f2d37e29e2512 100644 --- a/scikits/learn/neighbors.py +++ b/scikits/learn/neighbors.py @@ -1,6 +1,5 @@ -""" -Nearest Neighbor related algorithms. -""" +"""Nearest Neighbor related algorithms""" + # Author: Fabian Pedregosa <fabian.pedregosa@inria.fr> # Alexandre Gramfort <alexandre.gramfort@inria.fr> # @@ -12,19 +11,14 @@ from .base import BaseEstimator, ClassifierMixin, RegressorMixin from .ball_tree import BallTree -class Neighbors(BaseEstimator, ClassifierMixin): +class NeighborsClassifier(BaseEstimator, ClassifierMixin): """Classifier implementing k-Nearest Neighbor Algorithm. Parameters ---------- - data : array-like, shape (n, k) - The data points to be indexed. This array is not copied, and so - modifying this data will result in bogus results. - labels : array - An array representing labels for the data (only arrays of - integers are supported). n_neighbors : int default number of neighbors. + window_size : int Window size passed to BallTree @@ -32,10 +26,10 @@ class Neighbors(BaseEstimator, ClassifierMixin): -------- >>> samples = [[0, 0, 1], [1, 0, 0]] >>> labels = [0, 1] - >>> from scikits.learn.neighbors import Neighbors - >>> neigh = Neighbors(n_neighbors=1) + >>> from scikits.learn.neighbors import NeighborsClassifier + >>> neigh = NeighborsClassifier(n_neighbors=1) >>> neigh.fit(samples, labels) - Neighbors(n_neighbors=1, window_size=1) + NeighborsClassifier(n_neighbors=1, window_size=1) >>> print neigh.predict([[0,0,0]]) [1] @@ -102,16 +96,16 @@ class Neighbors(BaseEstimator, ClassifierMixin): Examples -------- - In the following example, we construnct a Neighbors class from an - array representing our data set and ask who's the closest point to - [1,1,1] + In the following example, we construnct a NeighborsClassifier + class from an array representing our data set and ask who's + the closest point to [1,1,1] >>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]] >>> labels = [0, 0, 1] - >>> from scikits.learn.neighbors import Neighbors - >>> neigh = Neighbors(n_neighbors=1) + >>> from scikits.learn.neighbors import NeighborsClassifier + >>> neigh = NeighborsClassifier(n_neighbors=1) >>> neigh.fit(samples, labels) - Neighbors(n_neighbors=1, window_size=1) + NeighborsClassifier(n_neighbors=1, window_size=1) >>> print neigh.kneighbors([1., 1., 1.]) (array([ 0.5]), array([2])) @@ -145,19 +139,6 @@ class Neighbors(BaseEstimator, ClassifierMixin): ------- labels: array List of class labels (one for each data sample). - - Examples - -------- - >>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]] - >>> labels = [0, 0, 1] - >>> from scikits.learn.neighbors import Neighbors - >>> neigh = Neighbors(n_neighbors=1) - >>> neigh.fit(samples, labels) - Neighbors(n_neighbors=1, window_size=1) - >>> neigh.predict([.2, .1, .2]) - array([0]) - >>> neigh.predict([[0., -1., 0.], [3., 2., 0.]]) - array([0, 1]) """ X = np.atleast_2d(X) self._set_params(**params) @@ -172,39 +153,38 @@ class Neighbors(BaseEstimator, ClassifierMixin): ############################################################################### -# Neighbors Barycenter class for regression problems +# NeighborsRegressor class for regression problems -class NeighborsBarycenter(Neighbors, RegressorMixin): +class NeighborsRegressor(NeighborsClassifier, RegressorMixin): """Regression based on k-Nearest Neighbor Algorithm. The target is predicted by local interpolation of the targets associated of the k-Nearest Neighbors in the training set. - The interpolation weights correspond to barycenter weights. + + Different modes for estimating the result can be set via parameter + mode. 'barycenter' will apply the weights that best reconstruct + the point from its neighbors while 'mean' will apply constant + weights to each point. Parameters ---------- - X : array-like, shape (n_samples, n_features) - The data points to be indexed. This array is not copied, and so - modifying this data will result in bogus results. - - y : array-like, shape (n_samples) - An array representing labels for the data (only arrays of - integers are supported). - n_neighbors : int default number of neighbors. window_size : int Window size passed to BallTree + mode : {'mean', 'barycenter'} + Weights to apply to labels. + Examples -------- >>> X = [[0], [1], [2], [3]] >>> y = [0, 0, 1, 1] - >>> from scikits.learn.neighbors import NeighborsBarycenter - >>> neigh = NeighborsBarycenter(n_neighbors=2) + >>> from scikits.learn.neighbors import NeighborsRegressor + >>> neigh = NeighborsRegressor(n_neighbors=2) >>> neigh.fit(X, y) - NeighborsBarycenter(n_neighbors=2, window_size=1) + NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean') >>> print neigh.predict([[1.5]]) [ 0.5] @@ -213,6 +193,13 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm """ + + def __init__(self, n_neighbors=5, mode='mean', window_size=1): + self.n_neighbors = n_neighbors + self.window_size = window_size + self.mode = mode + + def predict(self, X, **params): """Predict the target for the provided data. @@ -229,32 +216,31 @@ class NeighborsBarycenter(Neighbors, RegressorMixin): ------- y: array List of target values (one for each data sample). - - Examples - -------- - >>> X = [[0], [1], [2]] - >>> y = [0, 0, 1] - >>> from scikits.learn.neighbors import NeighborsBarycenter - >>> neigh = NeighborsBarycenter(n_neighbors=2) - >>> neigh.fit(X, y) - NeighborsBarycenter(n_neighbors=2, window_size=1) - >>> neigh.predict([[.5], [1.5]]) - array([ 0. , 0.5]) """ X = np.atleast_2d(np.asanyarray(X)) self._set_params(**params) - # get neighbors of X +# +# .. compute neighbors .. +# neigh_ind = self.ball_tree.query( X, k=self.n_neighbors, return_distance=False) neigh = self.ball_tree.data[neigh_ind] - # compute barycenters at each point - B = barycenter_weights(X, neigh) - labels = self._y[neigh_ind] +# +# .. return labels .. +# + if self.mode == 'barycenter': + W = barycenter_weights(X, neigh) + return (W * self._y[neigh_ind]).sum(axis=1) - return (B * labels).sum(axis=1) + elif self.mode == 'mean': + return np.mean(self._y[neigh_ind], axis=1) + else: + raise ValueError( + 'Unsupported mode, must be one of "barycenter" or ' + '"mean" but got %s instead' % self.mode) ############################################################################### # Utils k-NN based Functions @@ -281,6 +267,9 @@ def barycenter_weights(X, Z, cond=None): ------- B : array-like, shape (n_samples, n_neighbors) + Notes + ----- + See developers note for more information. """ # # .. local variables .. @@ -308,6 +297,7 @@ def barycenter_weights(X, Z, cond=None): C[:, 1:], X[i] - C[:, 0] / np.sqrt(n_neighbors), cond=cond, overwrite_a=True, overwrite_b=True)[0].ravel() B[i] = rank_update(alpha, v, np.dot(v.T, B[i]), a=B[i]) + return B @@ -322,7 +312,7 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'): n_neighbors : int Number of neighbors for each sample. - mode : 'connectivity' | 'distance' | 'barycenter' + mode : {'connectivity', 'distance', 'barycenter'} Type of returned matrix: 'connectivity' will return the connectivity matrix with ones and zeros, in 'distance' the edges are euclidian distance between points. In 'barycenter' @@ -331,7 +321,6 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'): Returns ------- - A : CSR sparse matrix, shape = [n_samples, n_samples] A[i,j] is assigned the weight of edge that connects i to j. @@ -345,16 +334,20 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'): [ 0., 1., 1.], [ 1., 0., 1.]]) """ + +# +# .. local variables .. +# from scipy import sparse X = np.asanyarray(X) - n_samples = X.shape[0] ball_tree = BallTree(X) - - # CSR matrix A is represented as A_data, A_ind and A_indptr. n_nonzero = n_neighbors * n_samples A_indptr = np.arange(0, n_nonzero + 1, n_neighbors) +# +# .. construct CSR matrix .. +# if mode is 'connectivity': A_data = np.ones((n_samples, n_neighbors)) A_ind = ball_tree.query( @@ -371,7 +364,9 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'): A_data = barycenter_weights(X, X[A_ind]) else: - raise ValueError("Unsupported mode type") + raise ValueError( + 'Unsupported mode, must be one of "connectivity", ' + '"distance" or "barycenter" but got %s instead' % mode) A = sparse.csr_matrix( (A_data.reshape(-1), A_ind.reshape(-1), A_indptr), diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py index 0ea6f7a7907e5bcfb47fc594932741df21b338b4..34d448f9a0250c0fee10603f50ba2ab35e37f19e 100644 --- a/scikits/learn/tests/test_neighbors.py +++ b/scikits/learn/tests/test_neighbors.py @@ -16,20 +16,20 @@ def test_neighbors_1D(): Y = [0]*(n/2) + [1]*(n/2) # n_neighbors = 1 - knn = neighbors.Neighbors(n_neighbors=1) + knn = neighbors.NeighborsClassifier(n_neighbors=1) knn.fit(X, Y) test = [[i + 0.01] for i in range(0, n/2)] + \ [[i - 0.01] for i in range(n/2, n)] assert_array_equal(knn.predict(test), [0]*3 + [1]*3) # n_neighbors = 2 - knn = neighbors.Neighbors(n_neighbors=2) + knn = neighbors.NeighborsClassifier(n_neighbors=2) knn.fit(X, Y) assert_array_equal(knn.predict(test), [0]*4 + [1]*2) # n_neighbors = 3 - knn = neighbors.Neighbors(n_neighbors=3) + knn = neighbors.NeighborsClassifier(n_neighbors=3) knn.fit(X, Y) assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]), [0 for i in range(n/2)]) @@ -49,22 +49,27 @@ def test_neighbors_2D(): (-1, 0), (-1, -1), (0, -1)) # label 1 n_2 = len(X)/2 Y = [0]*n_2 + [1]*n_2 - knn = neighbors.Neighbors() + knn = neighbors.NeighborsClassifier() knn.fit(X, Y) prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]]) assert_array_equal(prediction, [0, 1, 0, 1]) -def test_neighbors_barycenter(): +def test_neighbors_regressor(): """ - NeighborsBarycenter for regression using k-NN + NeighborsRegressor for regression using k-NN """ X = [[0], [1], [2], [3]] y = [0, 0, 1, 1] - neigh = neighbors.NeighborsBarycenter(n_neighbors=2) - neigh.fit(X, y) - assert_array_almost_equal(neigh.predict([[1.5]]), [0.5]) + neigh = neighbors.NeighborsRegressor(n_neighbors=3) + neigh.fit(X, y, mode='barycenter') + assert_array_almost_equal( + neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3) + neigh.fit(X, y, mode='mean') + assert_array_almost_equal( + neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3) + def test_kneighbors_graph():