diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst
index 49ab552e9e93fc63d24affe123e602a4feab2039..5e4bf87e42620e5403f606c910128e6aa67e5848 100644
--- a/doc/modules/neighbors.rst
+++ b/doc/modules/neighbors.rst
@@ -33,8 +33,8 @@ handwritten digits or satellite image scenes. It is often successful
 in classification situations where the decision boundary is very irregular.
 
 The classes in :mod:`sklearn.neighbors` can handle either Numpy arrays or
-`scipy.sparse` matrices as input.  It currently supports only the Euclidean
-distance metric.
+`scipy.sparse` matrices as input.  Arbitrary Minkowski metrics are supported 
+for searches.
 
 
 Unsupervised Nearest Neighbors
diff --git a/doc/tutorial/statistical_inference/supervised_learning.rst b/doc/tutorial/statistical_inference/supervised_learning.rst
index 5c1171fe063b1043573c72251ab0dc03be635d5b..1088214755d5e382307b1cc9357dadfcf9d5c246 100644
--- a/doc/tutorial/statistical_inference/supervised_learning.rst
+++ b/doc/tutorial/statistical_inference/supervised_learning.rst
@@ -95,7 +95,7 @@ Scikit-learn documentation for more information about this type of classifier.)
     >>> from sklearn.neighbors import KNeighborsClassifier
     >>> knn = KNeighborsClassifier()
     >>> knn.fit(iris_X_train, iris_y_train)
-    KNeighborsClassifier(algorithm='auto', leaf_size=30, n_neighbors=5,
+    KNeighborsClassifier(algorithm='auto', leaf_size=30, n_neighbors=5, p=2,
                warn_on_equidistant=True, weights='uniform')
     >>> knn.predict(iris_X_test)
     array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py
index 87c3407c99b01423bcfe2df5720da9b29464ca14..87820d51f505b2984781955e92aaa4660d861bc9 100644
--- a/sklearn/neighbors/base.py
+++ b/sklearn/neighbors/base.py
@@ -13,7 +13,7 @@ from scipy.spatial.ckdtree import cKDTree
 
 from .ball_tree import BallTree
 from ..base import BaseEstimator
-from ..metrics import euclidean_distances
+from ..metrics import pairwise_distances
 from ..utils import safe_asarray, atleast2d_or_csr
 
 
@@ -71,15 +71,18 @@ class NeighborsBase(BaseEstimator):
     # rely on soon-to-be-updated functionality in the pairwise module.
     def _init_params(self, n_neighbors=None, radius=None,
                      algorithm='auto', leaf_size=30,
-                     warn_on_equidistant=True):
+                     warn_on_equidistant=True, p=2):
         self.n_neighbors = n_neighbors
         self.radius = radius
         self.algorithm = algorithm
         self.leaf_size = leaf_size
         self.warn_on_equidistant = warn_on_equidistant
+        self.p = p
 
         if algorithm not in ['auto', 'brute', 'kd_tree', 'ball_tree']:
             raise ValueError("unrecognized algorithm: '%s'" % algorithm)
+        if p < 1:
+            raise ValueError("p must be greater than or equal to 1")
 
         self._fit_X = None
         self._tree = None
@@ -131,7 +134,7 @@ class NeighborsBase(BaseEstimator):
         if self._fit_method == 'kd_tree':
             self._tree = cKDTree(X, self.leaf_size)
         elif self._fit_method == 'ball_tree':
-            self._tree = BallTree(X, self.leaf_size)
+            self._tree = BallTree(X, self.leaf_size, p=self.p)
         elif self._fit_method == 'brute':
             self._tree = None
         else:
@@ -202,7 +205,16 @@ class KNeighborsMixin(object):
             n_neighbors = self.n_neighbors
 
         if self._fit_method == 'brute':
-            dist = euclidean_distances(X, self._fit_X, squared=True)
+            if self.p == 1:
+                dist = pairwise_distances(X, self._fit_X, 'manhattan')
+            elif self.p == 2:
+                dist = pairwise_distances(X, self._fit_X, 'euclidean',
+                                          squared=True)
+            elif self.p == np.inf:
+                dist = pairwise_distances(X, self._fit_X, 'chebyshev')
+            else:
+                dist = pairwise_distances(X, self._fit_X, 'minkowski',
+                                          p=self.p)
             # XXX: should be implemented with a partial sort
             neigh_ind = dist.argsort(axis=1)
             if self.warn_on_equidistant and n_neighbors < self._fit_X.shape[0]:
