diff --git a/doc/developers/neighbors.rst b/doc/developers/neighbors.rst
index 1938f059318305c3b6fd97d8d49f99de245fdd61..ae98b5b0be323565accaef51f112996c2dbc6284 100644
--- a/doc/developers/neighbors.rst
+++ b/doc/developers/neighbors.rst
@@ -33,7 +33,8 @@ Performance
 -----------
 
 The algorithm has to iterate over n_samples, which is the main
-bottleneck. It would be great to vectorize this loop.
+bottleneck. It would be great to vectorize this loop. Also, the rank
+updates could probably be moved outside the loop.
 
 Also, least squares solution could be computed more efficiently by a
 QR factorization, since probably we don't care about a minimum norm
@@ -41,9 +42,8 @@ solution for the undertermined case.
 
 The paper 'An introduction to Locally Linear Embeddings', Saul &
 Roweis solves the problem by the normal equation method over the
-covariance matrix. This has the disadvantage that it does not degrade
-grathefully when the covariance is singular, requiring to explicitly
-add regularization.
+covariance matrix. However, it does not degrade grathefully when the
+covariance is singular, requiring to explicitly add regularization.
 
 
 Stability
diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 50bea1baf5d8669ad4958388b43d498e4534efa9..8955bd140c652d64e34a1438dfad0a8c91250b55 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -105,8 +105,8 @@ Nearest Neighbors
    :toctree: generated/
    :template: class.rst
 
-   neighbors.Neighbors
-   neighbors.NeighborsBarycenter
+   neighbors.NeighborsClassifier
+   neighbors.NeighborsRegressor
    ball_tree.BallTree
 
 .. autosummary::
diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst
index e069ec15bfc4c33ae03be30e12035bfe24a8b5a1..68bcac1e212ea91cf070f88060b21208c29059c0 100644
--- a/doc/modules/neighbors.rst
+++ b/doc/modules/neighbors.rst
@@ -16,9 +16,9 @@ the decision boundary is very irregular.
 Classification
 ==============
 
-The :class:`Neighbors` estimators implements the nearest-neighbors
-classification method using a vote heuristic: the class most present in
-the k nearest neighbors of a point is assigned to this point.
+The :class:`NeighborsClassifier` implements the nearest-neighbors
+classification method using a vote heuristic: the class most present
+in the k nearest neighbors of a point is assigned to this point.
 
 .. figure:: ../auto_examples/images/plot_neighbors.png
    :target: ../auto_examples/plot_neighbors.html
@@ -31,12 +31,17 @@ the k nearest neighbors of a point is assigned to this point.
   * :ref:`example_plot_neighbors.py`: an example of classification
     using nearest neighbor.
 
+
 Regression
 ==========
 
-The :class:`NeighborsBarycenter` estimator implements a nearest-neighbors
-regression method using barycenter weighting of the targets of the
-k-neighbors.
+The :class:`NeighborsRegressor` estimator implements a
+nearest-neighbors regression method by weighting the targets of the
+k-neighbors. Two different weighting strategies are implemented:
+``barycenter`` and ``mean``. ``barycenter`` will apply the weights
+that best reconstruct the point from its neighbors while ``mean`` will
+apply constant weights to each point. This plot shows the behavior of
+both classifier for a simple regression task.
 
 .. figure:: ../auto_examples/images/plot_neighbors_regression.png
    :target: ../auto_examples/plot_neighbors_regression.html
diff --git a/examples/plot_neighbors.py b/examples/plot_neighbors.py
index 4bdf51f44e778d217c7ac1b0d033b5450f786040..812a23f95a52ee49785b02c8a84ded165d9336ca 100644
--- a/examples/plot_neighbors.py
+++ b/examples/plot_neighbors.py
@@ -22,7 +22,7 @@ h = .02 # step size in the mesh
 
 # we create an instance of SVM and fit out data. We do not scale our
 # data since we want to plot the support vectors
-clf = neighbors.Neighbors()
+clf = neighbors.NeighborsClassifier()
 clf.fit(X, Y)
 
 # Plot the decision boundary. For that, we will asign a color to each
diff --git a/examples/plot_neighbors_regression.py b/examples/plot_neighbors_regression.py
index 7d259ea4998ab36ca4e1e21077d43c1332c05521..e9bf38e283673752eb44b95c4d1b7330ef9aa650 100644
--- a/examples/plot_neighbors_regression.py
+++ b/examples/plot_neighbors_regression.py
@@ -5,14 +5,22 @@ k-Nearest Neighbors regression
 
 Demonstrate the resolution of a regression problem
 using a k-Nearest Neighbor and the interpolation of the
