diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index 7cda17f1d40a7a5dd52acef70ad934a13ce20a4d..c2e554793212a2d50120fe6273ee3a69d1ec64a0 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -78,13 +78,19 @@ class NeighborsBase(BaseEstimator): self._fit_method = None def _fit(self, X): - if isinstance(X, BallTree): + if isinstance(X, NeighborsBase): + self._fit_X = X._fit_X + self._tree = X._tree + self._fit_method = X._fit_method + return self + + elif isinstance(X, BallTree): self._fit_X = X.data self._tree = X self._fit_method = 'ball_tree' return self - if isinstance(X, cKDTree): + elif isinstance(X, cKDTree): self._fit_X = X.data self._tree = X self._fit_method = 'kd_tree' @@ -92,6 +98,9 @@ class NeighborsBase(BaseEstimator): X = safe_asanyarray(X) + if X.ndim != 2: + raise ValueError("data type not understood") + if issparse(X): if self.algorithm not in ('auto', 'brute'): warnings.warn("cannot use tree with sparse input: " @@ -320,7 +329,7 @@ class RadiusNeighborsMixin(object): >>> neigh.fit(samples) # doctest: +ELLIPSIS NearestNeighbors(algorithm='auto', leaf_size=30, ...) >>> print neigh.radius_neighbors([1., 1., 1.]) # doctest: +ELLIPSIS - (array([[ 1.5 0.5]]...), array([[1 2]]...) + (array([[ 1.5, 0.5]]...), array([[1, 2]]...) The first array returned contains the distances to all points which are closer than 1.6, while the second array returned contains their @@ -328,8 +337,8 @@ class RadiusNeighborsMixin(object): Because the number of neighbors of each point is not necessarily equal, `radius_neighbors` returns an array of objects, where each object is a 1D array of indices. - """ + if self._fit_method == None: raise ValueError("must fit neighbors before querying") diff --git a/sklearn/neighbors/regression.py b/sklearn/neighbors/regression.py index cbfd7b1837edc363cc08237b6e2ebde657d61777..9db821a97ff660f9380bc2eca53344b68edb35b9 100644 --- a/sklearn/neighbors/regression.py +++ b/sklearn/neighbors/regression.py @@ -1,4 +1,4 @@ -"""Nearest Neighbor Classification""" +"""Nearest Neighbor Regression""" # Authors: Jake Vanderplas <vanderplas@astro.washington.edu> # Fabian Pedregosa <fabian.pedregosa@inria.fr> diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index b2159609e38a5670a61925baa288a7ae3fee3c14..d15fc3940980f73e24c459a083eb3620ae59ef95 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -1,7 +1,9 @@ import numpy as np from numpy.testing import assert_array_almost_equal, assert_array_equal +from numpy.testing import assert_raises from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix, dok_matrix, lil_matrix) +from scipy.spatial import cKDTree from sklearn import neighbors, datasets @@ -30,6 +32,7 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5, test = rng.rand(n_query_pts, n_features) + results_nodist = [] results = [] for algorithm in ALGORITHMS: @@ -37,13 +40,33 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5, algorithm=algorithm) neigh.fit(X) + results_nodist.append(neigh.kneighbors(test, return_distance=False)) results.append(neigh.kneighbors(test, return_distance=True)) for i in range(len(results) - 1): + assert_array_almost_equal(results_nodist[i], results[i][1]) assert_array_almost_equal(results[i][0], results[i + 1][0]) assert_array_almost_equal(results[i][1], results[i + 1][1]) +def test_unsupervised_inputs(): + X = np.random.random((10,3)) + + nbrs_fid = neighbors.NearestNeighbors(n_neighbors=1) + nbrs_fid.fit(X) + + dist1, ind1 = nbrs_fid.kneighbors(X) + + nbrs = neighbors.NearestNeighbors(n_neighbors=1) + + for input in (nbrs_fid, neighbors.BallTree(X), cKDTree(X)): + nbrs.fit(input) + dist2, ind2 = nbrs.kneighbors(X) + + assert_array_almost_equal(dist1, dist2) + assert_array_almost_equal(ind1, ind2) + + def test_unsupervised_radius_neighbors(n_samples=20, n_features=5, n_query_pts=2, radius=0.5, random_state=0): @@ -63,15 +86,20 @@ def test_unsupervised_radius_neighbors(n_samples=20, n_features=5, algorithm=algorithm) neigh.fit(X) + ind1 = neigh.radius_neighbors(test, return_distance=False) + # sort the results: this is not done automatically for # radius searches dist, ind = neigh.radius_neighbors(test, return_distance=True) - for (d, i) in zip(dist, ind): + for (d, i, i1) in zip(dist, ind, ind1): j = d.argsort() d[:] = d[j] i[:] = i[j] results.append((dist, ind)) + assert_array_almost_equal(np.concatenate(list(ind)), + np.concatenate(list(ind1))) + for i in range(len(results) - 1): assert_array_almost_equal(np.concatenate(list(results[i][0])), np.concatenate(list(results[i + 1][0]))), @@ -189,6 +217,7 @@ def test_radius_neighbors_regressor(n_samples=40, for algorithm in ALGORITHMS: for weights in ['uniform', 'distance', weight_func]: neigh = neighbors.RadiusNeighborsRegressor(radius=radius, + weights=weights, algorithm=algorithm) neigh.fit(X, y) epsilon = 1E-5 * (2 * rng.rand(1, n_features) - 1) @@ -304,6 +333,44 @@ def test_radius_neighbors_graph(): [ 1.01 , 0. , 1.40716026], [ 0. , 1.40716026, 0. ]]) +def test_neighbors_badargs(): + neigh_types = [neighbors.KNeighborsClassifier, + neighbors.RadiusNeighborsClassifier, + neighbors.KNeighborsRegressor, + neighbors.RadiusNeighborsRegressor] + + + assert_raises(ValueError, + neighbors.NearestNeighbors, + algorithm='blah') + + X = np.random.random((10,2)) + + for cls in neigh_types: + assert_raises(ValueError, + cls, + weights='blah') + + for cls in neigh_types: + nbrs = cls() + assert_raises(ValueError, + nbrs.predict, + X) + + + + nbrs = neighbors.NearestNeighbors().fit(X) + + assert_raises(ValueError, + nbrs.kneighbors_graph, + X, mode='blah') + assert_raises(ValueError, + nbrs.radius_neighbors_graph, + X, mode='blah') + + + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/sklearn/neighbors/unsupervised.py b/sklearn/neighbors/unsupervised.py index 15d6c4478e2815eb78ac2b2cb4dcbf0259fc7016..74f42a0e2e937694bd5ac34af3d15b5320415287 100644 --- a/sklearn/neighbors/unsupervised.py +++ b/sklearn/neighbors/unsupervised.py @@ -47,10 +47,10 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin, NearestNeighbors(algorithm='auto', leaf_size=30, n_neighbors=2, radius=0.4) >>> neigh.kneighbors([[0, 0, 1.3]], 2, return_distance=False) - array([[2, 0]], dtype=int32) + array([[2, 0]]) >>> neigh.radius_neighbors([0, 0, 1.3], 0.4, return_distance=False) - array([[2]], dtype=object) + array([[2]]) See also --------