@@ -214,7 +226,10 @@ class KNeighborsMixin(object):
             neigh_ind = neigh_ind[:, :n_neighbors]
             if return_distance:
                 j = np.arange(neigh_ind.shape[0])[:, None]
-                return np.sqrt(dist[j, neigh_ind]), neigh_ind
+                if self.p == 2:
+                    return np.sqrt(dist[j, neigh_ind]), neigh_ind
+                else:
+                    return dist[j, neigh_ind], neigh_ind
             else:
                 return neigh_ind
         elif self._fit_method == 'ball_tree':
@@ -224,7 +239,7 @@ class KNeighborsMixin(object):
                 warn_equidistant()
             return result
         elif self._fit_method == 'kd_tree':
-            dist, ind = self._tree.query(X, n_neighbors)
+            dist, ind = self._tree.query(X, n_neighbors, p=self.p)
             # kd_tree returns a 1D array for n_neighbors = 1
             if n_neighbors == 1:
                 dist = dist[:, None]
@@ -366,10 +381,19 @@ class RadiusNeighborsMixin(object):
             radius = self.radius
 
         if self._fit_method == 'brute':
-            dist = euclidean_distances(X, self._fit_X, squared=True)
-            rad2 = radius ** 2
+            if self.p == 1:
+                dist = pairwise_distances(X, self._fit_X, 'manhattan')
+            elif self.p == 2:
+                dist = pairwise_distances(X, self._fit_X, 'euclidean',
+                                          squared=True)
+                radius *= radius
+            elif self.p == np.inf:
+                dist = pairwise_distances(X, self._fit_X, 'chebyshev')
+            else:
+                dist = pairwise_distances(X, self._fit_X, 'minkowski',
+                                          p=self.p)
 
-            neigh_ind = [np.where(d < rad2)[0] for d in dist]
+            neigh_ind = [np.where(d < radius)[0] for d in dist]
 
             # if there are the same number of neighbors for each point,
             # we can do a normal array.  Otherwise, we return an object
@@ -382,9 +406,14 @@ class RadiusNeighborsMixin(object):
                 dtype_F = object
 
             if return_distance:
-                dist = np.array([np.sqrt(d[neigh_ind[i]]) \
-                                     for i, d in enumerate(dist)],
-                                dtype=dtype_F)
+                if self.p == 2:
+                    dist = np.array([np.sqrt(d[neigh_ind[i]]) \
+                                        for i, d in enumerate(dist)],
+                                    dtype=dtype_F)
+                else:
+                    dist = np.array([d[neigh_ind[i]] \
+                                         for i, d in enumerate(dist)],
+                                    dtype=dtype_F)
                 return dist, neigh_ind
             else:
                 return neigh_ind
@@ -400,7 +429,8 @@ class RadiusNeighborsMixin(object):
         elif self._fit_method == 'kd_tree':
             Npts = self._fit_X.shape[0]
             dist, ind = self._tree.query(X, Npts,
-                                         distance_upper_bound=radius)
+                                         distance_upper_bound=radius,
+                                         p=self.p)
 
             ind = [ind_i[:ind_i.searchsorted(Npts)] for ind_i in ind]
 
diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py
index f8ded2b1f87935046b6567e499d04169e4ee9f80..90aebfa35e5ecc4340339aaf25cbdfff2c47f243 100644
--- a/sklearn/neighbors/classification.py
+++ b/sklearn/neighbors/classification.py
@@ -68,6 +68,12 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin,
         ordering of the training data.
         If the fit method is ``'kd_tree'``, no warnings will be generated.
 
+    p: integer, optional (default = 2)
+        Parameter for the Minkowski metric from
+        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
+        equivalent to using manhattan_distance, and euclidean_distance for
+        p = 2. For arbitrary p, minkowski_distance is used.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
@@ -97,11 +103,12 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin,
     def __init__(self, n_neighbors=5,
                  weights='uniform',
                  algorithm='auto', leaf_size=30,
-                 warn_on_equidistant=True):
+                 warn_on_equidistant=True, p=2):
         self._init_params(n_neighbors=n_neighbors,
                           algorithm=algorithm,
                           leaf_size=leaf_size,
-                          warn_on_equidistant=warn_on_equidistant)
+                          warn_on_equidistant=warn_on_equidistant,
+                          p=p)
         self.weights = _check_weights(weights)
 
     def predict(self, X):
@@ -174,6 +181,12 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin,
         required to store the tree.  The optimal value depends on the
         nature of the problem.
 