-target using barycenter computation.
+target using both barycenter and constant weights.
 
 """
 print __doc__
 
+# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
+#         Fabian Pedregosa <fabian.pedregosa@inria.fr>
+#
+# License: BSD, (C) INRIA
+
+
 ###############################################################################
 # Generate sample data
 import numpy as np
+import pylab as pl
+from scikits.learn import neighbors
 
 np.random.seed(0)
 X = np.sort(5*np.random.rand(40, 1), axis=0)
@@ -25,20 +33,17 @@ y[::5] += 1*(0.5 - np.random.rand(8))
 ###############################################################################
 # Fit regression model
 
-from scikits.learn import neighbors
+for i, mode in enumerate(('mean', 'barycenter')):
+    knn = neighbors.NeighborsRegressor(n_neighbors=4, mode=mode)
+    y_ = knn.fit(X, y).predict(T)
 
-knn_barycenter = neighbors.NeighborsBarycenter(n_neighbors=5)
-y_ = knn_barycenter.fit(X, y).predict(T)
+    pl.subplot(2, 1, 1 + i)
+    pl.scatter(X, y, c='k', label='data')
+    pl.plot(T, y_, c='g', label='prediction')
+    pl.axis('tight')
+    pl.legend()
+    pl.title('NeighborsRegressor with %s weights' % mode)
 
-###############################################################################
-# look at the results
-import pylab as pl
-pl.scatter(X, y, c='k', label='data')
-pl.hold('on')
-pl.plot(T, y_, c='g', label='k-NN prediction')
-pl.xlabel('data')
-pl.ylabel('target')
-pl.legend()
-pl.title('k-NN Regression')
+pl.subplots_adjust(0.1, 0.04, 0.95, 0.94, 0.3, 0.28)
 pl.show()
 
diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py
index f933d21716c3cec6b0b4f3bab23996ad3d8b7fa0..18b5e474295de6aa7230aa81cb7f2d37e29e2512 100644
--- a/scikits/learn/neighbors.py
+++ b/scikits/learn/neighbors.py
@@ -1,6 +1,5 @@
-"""
-Nearest Neighbor related algorithms.
-"""
+"""Nearest Neighbor related algorithms"""
+
 # Author: Fabian Pedregosa <fabian.pedregosa@inria.fr>
 #         Alexandre Gramfort <alexandre.gramfort@inria.fr>
 #
@@ -12,19 +11,14 @@ from .base import BaseEstimator, ClassifierMixin, RegressorMixin
 from .ball_tree import BallTree
 
 
-class Neighbors(BaseEstimator, ClassifierMixin):
+class NeighborsClassifier(BaseEstimator, ClassifierMixin):
     """Classifier implementing k-Nearest Neighbor Algorithm.
 
     Parameters
     ----------
-    data : array-like, shape (n, k)
-        The data points to be indexed. This array is not copied, and so
-        modifying this data will result in bogus results.
-    labels : array
-        An array representing labels for the data (only arrays of
-        integers are supported).
     n_neighbors : int
         default number of neighbors.
+
     window_size : int
         Window size passed to BallTree
 
@@ -32,10 +26,10 @@ class Neighbors(BaseEstimator, ClassifierMixin):
     --------
     >>> samples = [[0, 0, 1], [1, 0, 0]]
     >>> labels = [0, 1]
-    >>> from scikits.learn.neighbors import Neighbors
-    >>> neigh = Neighbors(n_neighbors=1)
+    >>> from scikits.learn.neighbors import NeighborsClassifier
+    >>> neigh = NeighborsClassifier(n_neighbors=1)
     >>> neigh.fit(samples, labels)
-    Neighbors(n_neighbors=1, window_size=1)
+    NeighborsClassifier(n_neighbors=1, window_size=1)
     >>> print neigh.predict([[0,0,0]])
     [1]
 
@@ -102,16 +96,16 @@ class Neighbors(BaseEstimator, ClassifierMixin):
 
         Examples
         --------
-        In the following example, we construnct a Neighbors class from an
-        array representing our data set and ask who's the closest point to
-        [1,1,1]
+        In the following example, we construnct a NeighborsClassifier
+        class from an array representing our data set and ask who's
+        the closest point to [1,1,1]
 
         >>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
         >>> labels = [0, 0, 1]
-        >>> from scikits.learn.neighbors import Neighbors
-        >>> neigh = Neighbors(n_neighbors=1)
+        >>> from scikits.learn.neighbors import NeighborsClassifier
+        >>> neigh = NeighborsClassifier(n_neighbors=1)
         >>> neigh.fit(samples, labels)
-        Neighbors(n_neighbors=1, window_size=1)
+        NeighborsClassifier(n_neighbors=1, window_size=1)
         >>> print neigh.kneighbors([1., 1., 1.])
         (array([ 0.5]), array([2]))
 
@@ -145,19 +139,6 @@ class Neighbors(BaseEstimator, ClassifierMixin):
         -------
         labels: array
             List of class labels (one for each data sample).
