diff --git a/README.rst b/README.rst
index a1349738143c1dd2675fd2a753d0fa76b4e5edc0..1815a48d50f61cb14d65ac75421ba9831f72036a 100644
--- a/README.rst
+++ b/README.rst
@@ -79,8 +79,8 @@ Bugs
 ----
 
 Please submit bugs you might encounter, as well as patches and feature
-requests to the tracker located at the address
-https://sourceforge.net/apps/trac/scikit-learn/report
+requests to the tracker located at github
+https://github.com/scikit-learn/scikit-learn/issues
 
 
 Testing
diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst
index 68bcac1e212ea91cf070f88060b21208c29059c0..dc9df8c36ac8f6db2cb7ba7f1f39bdcc6f294438 100644
--- a/doc/modules/neighbors.rst
+++ b/doc/modules/neighbors.rst
@@ -20,6 +20,19 @@ The :class:`NeighborsClassifier` implements the nearest-neighbors
 classification method using a vote heuristic: the class most present
 in the k nearest neighbors of a point is assigned to this point.
 
+It is possible to use different nearest neighbor search algorithms by
+using the keyword ``algorithm``. Possible values are ``'auto'``,
+``'ball_tree'``, ``'brute'`` and ``'brute_inplace'``. ``'ball_tree'``
+will create an instance of :class:`BallTree` to conduct the search,
+which is usually very efficient in low-dimensional spaces. In higher
+dimension, a brute-force approach is prefered thus parameters
+``'brute'`` and ``'brute_inplace'`` can be used . Both conduct a
+brute-force search, the difference being that ``'brute_inplace'`` does
+not perform any precomputations, and thus is better suited for
+low-memory settings.  Finally, ``'auto'`` is a simple heuristic that
+will guess the best approach based on the current dataset.
+
+
 .. figure:: ../auto_examples/images/plot_neighbors.png
    :target: ../auto_examples/plot_neighbors.html
    :align: center
diff --git a/scikits/learn/cluster/k_means_.py b/scikits/learn/cluster/k_means_.py
index d64f85659871a4de2e46718069ef16167fc38896..fc9baec2cf9b706b28e8b53ecbd7cea48c63b510 100644
--- a/scikits/learn/cluster/k_means_.py
+++ b/scikits/learn/cluster/k_means_.py
@@ -184,9 +184,13 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
         if verbose:
             print 'Initialization complete'
         # iterations
+        x_squared_norms = X.copy()
+        x_squared_norms **=2
+        x_squared_norms = x_squared_norms.sum(axis=1)
         for i in range(max_iter):
             centers_old = centers.copy()
-            labels, inertia = _e_step(X, centers)
+            labels, inertia = _e_step(X, centers,
+                                        x_squared_norms=x_squared_norms)
             centers = _m_step(X, labels, k)
             if verbose:
                 print 'Iteration %i, inertia %s' % (i, inertia)