+    p: integer, optional (default = 2)
+        Parameter for the Minkowski metric from
+        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
+        equivalent to using manhattan_distance, and euclidean_distance for
+        p = 2. For arbitrary p, minkowski_distance is used.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
@@ -201,10 +214,11 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin,
     """
 
     def __init__(self, radius=1.0, weights='uniform',
-                 algorithm='auto', leaf_size=30):
+                 algorithm='auto', leaf_size=30, p=2):
         self._init_params(radius=radius,
                           algorithm=algorithm,
-                          leaf_size=leaf_size)
+                          leaf_size=leaf_size,
+                          p=p)
         self.weights = _check_weights(weights)
 
     def predict(self, X):
diff --git a/sklearn/neighbors/regression.py b/sklearn/neighbors/regression.py
index 9f918db7810892d9494fdd0ef4aea428dbf948b0..a1acd28e93640c228d7498e262be1d51af064f8b 100644
--- a/sklearn/neighbors/regression.py
+++ b/sklearn/neighbors/regression.py
@@ -70,6 +70,12 @@ class KNeighborsRegressor(NeighborsBase, KNeighborsMixin,
         ordering of the training data.
         If the fit method is ``'kd_tree'``, no warnings will be generated.
 
+    p: integer, optional (default = 2)
+        Parameter for the Minkowski metric from
+        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
+        equivalent to using manhattan_distance, and euclidean_distance for
+        p = 2. For arbitrary p, minkowski_distance is used.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
@@ -97,11 +103,13 @@ class KNeighborsRegressor(NeighborsBase, KNeighborsMixin,
     """
 
     def __init__(self, n_neighbors=5, weights='uniform',
-                 algorithm='auto', leaf_size=30, warn_on_equidistant=True):
+                 algorithm='auto', leaf_size=30, warn_on_equidistant=True,
+                 p=2):
         self._init_params(n_neighbors=n_neighbors,
                           algorithm=algorithm,
                           leaf_size=leaf_size,
-                          warn_on_equidistant=warn_on_equidistant)
+                          warn_on_equidistant=warn_on_equidistant,
+                          p=p)
         self.weights = _check_weights(weights)
 
     def predict(self, X):