-
-        Examples
-        --------
-        >>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
-        >>> labels = [0, 0, 1]
-        >>> from scikits.learn.neighbors import Neighbors
-        >>> neigh = Neighbors(n_neighbors=1)
-        >>> neigh.fit(samples, labels)
-        Neighbors(n_neighbors=1, window_size=1)
-        >>> neigh.predict([.2, .1, .2])
-        array([0])
-        >>> neigh.predict([[0., -1., 0.], [3., 2., 0.]])
-        array([0, 1])
         """
         X = np.atleast_2d(X)
         self._set_params(**params)
@@ -172,39 +153,38 @@ class Neighbors(BaseEstimator, ClassifierMixin):
 
 
 ###############################################################################
-# Neighbors Barycenter class for regression problems
+# NeighborsRegressor class for regression problems
 
-class NeighborsBarycenter(Neighbors, RegressorMixin):
+class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
     """Regression based on k-Nearest Neighbor Algorithm.
 
     The target is predicted by local interpolation of the targets
     associated of the k-Nearest Neighbors in the training set.
-    The interpolation weights correspond to barycenter weights.
+
+    Different modes for estimating the result can be set via parameter
+    mode. 'barycenter' will apply the weights that best reconstruct
+    the point from its neighbors while 'mean' will apply constant
+    weights to each point.
 
     Parameters
     ----------
-    X : array-like, shape (n_samples, n_features)
-        The data points to be indexed. This array is not copied, and so
-        modifying this data will result in bogus results.
-
-    y : array-like, shape (n_samples)
-        An array representing labels for the data (only arrays of
-        integers are supported).
-
     n_neighbors : int
         default number of neighbors.
 
     window_size : int
         Window size passed to BallTree
 
+    mode : {'mean', 'barycenter'}
+        Weights to apply to labels.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
     >>> y = [0, 0, 1, 1]
-    >>> from scikits.learn.neighbors import NeighborsBarycenter
-    >>> neigh = NeighborsBarycenter(n_neighbors=2)
+    >>> from scikits.learn.neighbors import NeighborsRegressor
+    >>> neigh = NeighborsRegressor(n_neighbors=2)
     >>> neigh.fit(X, y)
-    NeighborsBarycenter(n_neighbors=2, window_size=1)
+    NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean')
     >>> print neigh.predict([[1.5]])
     [ 0.5]
 
@@ -213,6 +193,13 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
     http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
     """
 
+
+    def __init__(self, n_neighbors=5, mode='mean', window_size=1):
+        self.n_neighbors = n_neighbors
+        self.window_size = window_size
+        self.mode = mode
+
+
     def predict(self, X, **params):
         """Predict the target for the provided data.
 
@@ -229,32 +216,31 @@ class NeighborsBarycenter(Neighbors, RegressorMixin):
         -------
         y: array
             List of target values (one for each data sample).
-
-        Examples
-        --------
-        >>> X = [[0], [1], [2]]
-        >>> y = [0, 0, 1]
-        >>> from scikits.learn.neighbors import NeighborsBarycenter
-        >>> neigh = NeighborsBarycenter(n_neighbors=2)
-        >>> neigh.fit(X, y)
-        NeighborsBarycenter(n_neighbors=2, window_size=1)
-        >>> neigh.predict([[.5], [1.5]])
-        array([ 0. ,  0.5])
         """
         X = np.atleast_2d(np.asanyarray(X))
         self._set_params(**params)
 
-        # get neighbors of X
+#
+#       .. compute neighbors ..
+#
         neigh_ind = self.ball_tree.query(
             X, k=self.n_neighbors, return_distance=False)
         neigh = self.ball_tree.data[neigh_ind]
 
-        # compute barycenters at each point
-        B = barycenter_weights(X, neigh)
-        labels = self._y[neigh_ind]
+#
+#       .. return labels ..
+#
+        if self.mode == 'barycenter':
+            W = barycenter_weights(X, neigh)
+            return (W * self._y[neigh_ind]).sum(axis=1)
 
-        return (B * labels).sum(axis=1)
+        elif self.mode == 'mean':
+            return np.mean(self._y[neigh_ind], axis=1)
 
+        else:
+            raise ValueError(
+                'Unsupported mode, must be one of "barycenter" or '
+                '"mean" but got %s instead' % self.mode)
 
 ###############################################################################
 # Utils k-NN based Functions
@@ -281,6 +267,9 @@ def barycenter_weights(X, Z, cond=None):
     -------
     B : array-like, shape (n_samples, n_neighbors)
 
+    Notes
+    -----
+    See developers note for more information.
     """
 #
 #       .. local variables ..