@@ -228,12 +232,18 @@ def _m_step(x, z, k):
         The resulting centers
     """
     dim = x.shape[1]
-    centers = np.repeat(np.reshape(x.mean(0), (1, dim)), k, 0)
+    centers = np.empty((k, dim))
+    X_center = None
     for q in range(k):
-        if np.sum(z == q) == 0:
-            pass
+        this_center_mask = (z == q)
+        if not np.any(this_center_mask):
+            # The centroid of empty clusters is set to the center of
+            # everything
+            if X_center is None:
+                X_center = x.mean(axis=0)
+            centers[q] = X_center
         else:
-            centers[q] = np.mean(x[z == q], axis=0)
+            centers[q] = np.mean(x[this_center_mask], axis=0)
     return centers
 
 
@@ -265,8 +275,10 @@ def _e_step(x, centers, precompute_distances=True, x_squared_norms=None):
     if precompute_distances:
         distances = euclidean_distances(centers, x, x_squared_norms,
                                         squared=True)
-    z = -np.ones(n_samples).astype(np.int)
-    mindist = np.infty * np.ones(n_samples)
+    z = np.empty(n_samples, dtype=np.int)
+    z.fill(-1)
+    mindist = np.empty(n_samples)
+    mindist.fill(np.infty)
     for q in range(k):
         if precompute_distances:
             dist = distances[q]
diff --git a/scikits/learn/linear_model/setup.py b/scikits/learn/linear_model/setup.py
index 707f673a567ac65d838558418db9cb7a641ee4d2..a8f7a079f628ceae4fb0a52ae7bac21dc6848a49 100644
--- a/scikits/learn/linear_model/setup.py
+++ b/scikits/learn/linear_model/setup.py
@@ -1,11 +1,9 @@
 from os.path import join
-import warnings
 import numpy
-import sys
 
 def configuration(parent_package='', top_path=None):
     from numpy.distutils.misc_util import Configuration
-    from numpy.distutils.system_info import get_info, get_standard_file, BlasNotFoundError
+    from numpy.distutils.system_info import get_info
     config = Configuration('linear_model', parent_package, top_path)
 
     # cd fast needs CBLAS
diff --git a/scikits/learn/metrics/__init__.py b/scikits/learn/metrics/__init__.py
index 8bc4c44b47da540351fc3cc0b07badf99ac3b983..4b761ca18ae898dddbb4380dfd47af695a05e8c5 100644
--- a/scikits/learn/metrics/__init__.py
+++ b/scikits/learn/metrics/__init__.py
@@ -8,3 +8,5 @@ from .metrics import confusion_matrix, roc_curve, auc, precision_score, \
                 precision_recall_fscore_support, classification_report, \
                 precision_recall_curve, explained_variance_score, r2_score, \
                 zero_one, mean_square_error
+
+from .pairwise import euclidean_distances
diff --git a/scikits/learn/metrics/pairwise.py b/scikits/learn/metrics/pairwise.py
index 5f138b55f406ed40c8f2109fc69b43ccf3396dd2..4b60e2c3d61677a3a0c22acc4bea509a6ae2b443 100644
--- a/scikits/learn/metrics/pairwise.py
+++ b/scikits/learn/metrics/pairwise.py
@@ -8,10 +8,7 @@ sets of points.
 
 import numpy as np
 
-
-def euclidean_distances(X, Y,
-        Y_norm_squared=None,
-        squared=False):
+def euclidean_distances(X, Y, Y_norm_squared=None, squared=False):
     """
     Considering the rows of X (and Y=X) as vectors, compute the
     distance matrix between each pair of vectors.