@@ -177,6 +185,12 @@ class RadiusNeighborsRegressor(NeighborsBase, RadiusNeighborsMixin,
         required to store the tree.  The optimal value depends on the
         nature of the problem.
 
+    p: integer, optional (default = 2)
+        Parameter for the Minkowski metric from
+        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
+        equivalent to using manhattan_distance, and euclidean_distance for
+        p = 2. For arbitrary p, minkowski_distance is used.
+
     Examples
     --------
     >>> X = [[0], [1], [2], [3]]
@@ -204,10 +218,11 @@ class RadiusNeighborsRegressor(NeighborsBase, RadiusNeighborsMixin,
     """
 
     def __init__(self, radius=1.0, weights='uniform',
-                 algorithm='auto', leaf_size=30):
+                 algorithm='auto', leaf_size=30, p=2):
         self._init_params(radius=radius,
                           algorithm=algorithm,
-                          leaf_size=leaf_size)
+                          leaf_size=leaf_size,
+                          p=p)
         self.weights = _check_weights(weights)
 
     def predict(self, X):
diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py
index 810d6358a9564e4005a91ab224d5495011fee7f7..86934f5353f2edd06cb84e9cbf4a342d31ca133a 100644
--- a/sklearn/neighbors/tests/test_neighbors.py
+++ b/sklearn/neighbors/tests/test_neighbors.py
@@ -19,6 +19,7 @@ SPARSE_TYPES = (bsr_matrix, coo_matrix, csc_matrix, csr_matrix, dok_matrix,
 SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,)
 
 ALGORITHMS = ('ball_tree', 'brute', 'kd_tree', 'auto')
+P = (1, 2, 3, 4, np.inf)
 
 
 def test_warn_on_equidistant(n_samples=100, n_features=3, k=3):
@@ -73,21 +74,24 @@ def test_unsupervised_kneighbors(n_samples=20, n_features=5,
 
     test = rng.rand(n_query_pts, n_features)
 
-    results_nodist = []
-    results = []
 
-    for algorithm in ALGORITHMS:
-        neigh = neighbors.NearestNeighbors(n_neighbors=n_neighbors,
-                                           algorithm=algorithm)
-        neigh.fit(X)
+    for p in P:
+        results_nodist = []
+        results = []
+        
+        for algorithm in ALGORITHMS:
+            neigh = neighbors.NearestNeighbors(n_neighbors=n_neighbors,
+                                               algorithm=algorithm,
+                                               p=p)
+            neigh.fit(X)
 
-        results_nodist.append(neigh.kneighbors(test, return_distance=False))
-        results.append(neigh.kneighbors(test, return_distance=True))
+            results_nodist.append(neigh.kneighbors(test, return_distance=False))
+            results.append(neigh.kneighbors(test, return_distance=True))
 
-    for i in range(len(results) - 1):
-        assert_array_almost_equal(results_nodist[i], results[i][1])
-        assert_array_almost_equal(results[i][0], results[i + 1][0])
-        assert_array_almost_equal(results[i][1], results[i + 1][1])
+        for i in range(len(results) - 1):
+            assert_array_almost_equal(results_nodist[i], results[i][1])
+            assert_array_almost_equal(results[i][0], results[i + 1][0])
+            assert_array_almost_equal(results[i][1], results[i + 1][1])
 
 
 def test_unsupervised_inputs():
@@ -119,32 +123,35 @@ def test_unsupervised_radius_neighbors(n_samples=20, n_features=5,
 
     test = rng.rand(n_query_pts, n_features)
 
-    results = []
-
-    for algorithm in ALGORITHMS:
-        neigh = neighbors.NearestNeighbors(radius=radius,
-                                           algorithm=algorithm)
-        neigh.fit(X)
-
-        ind1 = neigh.radius_neighbors(test, return_distance=False)
-
-        # sort the results: this is not done automatically for
-        # radius searches
-        dist, ind = neigh.radius_neighbors(test, return_distance=True)
-        for (d, i, i1) in zip(dist, ind, ind1):
-            j = d.argsort()
-            d[:] = d[j]
-            i[:] = i[j]
-        results.append((dist, ind))
-
-        assert_array_almost_equal(np.concatenate(list(ind)),
-                                  np.concatenate(list(ind1)))
-
-    for i in range(len(results) - 1):
-        assert_array_almost_equal(np.concatenate(list(results[i][0])),
-                                  np.concatenate(list(results[i + 1][0]))),
-        assert_array_almost_equal(np.concatenate(list(results[i][1])),
-                                  np.concatenate(list(results[i + 1][1])))
+    for p in P:
+        results = []
+
+        for algorithm in ALGORITHMS:
+            neigh = neighbors.NearestNeighbors(radius=radius,
+                                               algorithm=algorithm,
+                                               p=p)
+            neigh.fit(X)
+
+            ind1 = neigh.radius_neighbors(test, return_distance=False)
+
+            # sort the results: this is not done automatically for
+            # radius searches
+            dist, ind = neigh.radius_neighbors(test, return_distance=True)
+            for (d, i, i1) in zip(dist, ind, ind1):
+                j = d.argsort()
+                d[:] = d[j]
+                i[:] = i[j]
+                i1[:] = i1[j]
+            results.append((dist, ind))
+
+            assert_array_almost_equal(np.concatenate(list(ind)),
+                                      np.concatenate(list(ind1)))
+
+        for i in range(len(results) - 1):
+            assert_array_almost_equal(np.concatenate(list(results[i][0])),
+                                      np.concatenate(list(results[i + 1][0]))),
+            assert_array_almost_equal(np.concatenate(list(results[i][1])),
+                                      np.concatenate(list(results[i + 1][1])))
 
 
 def test_kneighbors_classifier(n_samples=40,
diff --git a/sklearn/neighbors/unsupervised.py b/sklearn/neighbors/unsupervised.py
index 6149bd69bd4a89ec980c4f859b28c1dd973220c2..0b53270f6ddb9f0c593e45104ab69fb6bd1af9b6 100644
--- a/sklearn/neighbors/unsupervised.py
+++ b/sklearn/neighbors/unsupervised.py
@@ -45,6 +45,12 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin,
         ordering of the training data.
         If the fit method is ``'kd_tree'``, no warnings will be generated.
 
+    p: integer, optional (default = 2)
+        Parameter for the Minkowski metric from
+        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
+        equivalent to using manhattan_distance, and euclidean_distance for
+        p = 2. For arbitrary p, minkowski_distance is used.
+
     Examples
     --------
       >>> from sklearn.neighbors import NearestNeighbors
@@ -78,9 +84,10 @@ class NearestNeighbors(NeighborsBase, KNeighborsMixin,
 
     def __init__(self, n_neighbors=5, radius=1.0,
                  algorithm='auto', leaf_size=30,
-                 warn_on_equidistant=True):
+                 warn_on_equidistant=True, p=2):
         self._init_params(n_neighbors=n_neighbors,
                           radius=radius,
                           algorithm=algorithm,
                           leaf_size=leaf_size,
-                          warn_on_equidistant=warn_on_equidistant)
+                          warn_on_equidistant=warn_on_equidistant,
+                          p=p)