@@ -308,6 +297,7 @@ def barycenter_weights(X, Z, cond=None):
             C[:, 1:], X[i] - C[:, 0] / np.sqrt(n_neighbors), cond=cond,
             overwrite_a=True, overwrite_b=True)[0].ravel()
         B[i] = rank_update(alpha, v, np.dot(v.T, B[i]), a=B[i])
+
     return B
 
 
@@ -322,7 +312,7 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'):
     n_neighbors : int
         Number of neighbors for each sample.
 
-    mode : 'connectivity' | 'distance' | 'barycenter'
+    mode : {'connectivity', 'distance', 'barycenter'}
         Type of returned matrix: 'connectivity' will return the
         connectivity matrix with ones and zeros, in 'distance' the
         edges are euclidian distance between points. In 'barycenter'
@@ -331,7 +321,6 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'):
 
     Returns
     -------
-
     A : CSR sparse matrix, shape = [n_samples, n_samples]
         A[i,j] is assigned the weight of edge that connects i to j.
 
@@ -345,16 +334,20 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'):
             [ 0.,  1.,  1.],
             [ 1.,  0.,  1.]])
     """
+
+#
+#       .. local variables ..
+#
     from scipy import sparse
     X = np.asanyarray(X)
-
     n_samples = X.shape[0]
     ball_tree = BallTree(X)
-
-    # CSR matrix A is represented as A_data, A_ind and A_indptr.
     n_nonzero = n_neighbors * n_samples
     A_indptr = np.arange(0, n_nonzero + 1, n_neighbors)
 
+#
+#       .. construct CSR matrix ..
+#
     if mode is 'connectivity':
         A_data = np.ones((n_samples, n_neighbors))
         A_ind = ball_tree.query(
@@ -371,7 +364,9 @@ def kneighbors_graph(X, n_neighbors, mode='connectivity'):
         A_data = barycenter_weights(X, X[A_ind])
 
     else:
-        raise ValueError("Unsupported mode type")
+        raise ValueError(
+            'Unsupported mode, must be one of "connectivity", '
+            '"distance" or "barycenter" but got %s instead' % mode)
 
     A = sparse.csr_matrix(
         (A_data.reshape(-1), A_ind.reshape(-1), A_indptr),
diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py
index 0ea6f7a7907e5bcfb47fc594932741df21b338b4..34d448f9a0250c0fee10603f50ba2ab35e37f19e 100644
--- a/scikits/learn/tests/test_neighbors.py
+++ b/scikits/learn/tests/test_neighbors.py
@@ -16,20 +16,20 @@ def test_neighbors_1D():
     Y = [0]*(n/2) + [1]*(n/2)
 
     # n_neighbors = 1
-    knn = neighbors.Neighbors(n_neighbors=1)
+    knn = neighbors.NeighborsClassifier(n_neighbors=1)
     knn.fit(X, Y)
     test = [[i + 0.01] for i in range(0, n/2)] + \
            [[i - 0.01] for i in range(n/2, n)]
     assert_array_equal(knn.predict(test), [0]*3 + [1]*3)
 
     # n_neighbors = 2
-    knn = neighbors.Neighbors(n_neighbors=2)
+    knn = neighbors.NeighborsClassifier(n_neighbors=2)
     knn.fit(X, Y)
     assert_array_equal(knn.predict(test), [0]*4 + [1]*2)
 
 
     # n_neighbors = 3
-    knn = neighbors.Neighbors(n_neighbors=3)
+    knn = neighbors.NeighborsClassifier(n_neighbors=3)
     knn.fit(X, Y)
     assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]),
                         [0 for i in range(n/2)])
@@ -49,22 +49,27 @@ def test_neighbors_2D():
         (-1, 0), (-1, -1), (0, -1)) # label 1
     n_2 = len(X)/2
     Y = [0]*n_2 + [1]*n_2
-    knn = neighbors.Neighbors()
+    knn = neighbors.NeighborsClassifier()
     knn.fit(X, Y)
 
     prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]])
     assert_array_equal(prediction, [0, 1, 0, 1])
 
 
-def test_neighbors_barycenter():
+def test_neighbors_regressor():
     """
-    NeighborsBarycenter for regression using k-NN
+    NeighborsRegressor for regression using k-NN
     """
     X = [[0], [1], [2], [3]]
     y = [0, 0, 1, 1]
-    neigh = neighbors.NeighborsBarycenter(n_neighbors=2)
-    neigh.fit(X, y)
-    assert_array_almost_equal(neigh.predict([[1.5]]), [0.5])
+    neigh = neighbors.NeighborsRegressor(n_neighbors=3)
+    neigh.fit(X, y, mode='barycenter')
+    assert_array_almost_equal(
+        neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3)
+    neigh.fit(X, y, mode='mean')
+    assert_array_almost_equal(
+        neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3)
+    
 
 
 def test_kneighbors_graph():