@@ -61,7 +58,9 @@ def euclidean_distances(X, Y,
     if X is Y: # shortcut in the common case euclidean_distances(X, X)
         YY = XX.T
     elif Y_norm_squared is None:
-        YY = np.sum(Y * Y, axis=1)[np.newaxis, :]
+        YY = Y.copy()
+        YY **= 2
+        YY = np.sum(YY, axis=1)[np.newaxis, :]
     else:
         YY = np.asanyarray(Y_norm_squared)
         if YY.shape != (Y.shape[0],):
diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py
index 18b5e474295de6aa7230aa81cb7f2d37e29e2512..07fca967db8ca255728b0c7d7da03074cdc3fa08 100644
--- a/scikits/learn/neighbors.py
+++ b/scikits/learn/neighbors.py
@@ -8,7 +8,7 @@
 import numpy as np
 
 from .base import BaseEstimator, ClassifierMixin, RegressorMixin
-from .ball_tree import BallTree
+from .ball_tree import BallTree, knn_brute
 
 
 class NeighborsClassifier(BaseEstimator, ClassifierMixin):
@@ -16,12 +16,18 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
 
     Parameters
     ----------
-    n_neighbors : int
-        default number of neighbors.
+    n_neighbors : int, optional
+        Default number of neighbors. Defaults to 5.
 
-    window_size : int
+    window_size : int, optional
         Window size passed to BallTree
 
+    algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
+        Algorithm used to compute the nearest neighbors. 'ball_tree'
+        will construct a BallTree, 'brute' and 'brute_inplace' will
+        perform brute-force search.'auto' will guess the most
+        appropriate based on current dataset.
+
     Examples
     --------
     >>> samples = [[0, 0, 1], [1, 0, 0]]
@@ -29,24 +35,25 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
     >>> from scikits.learn.neighbors import NeighborsClassifier
     >>> neigh = NeighborsClassifier(n_neighbors=1)
     >>> neigh.fit(samples, labels)
-    NeighborsClassifier(n_neighbors=1, window_size=1)
+    NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
     >>> print neigh.predict([[0,0,0]])
     [1]
 
-    Notes
-    -----
-    Internally uses the ball tree datastructure and algorithm for fast
-    neighbors lookups on high dimensional datasets.
+    See also
+    --------
+    BallTree
 
     References
     ----------
     http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
     """
 
-    def __init__(self, n_neighbors=5, window_size=1):
+    def __init__(self, n_neighbors=5, algorithm='auto', window_size=1):
         self.n_neighbors = n_neighbors
         self.window_size = window_size
+        self.algorithm = algorithm
 
+        
     def fit(self, X, Y, **params):
         """
         Fit the model using X, y as training data.
@@ -62,12 +69,19 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
         params : list of keyword, optional
             Overwrite keywords from __init__
         """
+        X = np.asanyarray(X)
         self._y = np.asanyarray(Y)
         self._set_params(**params)
 
-        self.ball_tree = BallTree(X, self.window_size)
+        if self.algorithm == 'ball_tree' or \
+           (self.algorithm == 'auto' and X.shape[1] < 20):
+            self.ball_tree = BallTree(X, self.window_size)
+        else:
+            self.ball_tree = None
+            self._fit_X = X
         return self
 
+
     def kneighbors(self, data, return_distance=True, **params):
         """Finds the K-neighbors of a point.
 
@@ -105,7 +119,7 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
         >>> from scikits.learn.neighbors import NeighborsClassifier
         >>> neigh = NeighborsClassifier(n_neighbors=1)
         >>> neigh.fit(samples, labels)
-        NeighborsClassifier(n_neighbors=1, window_size=1)
+        NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
         >>> print neigh.kneighbors([1., 1., 1.])
         (array([ 0.5]), array([2]))
 
@@ -123,6 +137,7 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
         return self.ball_tree.query(
             data, k=self.n_neighbors, return_distance=return_distance)
 
+
     def predict(self, X, **params):
         """Predict the class labels for the provided data.
 
@@ -143,10 +158,21 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
         X = np.atleast_2d(X)
         self._set_params(**params)
 
-        ind = self.ball_tree.query(
-            X, self.n_neighbors, return_distance=False)
-        pred_labels = self._y[ind]
+        # .. get neighbors ..
+        if self.ball_tree is None:
+            if self.algorithm == 'brute_inplace':
+                neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
+            else:
+                from .metrics import euclidean_distances
+                dist = euclidean_distances(
+                    X, self._fit_X, squared=True)
+                neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors]
+        else:
+            neigh_ind = self.ball_tree.query(
+                X, self.n_neighbors, return_distance=False)
 
+        # .. most popular label ..
+        pred_labels = self._y[neigh_ind]
         from scipy import stats
         mode, _ = stats.mode(pred_labels, axis=1)
         return mode.flatten().astype(np.int)
@@ -168,15 +194,21 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
 
     Parameters
     ----------
-    n_neighbors : int
-        default number of neighbors.
+    n_neighbors : int, optional
+        Default number of neighbors. Defaults to 5.
 
-    window_size : int
+    window_size : int, optional
         Window size passed to BallTree
 
-    mode : {'mean', 'barycenter'}
+    mode : {'mean', 'barycenter'}, optional
         Weights to apply to labels.
 
+    algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
+        Algorithm used to compute the nearest neighbors. 'ball_tree'
+        will construct a BallTree, 'brute' and 'brute_inplace' will
+        perform brute-force search.'auto' will guess the most
+        appropriate based on current dataset.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
@@ -184,7 +216,8 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
     >>> from scikits.learn.neighbors import NeighborsRegressor
     >>> neigh = NeighborsRegressor(n_neighbors=2)
     >>> neigh.fit(X, y)
-    NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean')
+    NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean',
+              algorithm='auto')
     >>> print neigh.predict([[1.5]])
     [ 0.5]
 
@@ -194,10 +227,12 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
     """
 
 
-    def __init__(self, n_neighbors=5, mode='mean', window_size=1):
+    def __init__(self, n_neighbors=5, mode='mean', algorithm='auto',
+                 window_size=1):
         self.n_neighbors = n_neighbors
         self.window_size = window_size
         self.mode = mode
+        self.algorithm = algorithm
 
 
     def predict(self, X, **params):
@@ -220,16 +255,22 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
         X = np.atleast_2d(np.asanyarray(X))
         self._set_params(**params)
 
-#
-#       .. compute neighbors ..
-#
-        neigh_ind = self.ball_tree.query(
-            X, k=self.n_neighbors, return_distance=False)
-        neigh = self.ball_tree.data[neigh_ind]
-
-#
-#       .. return labels ..
-#
+        # .. get neighbors ..
+        if self.ball_tree is None:
+            if self.algorithm == 'brute_inplace':
+                neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
+            else:
+                from .metrics.pairwise import euclidean_distances
+                dist = euclidean_distances(
+                    X, self._fit_X, squared=False)
+                neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors]
+            neigh = self._fit_X[neigh_ind]
+        else:
+            neigh_ind = self.ball_tree.query(
+                X, self.n_neighbors, return_distance=False)
+            neigh = self.ball_tree.data[neigh_ind]
+        
+        # .. return labels ..
         if self.mode == 'barycenter':
             W = barycenter_weights(X, neigh)
             return (W * self._y[neigh_ind]).sum(axis=1)
