diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py
index 7cda17f1d40a7a5dd52acef70ad934a13ce20a4d..c2e554793212a2d50120fe6273ee3a69d1ec64a0 100644
--- a/sklearn/neighbors/base.py
+++ b/sklearn/neighbors/base.py
@@ -78,13 +78,19 @@ class NeighborsBase(BaseEstimator):
         self._fit_method = None
 
     def _fit(self, X):
-        if isinstance(X, BallTree):
+        if isinstance(X, NeighborsBase):
+            self._fit_X = X._fit_X
+            self._tree = X._tree
+            self._fit_method = X._fit_method
+            return self
+
+        elif isinstance(X, BallTree):
             self._fit_X = X.data
             self._tree = X
             self._fit_method = 'ball_tree'
             return self
 
-        if isinstance(X, cKDTree):
+        elif isinstance(X, cKDTree):
             self._fit_X = X.data
             self._tree = X
             self._fit_method = 'kd_tree'
@@ -92,6 +98,9 @@ class NeighborsBase(BaseEstimator):
 
         X = safe_asanyarray(X)
 
+        if X.ndim != 2:
+            raise ValueError("data type not understood")
+
         if issparse(X):
             if self.algorithm not in ('auto', 'brute'):
                 warnings.warn("cannot use tree with sparse input: "
@@ -320,7 +329,7 @@ class RadiusNeighborsMixin(object):
         >>> neigh.fit(samples) # doctest: +ELLIPSIS
         NearestNeighbors(algorithm='auto', leaf_size=30, ...)
         >>> print neigh.radius_neighbors([1., 1., 1.]) # doctest: +ELLIPSIS
-        (array([[ 1.5  0.5]]...), array([[1 2]]...)
+        (array([[ 1.5,  0.5]]...), array([[1, 2]]...)
 
         The first array returned contains the distances to all points which
         are closer than 1.6, while the second array returned contains their
@@ -328,8 +337,8 @@ class RadiusNeighborsMixin(object):
         Because the number of neighbors of each point is not necessarily
         equal, `radius_neighbors` returns an array of objects, where each
         object is a 1D array of indices.
-
         """
+
         if self._fit_method == None:
             raise ValueError("must fit neighbors before querying")
 
diff --git a/sklearn/neighbors/regression.py b/sklearn/neighbors/regression.py
index cbfd7b1837edc363cc08237b6e2ebde657d61777..9db821a97ff660f9380bc2eca53344b68edb35b9 100644
--- a/sklearn/neighbors/regression.py
+++ b/sklearn/neighbors/regression.py
@@ -1,4 +1,4 @@
-"""Nearest Neighbor Classification"""
+"""Nearest Neighbor Regression"""
 
 # Authors: Jake Vanderplas <vanderplas@astro.washington.edu>
 #          Fabian Pedregosa <fabian.pedregosa@inria.fr>
diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py
index b2159609e38a5670a61925baa288a7ae3fee3c14..d15fc3940980f73e24c459a083eb3620ae59ef95 100644
--- a/sklearn/neighbors/tests/test_neighbors.py
+++ b/sklearn/neighbors/tests/test_neighbors.py
@@ -1,7 +1,9 @@
 import numpy as np
 from numpy.testing import assert_array_almost_equal, assert_array_equal
+from numpy.testing import assert_raises
 from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix,
                           dok_matrix, lil_matrix)
+from scipy.spatial import cKDTree
 
 from sklearn import neighbors, datasets
 
@@ -30,6 +32,7 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5,
 
     test = rng.rand(n_query_pts, n_features)
 
+    results_nodist = []
     results = []
 
     for algorithm in ALGORITHMS:
@@ -37,13 +40,33 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5,
                                            algorithm=algorithm)
         neigh.fit(X)
 
+        results_nodist.append(neigh.kneighbors(test, return_distance=False))
         results.append(neigh.kneighbors(test, return_distance=True))
 
     for i in range(len(results) - 1):
+        assert_array_almost_equal(results_nodist[i], results[i][1])
         assert_array_almost_equal(results[i][0], results[i + 1][0])
         assert_array_almost_equal(results[i][1], results[i + 1][1])
 
 
+def test_unsupervised_inputs():
+    X = np.random.random((10,3))
+
+    nbrs_fid = neighbors.NearestNeighbors(n_neighbors=1)
+    nbrs_fid.fit(X)
+    
+    dist1, ind1 = nbrs_fid.kneighbors(X)
+
+    nbrs = neighbors.NearestNeighbors(n_neighbors=1)    
+
+    for input in (nbrs_fid, neighbors.BallTree(X), cKDTree(X)):
+        nbrs.fit(input)
+        dist2, ind2 = nbrs.kneighbors(X)
+        
+        assert_array_almost_equal(dist1, dist2)
+        assert_array_almost_equal(ind1, ind2)
+
+
 def test_unsupervised_radius_neighbors(n_samples=20, n_features=5,
                                        n_query_pts=2, radius=0.5,
                                        random_state=0):
@@ -63,15 +86,20 @@ def test_unsupervised_radius_neighbors(n_samples=20, n_features=5,
                                            algorithm=algorithm)
         neigh.fit(X)
 
+        ind1 = neigh.radius_neighbors(test, return_distance=False)
+
         # sort the results: this is not done automatically for
         # radius searches
         dist, ind = neigh.radius_neighbors(test, return_distance=True)
-        for (d, i) in zip(dist, ind):
+        for (d, i, i1) in zip(dist, ind, ind1):
             j = d.argsort()
             d[:] = d[j]
             i[:] = i[j]
         results.append((dist, ind))
 
+        assert_array_almost_equal(np.concatenate(list(ind)),
+                                  np.concatenate(list(ind1)))
+
     for i in range(len(results) - 1):
         assert_array_almost_equal(np.concatenate(list(results[i][0])),
                                   np.concatenate(list(results[i + 1][0]))),
@@ -189,6 +217,7 @@ def test_radius_neighbors_regressor(n_samples=40,
     for algorithm in ALGORITHMS:
         for weights in ['uniform', 'distance', weight_func]:
             neigh = neighbors.RadiusNeighborsRegressor(radius=radius,
+                                                       weights=weights,
                                                        algorithm=algorithm)
             neigh.fit(X, y)
             epsilon = 1E-5 * (2 * rng.rand(1, n_features) - 1)
@@ -304,6 +333,44 @@ def test_radius_neighbors_graph():
          [ 1.01      ,  0.        ,  1.40716026],
          [ 0.        ,  1.40716026,  0.        ]])
 
+def test_neighbors_badargs():
+    neigh_types = [neighbors.KNeighborsClassifier,
+                   neighbors.RadiusNeighborsClassifier,
+                   neighbors.KNeighborsRegressor,
+                   neighbors.RadiusNeighborsRegressor]
+
+    
+    assert_raises(ValueError,
+                  neighbors.NearestNeighbors,
+                  algorithm='blah')
+
+    X = np.random.random((10,2))
+
+    for cls in neigh_types:
+        assert_raises(ValueError,
+                      cls,
+                      weights='blah')
+
+    for cls in neigh_types:
+        nbrs = cls()
+        assert_raises(ValueError,
+                      nbrs.predict,
+                      X)
+        
+
+        
+    nbrs = neighbors.NearestNeighbors().fit(X)
+
+    assert_raises(ValueError,
+                  nbrs.kneighbors_graph,
+                  X, mode='blah')
+    assert_raises(ValueError,
+                  nbrs.radius_neighbors_graph,
+                  X, mode='blah')
+
+
+    
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/sklearn/neighbors/unsupervised.py b/sklearn/neighbors/unsupervised.py
index 15d6c4478e2815eb78ac2b2cb4dcbf0259fc7016..74f42a0e2e937694bd5ac34af3d15b5320415287 100644
--- a/sklearn/neighbors/unsupervised.py
+++ b/sklearn/neighbors/unsupervised.py
@@ -47,10 +47,10 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin,
       NearestNeighbors(algorithm='auto', leaf_size=30, n_neighbors=2, radius=0.4)
 
       >>> neigh.kneighbors([[0, 0, 1.3]], 2, return_distance=False)
-      array([[2, 0]], dtype=int32)
+      array([[2, 0]])
 
       >>> neigh.radius_neighbors([0, 0, 1.3], 0.4, return_distance=False)
-      array([[2]], dtype=object)
+      array([[2]])
 
     See also
     --------