diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index 49ab552e9e93fc63d24affe123e602a4feab2039..5e4bf87e42620e5403f606c910128e6aa67e5848 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -33,8 +33,8 @@ handwritten digits or satellite image scenes. It is often successful in classification situations where the decision boundary is very irregular. The classes in :mod:`sklearn.neighbors` can handle either Numpy arrays or -`scipy.sparse` matrices as input. It currently supports only the Euclidean -distance metric. +`scipy.sparse` matrices as input. Arbitrary Minkowski metrics are supported +for searches. Unsupervised Nearest Neighbors diff --git a/doc/tutorial/statistical_inference/supervised_learning.rst b/doc/tutorial/statistical_inference/supervised_learning.rst index 5c1171fe063b1043573c72251ab0dc03be635d5b..1088214755d5e382307b1cc9357dadfcf9d5c246 100644 --- a/doc/tutorial/statistical_inference/supervised_learning.rst +++ b/doc/tutorial/statistical_inference/supervised_learning.rst @@ -95,7 +95,7 @@ Scikit-learn documentation for more information about this type of classifier.) >>> from sklearn.neighbors import KNeighborsClassifier >>> knn = KNeighborsClassifier() >>> knn.fit(iris_X_train, iris_y_train) - KNeighborsClassifier(algorithm='auto', leaf_size=30, n_neighbors=5, + KNeighborsClassifier(algorithm='auto', leaf_size=30, n_neighbors=5, p=2, warn_on_equidistant=True, weights='uniform') >>> knn.predict(iris_X_test) array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0]) diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index 87c3407c99b01423bcfe2df5720da9b29464ca14..87820d51f505b2984781955e92aaa4660d861bc9 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -13,7 +13,7 @@ from scipy.spatial.ckdtree import cKDTree from .ball_tree import BallTree from ..base import BaseEstimator -from ..metrics import euclidean_distances +from ..metrics import pairwise_distances from ..utils import safe_asarray, atleast2d_or_csr @@ -71,15 +71,18 @@ class NeighborsBase(BaseEstimator): # rely on soon-to-be-updated functionality in the pairwise module. def _init_params(self, n_neighbors=None, radius=None, algorithm='auto', leaf_size=30, - warn_on_equidistant=True): + warn_on_equidistant=True, p=2): self.n_neighbors = n_neighbors self.radius = radius self.algorithm = algorithm self.leaf_size = leaf_size self.warn_on_equidistant = warn_on_equidistant + self.p = p if algorithm not in ['auto', 'brute', 'kd_tree', 'ball_tree']: raise ValueError("unrecognized algorithm: '%s'" % algorithm) + if p < 1: + raise ValueError("p must be greater than or equal to 1") self._fit_X = None self._tree = None @@ -131,7 +134,7 @@ class NeighborsBase(BaseEstimator): if self._fit_method == 'kd_tree': self._tree = cKDTree(X, self.leaf_size) elif self._fit_method == 'ball_tree': - self._tree = BallTree(X, self.leaf_size) + self._tree = BallTree(X, self.leaf_size, p=self.p) elif self._fit_method == 'brute': self._tree = None else: @@ -202,7 +205,16 @@ class KNeighborsMixin(object): n_neighbors = self.n_neighbors if self._fit_method == 'brute': - dist = euclidean_distances(X, self._fit_X, squared=True) + if self.p == 1: + dist = pairwise_distances(X, self._fit_X, 'manhattan') + elif self.p == 2: + dist = pairwise_distances(X, self._fit_X, 'euclidean', + squared=True) + elif self.p == np.inf: + dist = pairwise_distances(X, self._fit_X, 'chebyshev') + else: + dist = pairwise_distances(X, self._fit_X, 'minkowski', + p=self.p) # XXX: should be implemented with a partial sort neigh_ind = dist.argsort(axis=1) if self.warn_on_equidistant and n_neighbors < self._fit_X.shape[0]: @@ -214,7 +226,10 @@ class KNeighborsMixin(object): neigh_ind = neigh_ind[:, :n_neighbors] if return_distance: j = np.arange(neigh_ind.shape[0])[:, None] - return np.sqrt(dist[j, neigh_ind]), neigh_ind + if self.p == 2: + return np.sqrt(dist[j, neigh_ind]), neigh_ind + else: + return dist[j, neigh_ind], neigh_ind else: return neigh_ind elif self._fit_method == 'ball_tree': @@ -224,7 +239,7 @@ class KNeighborsMixin(object): warn_equidistant() return result elif self._fit_method == 'kd_tree': - dist, ind = self._tree.query(X, n_neighbors) + dist, ind = self._tree.query(X, n_neighbors, p=self.p) # kd_tree returns a 1D array for n_neighbors = 1 if n_neighbors == 1: dist = dist[:, None] @@ -366,10 +381,19 @@ class RadiusNeighborsMixin(object): radius = self.radius if self._fit_method == 'brute': - dist = euclidean_distances(X, self._fit_X, squared=True) - rad2 = radius ** 2 + if self.p == 1: + dist = pairwise_distances(X, self._fit_X, 'manhattan') + elif self.p == 2: + dist = pairwise_distances(X, self._fit_X, 'euclidean', + squared=True) + radius *= radius + elif self.p == np.inf: + dist = pairwise_distances(X, self._fit_X, 'chebyshev') + else: + dist = pairwise_distances(X, self._fit_X, 'minkowski', + p=self.p) - neigh_ind = [np.where(d < rad2)[0] for d in dist] + neigh_ind = [np.where(d < radius)[0] for d in dist] # if there are the same number of neighbors for each point, # we can do a normal array. Otherwise, we return an object @@ -382,9 +406,14 @@ class RadiusNeighborsMixin(object): dtype_F = object if return_distance: - dist = np.array([np.sqrt(d[neigh_ind[i]]) \ - for i, d in enumerate(dist)], - dtype=dtype_F) + if self.p == 2: + dist = np.array([np.sqrt(d[neigh_ind[i]]) \ + for i, d in enumerate(dist)], + dtype=dtype_F) + else: + dist = np.array([d[neigh_ind[i]] \ + for i, d in enumerate(dist)], + dtype=dtype_F) return dist, neigh_ind else: return neigh_ind @@ -400,7 +429,8 @@ class RadiusNeighborsMixin(object): elif self._fit_method == 'kd_tree': Npts = self._fit_X.shape[0] dist, ind = self._tree.query(X, Npts, - distance_upper_bound=radius) + distance_upper_bound=radius, + p=self.p) ind = [ind_i[:ind_i.searchsorted(Npts)] for ind_i in ind] diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index f8ded2b1f87935046b6567e499d04169e4ee9f80..90aebfa35e5ecc4340339aaf25cbdfff2c47f243 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -68,6 +68,12 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin, ordering of the training data. If the fit method is ``'kd_tree'``, no warnings will be generated. + p: integer, optional (default = 2) + Parameter for the Minkowski metric from + sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is + equivalent to using manhattan_distance, and euclidean_distance for + p = 2. For arbitrary p, minkowski_distance is used. + Examples -------- >>> X = [[0], [1], [2], [3]] @@ -97,11 +103,12 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin, def __init__(self, n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, - warn_on_equidistant=True): + warn_on_equidistant=True, p=2): self._init_params(n_neighbors=n_neighbors, algorithm=algorithm, leaf_size=leaf_size, - warn_on_equidistant=warn_on_equidistant) + warn_on_equidistant=warn_on_equidistant, + p=p) self.weights = _check_weights(weights) def predict(self, X): @@ -174,6 +181,12 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin, required to store the tree. The optimal value depends on the nature of the problem. + p: integer, optional (default = 2) + Parameter for the Minkowski metric from + sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is + equivalent to using manhattan_distance, and euclidean_distance for + p = 2. For arbitrary p, minkowski_distance is used. + Examples -------- >>> X = [[0], [1], [2], [3]] @@ -201,10 +214,11 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin, """ def __init__(self, radius=1.0, weights='uniform', - algorithm='auto', leaf_size=30): + algorithm='auto', leaf_size=30, p=2): self._init_params(radius=radius, algorithm=algorithm, - leaf_size=leaf_size) + leaf_size=leaf_size, + p=p) self.weights = _check_weights(weights) def predict(self, X): diff --git a/sklearn/neighbors/regression.py b/sklearn/neighbors/regression.py index 9f918db7810892d9494fdd0ef4aea428dbf948b0..a1acd28e93640c228d7498e262be1d51af064f8b 100644 --- a/sklearn/neighbors/regression.py +++ b/sklearn/neighbors/regression.py @@ -70,6 +70,12 @@ class KNeighborsRegressor(NeighborsBase, KNeighborsMixin, ordering of the training data. If the fit method is ``'kd_tree'``, no warnings will be generated. + p: integer, optional (default = 2) + Parameter for the Minkowski metric from + sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is + equivalent to using manhattan_distance, and euclidean_distance for + p = 2. For arbitrary p, minkowski_distance is used. + Examples -------- >>> X = [[0], [1], [2], [3]] @@ -97,11 +103,13 @@ class KNeighborsRegressor(NeighborsBase, KNeighborsMixin, """ def __init__(self, n_neighbors=5, weights='uniform', - algorithm='auto', leaf_size=30, warn_on_equidistant=True): + algorithm='auto', leaf_size=30, warn_on_equidistant=True, + p=2): self._init_params(n_neighbors=n_neighbors, algorithm=algorithm, leaf_size=leaf_size, - warn_on_equidistant=warn_on_equidistant) + warn_on_equidistant=warn_on_equidistant, + p=p) self.weights = _check_weights(weights) def predict(self, X): @@ -177,6 +185,12 @@ class RadiusNeighborsRegressor(NeighborsBase, RadiusNeighborsMixin, required to store the tree. The optimal value depends on the nature of the problem. + p: integer, optional (default = 2) + Parameter for the Minkowski metric from + sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is + equivalent to using manhattan_distance, and euclidean_distance for + p = 2. For arbitrary p, minkowski_distance is used. + Examples -------- >>> X = [[0], [1], [2], [3]] @@ -204,10 +218,11 @@ class RadiusNeighborsRegressor(NeighborsBase, RadiusNeighborsMixin, """ def __init__(self, radius=1.0, weights='uniform', - algorithm='auto', leaf_size=30): + algorithm='auto', leaf_size=30, p=2): self._init_params(radius=radius, algorithm=algorithm, - leaf_size=leaf_size) + leaf_size=leaf_size, + p=p) self.weights = _check_weights(weights) def predict(self, X): diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 810d6358a9564e4005a91ab224d5495011fee7f7..86934f5353f2edd06cb84e9cbf4a342d31ca133a 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -19,6 +19,7 @@ SPARSE_TYPES = (bsr_matrix, coo_matrix, csc_matrix, csr_matrix, dok_matrix, SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,) ALGORITHMS = ('ball_tree', 'brute', 'kd_tree', 'auto') +P = (1, 2, 3, 4, np.inf) def test_warn_on_equidistant(n_samples=100, n_features=3, k=3): @@ -73,21 +74,24 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5, test = rng.rand(n_query_pts, n_features) - results_nodist = [] - results = [] - for algorithm in ALGORITHMS: - neigh = neighbors.NearestNeighbors(n_neighbors=n_neighbors, - algorithm=algorithm) - neigh.fit(X) + for p in P: + results_nodist = [] + results = [] + + for algorithm in ALGORITHMS: + neigh = neighbors.NearestNeighbors(n_neighbors=n_neighbors, + algorithm=algorithm, + p=p) + neigh.fit(X) - results_nodist.append(neigh.kneighbors(test, return_distance=False)) - results.append(neigh.kneighbors(test, return_distance=True)) + results_nodist.append(neigh.kneighbors(test, return_distance=False)) + results.append(neigh.kneighbors(test, return_distance=True)) - for i in range(len(results) - 1): - assert_array_almost_equal(results_nodist[i], results[i][1]) - assert_array_almost_equal(results[i][0], results[i + 1][0]) - assert_array_almost_equal(results[i][1], results[i + 1][1]) + for i in range(len(results) - 1): + assert_array_almost_equal(results_nodist[i], results[i][1]) + assert_array_almost_equal(results[i][0], results[i + 1][0]) + assert_array_almost_equal(results[i][1], results[i + 1][1]) def test_unsupervised_inputs(): @@ -119,32 +123,35 @@ def test_unsupervised_radius_neighbors(n_samples=20, n_features=5, test = rng.rand(n_query_pts, n_features) - results = [] - - for algorithm in ALGORITHMS: - neigh = neighbors.NearestNeighbors(radius=radius, - algorithm=algorithm) - neigh.fit(X) - - ind1 = neigh.radius_neighbors(test, return_distance=False) - - # sort the results: this is not done automatically for - # radius searches - dist, ind = neigh.radius_neighbors(test, return_distance=True) - for (d, i, i1) in zip(dist, ind, ind1): - j = d.argsort() - d[:] = d[j] - i[:] = i[j] - results.append((dist, ind)) - - assert_array_almost_equal(np.concatenate(list(ind)), - np.concatenate(list(ind1))) - - for i in range(len(results) - 1): - assert_array_almost_equal(np.concatenate(list(results[i][0])), - np.concatenate(list(results[i + 1][0]))), - assert_array_almost_equal(np.concatenate(list(results[i][1])), - np.concatenate(list(results[i + 1][1]))) + for p in P: + results = [] + + for algorithm in ALGORITHMS: + neigh = neighbors.NearestNeighbors(radius=radius, + algorithm=algorithm, + p=p) + neigh.fit(X) + + ind1 = neigh.radius_neighbors(test, return_distance=False) + + # sort the results: this is not done automatically for + # radius searches + dist, ind = neigh.radius_neighbors(test, return_distance=True) + for (d, i, i1) in zip(dist, ind, ind1): + j = d.argsort() + d[:] = d[j] + i[:] = i[j] + i1[:] = i1[j] + results.append((dist, ind)) + + assert_array_almost_equal(np.concatenate(list(ind)), + np.concatenate(list(ind1))) + + for i in range(len(results) - 1): + assert_array_almost_equal(np.concatenate(list(results[i][0])), + np.concatenate(list(results[i + 1][0]))), + assert_array_almost_equal(np.concatenate(list(results[i][1])), + np.concatenate(list(results[i + 1][1]))) def test_kneighbors_classifier(n_samples=40, diff --git a/sklearn/neighbors/unsupervised.py b/sklearn/neighbors/unsupervised.py index 6149bd69bd4a89ec980c4f859b28c1dd973220c2..0b53270f6ddb9f0c593e45104ab69fb6bd1af9b6 100644 --- a/sklearn/neighbors/unsupervised.py +++ b/sklearn/neighbors/unsupervised.py @@ -45,6 +45,12 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin, ordering of the training data. If the fit method is ``'kd_tree'``, no warnings will be generated. + p: integer, optional (default = 2) + Parameter for the Minkowski metric from + sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is + equivalent to using manhattan_distance, and euclidean_distance for + p = 2. For arbitrary p, minkowski_distance is used. + Examples -------- >>> from sklearn.neighbors import NearestNeighbors @@ -78,9 +84,10 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin, def __init__(self, n_neighbors=5, radius=1.0, algorithm='auto', leaf_size=30, - warn_on_equidistant=True): + warn_on_equidistant=True, p=2): self._init_params(n_neighbors=n_neighbors, radius=radius, algorithm=algorithm, leaf_size=leaf_size, - warn_on_equidistant=warn_on_equidistant) + warn_on_equidistant=warn_on_equidistant, + p=p)