diff --git a/scikits/learn/setup.py b/scikits/learn/setup.py
index b2e940a486a73e66103b892ca5c98c39f2016ef2..2a903f27a20bc5c9b9b1ee5e0062a120e9c24f79 100644
--- a/scikits/learn/setup.py
+++ b/scikits/learn/setup.py
@@ -1,7 +1,6 @@
 from os.path import join
 import warnings
 import numpy
-import sys
 
 
 def configuration(parent_package='', top_path=None):
@@ -36,12 +35,7 @@ def configuration(parent_package='', top_path=None):
         ('NO_ATLAS_INFO', 1) in blas_info.get('define_macros', [])):
         config.add_library('cblas',
                            sources=[join('src', 'cblas', '*.c')])
-        cblas_libs = ['cblas']
-        blas_info.pop('libraries', None)
         warnings.warn(BlasNotFoundError.__doc__)
-    else:
-        cblas_libs = blas_info.pop('libraries', [])
-
 
     config.add_extension('ball_tree',
                          sources=[join('src', 'BallTree.cpp')],
diff --git a/scikits/learn/src/BallTree.cpp b/scikits/learn/src/BallTree.cpp
index c1becd13845f49c957a469181eb691209b7fbf26..f98cdfa6684d4ab9e89b56fd412ac9326b7bb8e4 100644
--- a/scikits/learn/src/BallTree.cpp
+++ b/scikits/learn/src/BallTree.cpp
@@ -712,22 +712,6 @@ BallTree_knn_brute(PyObject *self, PyObject *args, PyObject *kwds){
     for(int i=0;i<N;i++)
         delete Points[i];
 
-  //if only one neighbor is requested, then resize the neighbors array
-    if(k==1){
-        PyArray_Dims dims;
-        dims.ptr = PyArray_DIMS(arr2);
-        dims.len = PyArray_NDIM(arr2)-1;
-
-    //PyArray_Resize returns None - this needs to be picked
-    // up and dereferenced.
-        PyObject *NoneObj = PyArray_Resize( (PyArrayObject*)nbrs, &dims,
-            0, NPY_ANYORDER );
-        if (NoneObj == NULL){
-            goto fail;
-        }
-        Py_DECREF(NoneObj);
-    }
-
     return nbrs;
 
     fail:
diff --git a/scikits/learn/svm/base.py b/scikits/learn/svm/base.py
index 1f21c3d670f31070dc3ad6ac344728d0f5e99730..3be7947c743688e23d5b103bcfbc0e0d79ef3a8a 100644
--- a/scikits/learn/svm/base.py
+++ b/scikits/learn/svm/base.py
@@ -425,17 +425,14 @@ class BaseLibLinear(BaseEstimator):
         X = np.atleast_2d(np.asanyarray(X, dtype=np.float64, order='C'))
         self._check_n_features(X)
 
-        coef = self.raw_coef_
-
-        dec_func = _liblinear.decision_function_wrap(X, coef,
-                                      self._get_solver_type(),
-                                      self.eps, self.C,
-                                      self.class_weight_label,
-                                      self.class_weight, self.label_,
-                                      self._get_bias())
+        dec_func = _liblinear.decision_function_wrap(
+            X, self.raw_coef_, self._get_solver_type(), self.eps,
+            self.C, self.class_weight_label, self.class_weight,
+            self.label_, self._get_bias())
 
         if len(self.label_) <= 2:
-            # one class
+            # in the two-class case, the decision sign needs be flipped
+            # due to liblinear's design
             return -dec_func
         else:
             return dec_func
@@ -451,14 +448,22 @@ class BaseLibLinear(BaseEstimator):
     @property
     def intercept_(self):
         if self.fit_intercept:
-            return self.intercept_scaling * self.raw_coef_[:, -1]
+            ret = self.intercept_scaling * self.raw_coef_[:, -1]
+            if len(self.label_) <= 2:
+                ret *= -1
+            return ret
         return 0.0
 
     @property
     def coef_(self):
         if self.fit_intercept:
-            return self.raw_coef_[:, : -1]
-        return self.raw_coef_
+            ret = self.raw_coef_[:, : -1]
+        else:
+            ret = self.raw_coef_
+        if len(self.label_) <= 2:
+            return -ret
+        else:
+            return ret
 
     def predict_proba(self, T):
         # only available for logistic regression
diff --git a/scikits/learn/svm/sparse/base.py b/scikits/learn/svm/sparse/base.py
index 36e85bebac68e3eaf5e12f7982c8392c40c8f41c..62b3d3ed18a4f8ee137de5c94f4d9f6545ecb3cb 100644
--- a/scikits/learn/svm/sparse/base.py
+++ b/scikits/learn/svm/sparse/base.py
@@ -265,7 +265,8 @@ class SparseBaseLibLinear(BaseLibLinear):
             self._get_bias())
 
         if len(self.label_) <= 2:
-            # one class
+            # in the two-class case, the decision sign needs be flipped
+            # due to liblinear's design
             return -dec_func
         else:
             return dec_func
diff --git a/scikits/learn/svm/tests/test_svm.py b/scikits/learn/svm/tests/test_svm.py
index 46e0e3be0c13502bf1f3c22618c28f8330ed9000..18bbab4818d611f0e8f8c4cc38ecc8f85ec5ac78 100644
--- a/scikits/learn/svm/tests/test_svm.py
+++ b/scikits/learn/svm/tests/test_svm.py
@@ -417,7 +417,7 @@ def test_dense_liblinear_intercept_handling(classifier=svm.LinearSVC):
     clf.intercept_scaling = 100
     clf.fit(X, y)
     intercept1 = clf.intercept_
-    assert intercept1 > 1
+    assert intercept1 < -1
 
     # when intercept_scaling is sufficiently high, the intercept value
     # doesn't depend on intercept_scaling value
@@ -435,14 +435,26 @@ def test_liblinear_predict():
     returns the same as the one in libliblinear
 
     """
+    # multi-class case
     clf = svm.LinearSVC().fit(iris.data, iris.target)
-
     weights = clf.coef_.T
     bias = clf.intercept_
     H = np.dot(iris.data, weights) + bias
-
     assert_array_equal(clf.predict(iris.data), H.argmax(axis=1))
 
+    # binary-class case
+    X = [[2, 1],
+         [3, 1],
+         [1, 3],
+         [2, 3]]
+    y = [0, 0, 1, 1]
+
+    clf = svm.LinearSVC().fit(X, y)
+    weights = np.ravel(clf.coef_)
+    bias = clf.intercept_
+    H = np.dot(X, weights) + bias
+    assert_array_equal(clf.predict(X), (H > 0).astype(int))
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py
index 34d448f9a0250c0fee10603f50ba2ab35e37f19e..3921fadf6b74c090b2f8610915eb18784a8f0815 100644
--- a/scikits/learn/tests/test_neighbors.py
+++ b/scikits/learn/tests/test_neighbors.py
@@ -1,7 +1,14 @@
 import numpy as np
-from numpy.testing import assert_array_almost_equal, assert_array_equal
+from numpy.testing import assert_array_almost_equal, assert_array_equal, \
+     assert_
 
-from scikits.learn import neighbors
+from scikits.learn import neighbors, datasets
+
+# load and shuffle iris dataset
+iris = datasets.load_iris()
+perm = np.random.permutation(iris.target.size)
+iris.data = iris.data[perm]
+iris.target = iris.target[perm]
 
 
 def test_neighbors_1D():
@@ -15,61 +22,49 @@ def test_neighbors_1D():
     X = [[x] for x in range(0, n)]
     Y = [0]*(n/2) + [1]*(n/2)
 
-    # n_neighbors = 1
-    knn = neighbors.NeighborsClassifier(n_neighbors=1)
-    knn.fit(X, Y)
-    test = [[i + 0.01] for i in range(0, n/2)] + \
-           [[i - 0.01] for i in range(n/2, n)]
-    assert_array_equal(knn.predict(test), [0]*3 + [1]*3)
-
-    # n_neighbors = 2
-    knn = neighbors.NeighborsClassifier(n_neighbors=2)
-    knn.fit(X, Y)
-    assert_array_equal(knn.predict(test), [0]*4 + [1]*2)
-
-
-    # n_neighbors = 3
-    knn = neighbors.NeighborsClassifier(n_neighbors=3)
-    knn.fit(X, Y)
-    assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]),
-                        [0 for i in range(n/2)])
-    assert_array_equal(knn.predict([[i-0.01] for i in range(n/2, n)]),
-                        [1 for i in range(n/2)])
-
-
-def test_neighbors_2D():
+    for s in ('auto', 'ball_tree', 'brute', 'inplace'):
+        # n_neighbors = 1
+        knn = neighbors.NeighborsClassifier(n_neighbors=1, algorithm=s)
+        knn.fit(X, Y)
+        test = [[i + 0.01] for i in range(0, n/2)] + \
+               [[i - 0.01] for i in range(n/2, n)]
+        assert_array_equal(knn.predict(test), [0]*3 + [1]*3)
+
+        # n_neighbors = 2
+        knn = neighbors.NeighborsClassifier(n_neighbors=2, algorithm=s)
+        knn.fit(X, Y)
+        assert_array_equal(knn.predict(test), [0]*4 + [1]*2)
+
+        # n_neighbors = 3
+        knn = neighbors.NeighborsClassifier(n_neighbors=3, algorithm=s)
+        knn.fit(X, Y)
+        assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]),
+                            [0 for i in range(n/2)])
+        assert_array_equal(knn.predict([[i-0.01] for i in range(n/2, n)]),
+                            [1 for i in range(n/2)])
+
+
+def test_neighbors_iris():
     """
-    Nearest Neighbor in the plane.
+    Sanity checks on the iris dataset
 
     Puts three points of each label in the plane and performs a
     nearest neighbor query on points near the decision boundary.
     """
-    X = (
-        (0, 1), (1, 1), (1, 0), # label 0
-        (-1, 0), (-1, -1), (0, -1)) # label 1
-    n_2 = len(X)/2
-    Y = [0]*n_2 + [1]*n_2
-    knn = neighbors.NeighborsClassifier()
-    knn.fit(X, Y)
 
-    prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]])
-    assert_array_equal(prediction, [0, 1, 0, 1])
+    for s in ('auto', 'ball_tree', 'brute', 'inplace'):
+        clf = neighbors.NeighborsClassifier()
+        clf.fit(iris.data, iris.target, n_neighbors=1, algorithm=s)
+        assert_array_equal(clf.predict(iris.data), iris.target)
 
+        clf.fit(iris.data, iris.target, n_neighbors=9, algorithm=s)
+        assert_(np.mean(clf.predict(iris.data)== iris.target) > 0.95)
 
-def test_neighbors_regressor():
-    """
-    NeighborsRegressor for regression using k-NN
-    """
-    X = [[0], [1], [2], [3]]
-    y = [0, 0, 1, 1]
-    neigh = neighbors.NeighborsRegressor(n_neighbors=3)
-    neigh.fit(X, y, mode='barycenter')
-    assert_array_almost_equal(
-        neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3)
-    neigh.fit(X, y, mode='mean')
-    assert_array_almost_equal(
-        neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3)
-    
+        for m in ('barycenter', 'mean'):
+            rgs = neighbors.NeighborsRegressor()
+            rgs.fit(iris.data, iris.target, mode=m, algorithm=s)
+            assert_(np.mean(
+                rgs.predict(iris.data).round() == iris.target) > 0.95)
 
 
 def test_kneighbors_graph():