diff --git a/doc/developers/performance.rst b/doc/developers/performance.rst
index d962348e6d9cbf9628b2306292574d611c2a9c8a..df92089dd345b36f623d395c8e100bda3d80c254 100644
--- a/doc/developers/performance.rst
+++ b/doc/developers/performance.rst
@@ -174,17 +174,42 @@ order to better understand the profile of this specific function, let
 us install ``line-prof`` and wire it to IPython::
 
   $ pip install line-profiler
-  $ vim ~/.ipython/ipy_user_conf.py
 
-Ensure the following lines are present::
+- **Under IPython <= 0.10**, edit ``~/.ipython/ipy_user_conf.py`` and
+  ensure the following lines are present::
 
-  import IPython.ipapi
-  ip = IPython.ipapi.get()
+    import IPython.ipapi
+    ip = IPython.ipapi.get()
 
-Towards the end of the file, define the ``%lprun`` magic::
+  Towards the end of the file, define the ``%lprun`` magic::
 
-  import line_profiler
-  ip.expose_magic('lprun', line_profiler.magic_lprun)
+    import line_profiler
+    ip.expose_magic('lprun', line_profiler.magic_lprun)
+
+- **Under IPython 0.11+**, first create a configuration profile::
+
+    $ ipython profile create
+
+  Then create a file named ``~/.ipython/extensions/line_profile_ext`` with
+  the following content::
+
+    import line_profiler
+
+    def load_ipython_extension(ip):
+        ip.define_magic('lprun', line_profiler.magic_lprun)
+
+  Then register it in ``~/.ipython/profile_default/ipython_config.py``::
+
+    c.TerminalIPythonApp.extensions = [
+        'line_profiler_ext',
+    ]
+
+  This will register the ``%lprun`` magic command in the IPython terminal
+  client.
+
+  You can do a similar operation ``ipython_notebook_config.py`` and
+  ``ipython_qtconsole_config`` to register the same extensions for the
+  HTML notebook and qtconsole clients.
 
 Now restart IPython and let us use this new toy::
 
diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 76eb410a37d9215ec87b0b2b13885db135ac1d90..ab472c206221fcb8da83af6cc1a34634ee392586 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -325,7 +325,10 @@ Pairwise metrics
    metrics.pairwise.linear_kernel
    metrics.pairwise.polynomial_kernel
    metrics.pairwise.rbf_kernel
-
+   metrics.pairwise.distance_metrics
+   metrics.pairwise.pairwise_distances
+   metrics.pairwise.kernel_metrics
+   metrics.pairwise.pairwise_kernels
 
 Covariance Estimators
 =====================
diff --git a/sklearn/gaussian_process/gaussian_process.py b/sklearn/gaussian_process/gaussian_process.py
index 1639396e99fa170a6e2bdd0db20f91aac67feba2..7f7b1dc1e6ef8b108fd8401b64e82cdc4b0c9b47 100644
--- a/sklearn/gaussian_process/gaussian_process.py
+++ b/sklearn/gaussian_process/gaussian_process.py
@@ -9,7 +9,7 @@ import numpy as np
 from scipy import linalg, optimize, rand
 
 from ..base import BaseEstimator, RegressorMixin
-from ..metrics.pairwise import l1_distances
+from ..metrics.pairwise import manhattan_distances
 from . import regression_models as regression
 from . import correlation_models as correlation
 
@@ -413,8 +413,7 @@ class GaussianProcess(BaseEstimator, RegressorMixin):
                 MSE = np.zeros(n_eval)
 
             # Get pairwise componentwise L1-distances to the input training set
-            dx = l1_distances(X, self.X)
-
+            dx = manhattan_distances(X, Y=self.X, sum_over_features=False)
             # Get regression function and correlation
             f = self.regr(X)
             r = self.corr(self.theta, dx).reshape(n_eval, n_samples)
diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py
index 4ed164bef2a26a9e729c6d6f55b3fe6e56bc15a5..4eff13b4a6472b9153d584050c62b2aa71718070 100644
--- a/sklearn/metrics/__init__.py
+++ b/sklearn/metrics/__init__.py
@@ -14,4 +14,4 @@ from .cluster import homogeneity_completeness_v_measure
 from .cluster import homogeneity_score
 from .cluster import completeness_score
 from .cluster import v_measure_score
-from .pairwise import euclidean_distances, pairwise_distances
+from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels
diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py
index a85e8ce7e8f5ab2b83cd44330e7969dcf78ef257..70bfd87adaa5cbcdaddf0a825a5b55fa81124db5 100644
--- a/sklearn/metrics/metrics.py
+++ b/sklearn/metrics/metrics.py
@@ -749,5 +749,6 @@ def hinge_loss(y_true, pred_decision, pos_label=1, neg_label=-1):
 
     margin = y_true * pred_decision
     losses = 1 - margin
-    losses[losses <= 0] = 0  # hinge doesn't penalize good enough predictions
+    # The hinge doesn't penalize good enough predictions.
+    losses[losses <= 0] = 0
     return np.mean(losses)
diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py
index 2c513bbdb86dfb261d0f15e1298d92ec75e29d7d..7fb9a791a0435ca45fef62b9d911194158a5d4d2 100644
--- a/sklearn/metrics/pairwise.py
+++ b/sklearn/metrics/pairwise.py
@@ -1,7 +1,36 @@
-"""Utilities to evaluate pairwise distances or affinity of sets of samples"""
+""" Utilities to evaluate pairwise distances or affinity of sets of samples.
+
+This module contains both distance metrics and kernels. A brief summary is
+given on the two here.
+
+Distance metrics are a function d(a, b) such that d(a, b) < d(a, c) if objects
+a and b are considered "more similar" to objects a and c. Two objects exactly
+alike would have a distance of zero.
+One of the most popular examples is Euclidean distance.
+To be a 'true' metric, it must obey the following four conditions:
+
+1. d(a, b) >= 0, for all a and b
+2. d(a, b) == 0, if and only if a = b, positive definiteness
+3. d(a, b) == d(b, a), symmetry
+4. d(a, c) <= d(a, b) + d(b, c), the triangle inequality
+
+
+Kernels are measures of similarity, i.e. s(a, b) > s(a, c) if objects a and b
+are considered "more similar" to objects a and c. A kernel must also be
+positive semi-definite.
+
+There are a number of ways to convert between a distance metric and a similarity
+measure, such as a kernel. Let D be the distance, and S be the kernel:
+
+1. S = np.exp(-D * gamma), where one heuristic for choosing
+   gamma is 1 / num_features
+2. S = 1. / (D / np.max(D))
+
+"""
 
 # Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
 #          Mathieu Blondel <mathieu@mblondel.org>
+#          Robert Layton <robertlayton@gmail.com>
 # License: BSD Style.
 
 import numpy as np
@@ -10,61 +39,50 @@ from scipy.sparse import csr_matrix, issparse
 from ..utils import safe_asanyarray, atleast2d_or_csr, deprecated
 from ..utils.extmath import safe_sparse_dot
 
-################################################################################
-# Distances
 
-def pairwise_distances(X, Y=None, metric="euclidean"):
-    """ Calculates the distance matrix from a vector matrix X.
+# Utility Functions
+def check_pairwise_arrays(X, Y):
+    """ Set X and Y appropriately and checks inputs
 
-    This method takes either a vector array or a distance matrix, and returns
-    a distance matrix. If the input is a vector array, the distances are
-    computed. If the input is a distances matrix, it is returned instead.
+    If Y is None, it is set as a pointer to X (i.e. not a copy).
+    If Y is given, this does not happen.
+    All distance metrics should use this function first to assert that the
+    given parameters are correct and safe to use.
 
-    This method provides a safe way to take a distance matrix as input, while
-    preserving compatability with many other algorithms that take a vector
-    array.
+    Specifically, this function first ensures that both X and Y are arrays,
+    then checkes that they are at least two dimensional. Finally, the function
+    checks that the size of the second dimension of the two arrays is equal.
 
     Parameters
     ----------
-    X: array [n_samples, n_samples] if metric == "precomputed", or,
-             [n_samples, n_features] otherwise
-        Array of pairwise distances between samples, or a feature array.
+    X: {array-like, sparse matrix}, shape = [n_samples_a, n_features]
 
-    X: array [n_samples, n_features]
-        A second feature array only if X has shape [n_samples, n_features].
-
-    metric: string, or callable
-        The metric to use when calculating distance between instances in a
-        feature array. If metric is a string, it must be one of the options
-        allowed by scipy.spatial.distance.pdist for its metric parameter.
-        If metric is "precomputed", X is assumed to be a distance matrix and
-        must be square.
-        Alternatively, if metric is a callable function, it is called on each
-        pair of instances (rows) and the resulting value recorded. The callable
-        should take two arrays from X as input and return a value indicating
-        the distance between them.
+    Y: {array-like, sparse matrix}, shape = [n_samples_b, n_features]
 
     Returns
     -------
-    D: array [n_samples, n_samples]
-        A distance matrix D such that D_{i, j} is the distance between the
-        ith and jth vectors of the given matrix X.
-
-    """
-    if metric == "precomputed":
-        if X.shape[0] != X.shape[1]:
-            raise ValueError("X is not square!")
-        return X
+    safe_X: {array-like, sparse matrix}, shape = [n_samples_a, n_features]
+        An array equal to X, guarenteed to be a numpy array.
 
-    elif metric == "euclidean":
-        return euclidean_distances(X, Y)
+    safe_Y: {array-like, sparse matrix}, shape = [n_samples_b, n_features]
+        An array equal to Y if Y was not None, guarenteed to be a numpy array.
+        If Y was None, safe_Y will be a pointer to X.
 
+    """
+    if Y is X or Y is None:
+        X = Y = safe_asanyarray(X)
     else:
-        # FIXME: the distance module doesn't support sparse matrices!
-        if Y is None:
-            return distance.squareform(distance.pdist(X, metric=metric))
-        else:
-            return distance.cdist(X, Y, metric=metric)
+        X = safe_asanyarray(X)
+        Y = safe_asanyarray(Y)
+    X = atleast2d_or_csr(X)
+    Y = atleast2d_or_csr(Y)
+    if len(X.shape) < 2:
+        raise ValueError("X is required to be at least two dimensional.")
+    if len(Y.shape) < 2:
+        raise ValueError("Y is required to be at least two dimensional.")
+    if X.shape[1] != Y.shape[1]:
+        raise ValueError("Incompatible dimension for X and Y matrices")
+    return X, Y
 
 
 # Distances
@@ -115,15 +133,7 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
     # should not need X_norm_squared because if you could precompute that as
     # well as Y, then you should just pre-compute the output and not even
     # call this function.
-    if Y is X or Y is None:
-        X = Y = safe_asanyarray(X)
-    else:
-        X = safe_asanyarray(X)
-        Y = safe_asanyarray(Y)
-
-    if X.shape[1] != Y.shape[1]:
-        raise ValueError("Incompatible dimension for X and Y matrices")
-
+    X, Y = check_pairwise_arrays(X, Y)
     if issparse(X):
         XX = X.multiply(X).sum(axis=1)
     else:
@@ -167,56 +177,67 @@ def euclidian_distances(*args, **kwargs):
     return euclidean_distances(*args, **kwargs)
 
 
+def manhattan_distances(X, Y=None, sum_over_features=True):
+    """ Compute the L1 distances between the vectors in X and Y.
 
-def l1_distances(X, Y):
-    """
-    Computes the componentwise L1 pairwise-distances between the vectors
-    in X and Y.
+    With sum_over_features equal to False it returns the componentwise
+    distances.
 
     Parameters
     ----------
     X: array_like
-        An array with shape (n_samples_X, n_features)
+        An array with shape (n_samples_X, n_features).
 
     Y: array_like, optional
         An array with shape (n_samples_Y, n_features).
 
+    sum_over_features: bool, default=True
+        If True the function returns the pairwise distance matrix
+        else it returns the componentwise L1 pairwise-distances.
+
     Returns
     -------
-    D: array with shape (n_samples_X * n_samples_Y, n_features)
-        The array of componentwise L1 pairwise-distances.
+    D: array
+        If sum_over_features is False shape is
+        (n_samples_X * n_samples_Y, n_features) and D contains the
+        componentwise L1 pairwise-distances (ie. absolute difference),
+        else shape is (n_samples_X, n_samples_Y) and D contains
+        the pairwise l1 distances.
 
     Examples
     --------
-    >>> from sklearn.metrics.pairwise import l1_distances
-    >>> l1_distances(3, 3)
+    >>> from sklearn.metrics.pairwise import manhattan_distances
+    >>> manhattan_distances(3, 3)
     array([[0]])
-    >>> l1_distances(3, 2)
+    >>> manhattan_distances(3, 2)
     array([[1]])
-    >>> l1_distances(2, 3)
+    >>> manhattan_distances(2, 3)
     array([[1]])
+    >>> manhattan_distances([[1, 2], [3, 4]], [[1, 2], [0, 3]])
+    array([[0, 2],
+           [4, 4]])
     >>> import numpy as np
     >>> X = np.ones((1, 2))
-    >>> y = 2*np.ones((2, 2))
-    >>> l1_distances(X, y)
+    >>> y = 2 * np.ones((2, 2))
+    >>> manhattan_distances(X, y, sum_over_features=False)
     array([[ 1.,  1.],
            [ 1.,  1.]])
     """
-    X, Y = np.atleast_2d(X), np.atleast_2d(Y)
+    X, Y = check_pairwise_arrays(X, Y)
     n_samples_X, n_features_X = X.shape
     n_samples_Y, n_features_Y = Y.shape
     if n_features_X != n_features_Y:
         raise Exception("X and Y should have the same number of features!")
-    else:
-        n_features = n_features_X
     D = np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :])
-    D = D.reshape((n_samples_X * n_samples_Y, n_features))
-
+    if sum_over_features:
+        D = np.sum(D, axis=2)
+    else:
+        D = D.reshape((n_samples_X * n_samples_Y, n_features_X))
     return D
 
 
 # Kernels
-def linear_kernel(X, Y):
+def linear_kernel(X, Y=None):
     """
     Compute the linear kernel between X and Y.
 
@@ -230,10 +251,11 @@ def linear_kernel(X, Y):
     -------
     Gram matrix: array of shape (n_samples_1, n_samples_2)
     """
+    X, Y = check_pairwise_arrays(X, Y)
     return safe_sparse_dot(X, Y.T, dense_output=True)
 
 
-def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1):
+def polynomial_kernel(X, Y=None, degree=3, gamma=0, coef0=1):
     """
     Compute the polynomial kernel between X and Y.
 
@@ -251,6 +273,7 @@ def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1):
     -------
     Gram matrix: array of shape (n_samples_1, n_samples_2)
     """
+    X, Y = check_pairwise_arrays(X, Y)
     if gamma == 0:
         gamma = 1.0 / X.shape[1]
 
@@ -261,7 +284,7 @@ def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1):
     return K
 
 
-def sigmoid_kernel(X, Y, gamma=0, coef0=1):
+def sigmoid_kernel(X, Y=None, gamma=0, coef0=1):
     """
     Compute the sigmoid kernel between X and Y.
 
@@ -279,6 +302,7 @@ def sigmoid_kernel(X, Y, gamma=0, coef0=1):
     -------
     Gram matrix: array of shape (n_samples_1, n_samples_2)
     """
+    X, Y = check_pairwise_arrays(X, Y)
     if gamma == 0:
         gamma = 1.0 / X.shape[1]
 
@@ -289,7 +313,7 @@ def sigmoid_kernel(X, Y, gamma=0, coef0=1):
     return K
 
 
-def rbf_kernel(X, Y, gamma=0):
+def rbf_kernel(X, Y=None, gamma=0):
     """
     Compute the rbf (gaussian) kernel between X and Y.
 
@@ -307,6 +331,7 @@ def rbf_kernel(X, Y, gamma=0):
     -------
     Gram matrix: array of shape (n_samples_1, n_samples_2)
     """
+    X, Y = check_pairwise_arrays(X, Y)
     if gamma == 0:
         gamma = 1.0 / X.shape[1]
 
@@ -315,3 +340,246 @@ def rbf_kernel(X, Y, gamma=0):
     np.exp(K, K)    # exponentiate K in-place
     return K
 
+
+# Helper functions - distance
+pairwise_distance_functions = {
+    # If updating this dictionary, update the doc in both distance_metrics()
+    # and also in pairwise_distances()!
+    'euclidean': euclidean_distances,
+    'l2': euclidean_distances,
+    'l1': manhattan_distances,
+    'manhattan': manhattan_distances,
+    'cityblock': manhattan_distances
+    }
+
+
+def distance_metrics():
+    """ Valid metrics for pairwise_distances
+
+    This function simply returns the valid pairwise distance metrics.
+    It exists, however, to allow for a verbose description of the mapping for
+    each of the valid strings.
+
+    The valid distance metrics, and the function they map to, are:
+      ===========     ====================================
+      metric          Function
+      ===========     ====================================
+      'cityblock'     sklearn.pairwise.manhattan_distances
+      'euclidean'     sklearn.pairwise.euclidean_distances
+      'l1'            sklearn.pairwise.manhattan_distances
+      'l2'            sklearn.pairwise.euclidean_distances
+      'manhattan'     sklearn.pairwise.manhattan_distances
+      ===========     ====================================
+    """
+    return pairwise_distance_functions
+
+
+def pairwise_distances(X, Y=None, metric="euclidean", **kwds):
+    """ Compute the distance matrix from a vector array X and optional Y.
+
+    This method takes either a vector array or a distance matrix, and returns
+    a distance matrix. If the input is a vector array, the distances are
+    computed. If the input is a distances matrix, it is returned instead.
+
+    This method provides a safe way to take a distance matrix as input, while
+    preserving compatability with many other algorithms that take a vector
+    array.
+
+    If Y is given (default is None), then the returned matrix is the pairwise
+    distance between the arrays from both X and Y.
+
+    Please note that support for sparse matrices is currently limited to those
+    metrics listed in pairwise.pairwise_distance_functions.
+
+    Valid values for metric are:
+    - from scikits.learn: ['euclidean', 'l2', 'l1', 'manhattan', 'cityblock']
+    - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
+      'correlation', 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski',
+      'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
+      'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeucludean', 'yule']
+      See the documentation for scipy.spatial.distance for details on these
+      metrics.
+    Note in the case of 'euclidean' and 'cityblock' (which are valid
+    scipy.spatial.distance metrics), the values will use the scikits.learn
+    implementation, which is faster and has support for sparse matrices.
+    For a verbose description of the metrics from scikits.learn, see the
+    __doc__ of the sklearn.pairwise.distance_metrics function.
+
+    Parameters
+    ----------
+    X: array [n_samples_a, n_samples_a] if metric == "precomputed", or,
+             [n_samples_a, n_features] otherwise
+        Array of pairwise distances between samples, or a feature array.
+
+    Y: array [n_samples_b, n_features]
+        A second feature array only if X has shape [n_samples_a, n_features].
+
+    metric: string, or callable
+        The metric to use when calculating distance between instances in a
+        feature array. If metric is a string, it must be one of the options
+        allowed by scipy.spatial.distance.pdist for its metric parameter, or
+        a metric listed in pairwise.pairwise_distance_functions.
+        If metric is "precomputed", X is assumed to be a distance matrix and
+        must be square.
+        Alternatively, if metric is a callable function, it is called on each
+        pair of instances (rows) and the resulting value recorded. The callable
+        should take two arrays from X as input and return a value indicating
+        the distance between them.
+
+    **kwds: optional keyword parameters
+        Any further parameters are passed directly to the distance function.
+        If using a scipy.spatial.distance metric, the parameters are still
+        metric dependent. See the scipy docs for usage examples.
+
+    Returns
+    -------
+    D: array [n_samples_a, n_samples_a] or [n_samples_a, n_samples_b]
+        A distance matrix D such that D_{i, j} is the distance between the
+        ith and jth vectors of the given matrix X, if Y is None.
+        If Y is not None, then D_{i, j} is the distance between the ith array
+        from X and the jth array from Y.
+
+    """
+    if metric == "precomputed":
+        if X.shape[0] != X.shape[1]:
+            raise ValueError("X is not square!")
+        return X
+    elif metric in pairwise_distance_functions:
+        return pairwise_distance_functions[metric](X, Y, **kwds)
+    elif callable(metric):
+        # Check matrices first (this is usually done by the metric).
+        X, Y = check_pairwise_arrays(X, Y)
+        n_x, n_y = X.shape[0], Y.shape[0]
+        # Calculate distance for each element in X and Y.
+        D = np.zeros((n_x, n_y), dtype='float')
+        for i in range(n_x):
+            start = 0
+            if X is Y:
+                start = i
+            for j in range(start, n_y):
+                # Kernel assumed to be symmetric.
+                D[i][j] = metric(X[i], Y[j], **kwds)
+                if X is Y:
+                    D[j][i] = D[i][j]
+        return D
+    else:
+        # Note: the distance module doesn't support sparse matrices!
+        if type(X) is csr_matrix:
+            raise TypeError("scipy distance metrics do not"
+                            " support sparse matrices.")
+        if Y is None:
+            return distance.squareform(distance.pdist(X, metric=metric,
+                                                      **kwds))
+        else:
+            if type(Y) is csr_matrix:
+                raise TypeError("scipy distance metrics do not"
+                                " support sparse matrices.")
+            return distance.cdist(X, Y, metric=metric, **kwds)
+
+
+# Helper functions - distance
+pairwise_kernel_functions = {
+    # If updating this dictionary, update the doc in both distance_metrics()
+    # and also in pairwise_distances()!
+    'rbf': rbf_kernel,
+    'sigmoid': sigmoid_kernel,
+    'polynomial': polynomial_kernel,
+    'poly': polynomial_kernel,
+    'linear': linear_kernel
+    }
+
+
+def kernel_metrics():
+    """ Valid metrics for pairwise_kernels
+
+    This function simply returns the valid pairwise distance metrics.
+    It exists, however, to allow for a verbose description of the mapping for
+    each of the valid strings.
+
+    The valid distance metrics, and the function they map to, are:
+      ============   ==================================
+      metric         Function
+      ============   ==================================
+      'linear'       sklearn.pairwise.linear_kernel
+      'poly'         sklearn.pairwise.polynomial_kernel
+      'polynomial'   sklearn.pairwise.polynomial_kernel
+      'rbf'          sklearn.pairwise.rbf_kernel
+      'sigmoid'      sklearn.pairwise.sigmoid_kernel
+      ============   ==================================
+    """
+    return pairwise_kernel_functions
+
+
+def pairwise_kernels(X, Y=None, metric="linear", **kwds):
+    """ Compute the kernel between arrays X and optional array Y.
+
+    This method takes either a vector array or a kernel matrix, and returns
+    a kernel matrix. If the input is a vector array, the kernels are
+    computed. If the input is a kernel matrix, it is returned instead.
+
+    This method provides a safe way to take a kernel matrix as input, while
+    preserving compatability with many other algorithms that take a vector
+    array.
+
+    If Y is given (default is None), then the returned matrix is the pairwise
+    kernel between the arrays from both X and Y.
+
+    Valid values for metric are:
+    ['rbf', 'sigmoid', 'polynomial', 'poly', 'linear']
+
+    Parameters
+    ----------
+    X: array [n_samples_a, n_samples_a] if metric == "precomputed", or,
+             [n_samples_a, n_features] otherwise
+        Array of pairwise kernels between samples, or a feature array.
+
+    Y: array [n_samples_b, n_features]
+        A second feature array only if X has shape [n_samples_a, n_features].
+
+    metric: string, or callable
+        The metric to use when calculating kernel between instances in a
+        feature array. If metric is a string, it must be one of the metrics
+        in pairwise.pairwise_kernel_functions.
+        If metric is "precomputed", X is assumed to be a kernel matrix and
+        must be square.
+        Alternatively, if metric is a callable function, it is called on each
+        pair of instances (rows) and the resulting value recorded. The callable
+        should take two arrays from X as input and return a value indicating
+        the distance between them.
+
+    **kwds: optional keyword parameters
+        Any further parameters are passed directly to the kernel function.
+
+    Returns
+    -------
+    K: array [n_samples_a, n_samples_a] or [n_samples_a, n_samples_b]
+        A kernel matrix K such that K_{i, j} is the kernel between the
+        ith and jth vectors of the given matrix X, if Y is None.
+        If Y is not None, then K_{i, j} is the kernel between the ith array
+        from X and the jth array from Y.
+
+    """
+    if metric == "precomputed":
+        if X.shape[0] != X.shape[1]:
+            raise ValueError("X is not square!")
+        return X
+    elif metric in pairwise_kernel_functions:
+        return pairwise_kernel_functions[metric](X, Y, **kwds)
+    elif callable(metric):
+        # Check matrices first (this is usually done by the metric).
+        X, Y = check_pairwise_arrays(X, Y)
+        n_x, n_y = X.shape[0], Y.shape[0]
+        # Calculate kernel for each element in X and Y.
+        K = np.zeros((n_x, n_y), dtype='float')
+        for i in range(n_x):
+            start = 0
+            if X is Y:
+                start = i
+            for j in range(start, n_y):
+                # Kernel assumed to be symmetric.
+                K[i][j] = metric(X[i], Y[j], **kwds)
+                if X is Y:
+                    K[j][i] = K[i][j]
+        return K
+    else:
+        raise AttributeError("Unknown metric %s" % metric)
diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py
index 9bcd6159573a0a7217e15cf90bfd5b29b6eb996b..329dce0674ee822299ef3717cbfdb962c226936d 100644
--- a/sklearn/metrics/tests/test_pairwise.py
+++ b/sklearn/metrics/tests/test_pairwise.py
@@ -1,46 +1,123 @@
 import numpy as np
 from numpy import linalg
 from numpy.testing import assert_array_almost_equal
-from numpy.testing import assert_array_equal
 from numpy.testing import assert_equal
 from nose.tools import assert_raises
 from nose.tools import assert_true
 from scipy.sparse import csr_matrix
+from scipy.spatial.distance import cosine, cityblock, minkowski
 
-from ..pairwise import euclidean_distances, linear_kernel, polynomial_kernel, \
-                       rbf_kernel, sigmoid_kernel, pairwise_distances
+from ..pairwise import (euclidean_distances, linear_kernel, polynomial_kernel,
+                        rbf_kernel, sigmoid_kernel)
+from .. import pairwise_distances, pairwise_kernels
+from ..pairwise import pairwise_kernel_functions
+from ..pairwise import check_pairwise_arrays
 
 np.random.seed(0)
 
 
 def test_pairwise_distances():
+    """ Test the pairwise_distance helper function. """
     rng = np.random.RandomState(0)
+    # Euclidean distance should be equivalent to calling the function.
     X = rng.random_sample((5, 4))
     S = pairwise_distances(X, metric="euclidean")
     S2 = euclidean_distances(X)
-    assert_array_equal(S, S2)
-
-    X2 = rng.random_sample((2, 4))
-    S = pairwise_distances(X, X2, metric="euclidean")
-    S2 = euclidean_distances(X, X2)
-    assert_array_equal(S, S2)
-
+    assert_array_almost_equal(S, S2)
+    # Euclidean distance, with Y != X.
+    Y = rng.random_sample((2, 4))
+    S = pairwise_distances(X, Y, metric="euclidean")
+    S2 = euclidean_distances(X, Y)
+    assert_array_almost_equal(S, S2)
+    # Test with tuples as X and Y
+    X_tuples = tuple([tuple([v for v in row]) for row in X])
+    Y_tuples = tuple([tuple([v for v in row]) for row in Y])
+    S2 = pairwise_distances(X_tuples, Y_tuples, metric="euclidean")
+    assert_array_almost_equal(S, S2)
+    # "cityblock" uses sklearn metric, cityblock (function) is scipy.spatial.
     S = pairwise_distances(X, metric="cityblock")
+    S2 = pairwise_distances(X, metric=cityblock)
     assert_equal(S.shape[0], S.shape[1])
     assert_equal(S.shape[0], X.shape[0])
-
-    S = pairwise_distances(X, X2, metric="cityblock")
+    assert_array_almost_equal(S, S2)
+    # The manhattan metric should be equivalent to cityblock.
+    S = pairwise_distances(X, Y, metric="manhattan")
+    S2 = pairwise_distances(X, Y, metric=cityblock)
     assert_equal(S.shape[0], X.shape[0])
-    assert_equal(S.shape[1], X2.shape[0])
-
+    assert_equal(S.shape[1], Y.shape[0])
+    assert_array_almost_equal(S, S2)
+    # Test cosine as a string metric versus cosine callable
+    S = pairwise_distances(X, Y, metric="cosine")
+    S2 = pairwise_distances(X, Y, metric=cosine)
+    assert_equal(S.shape[0], X.shape[0])
+    assert_equal(S.shape[1], Y.shape[0])
+    assert_array_almost_equal(S, S2)
+    # Tests that precomputed metric returns pointer to, and not copy of, X.
     S = np.dot(X, X.T)
     S2 = pairwise_distances(S, metric="precomputed")
     assert_true(S is S2)
     assert_raises(ValueError, pairwise_distances, X, None, "precomputed")
+    # Test with sparse X and Y
+    X_sparse = csr_matrix(X)
+    Y_sparse = csr_matrix(Y)
+    S = pairwise_distances(X_sparse, Y_sparse, metric="euclidean")
+    S2 = euclidean_distances(X_sparse, Y_sparse)
+    assert_array_almost_equal(S, S2)
+    # Test with scipy.spatial.distance metric, with a kwd
+    kwds = {"p":2.0}
+    S = pairwise_distances(X, Y, metric="minkowski", **kwds)
+    S2 = pairwise_distances(X, Y, metric=minkowski, **kwds)
+    assert_array_almost_equal(S, S2)
+    # Test that scipy distance metrics throw an error if sparse matrix given
+    assert_raises(TypeError, pairwise_distances, X_sparse, metric="minkowski")
+    assert_raises(TypeError, pairwise_distances, X, Y_sparse,
+                  metric="minkowski")
+    
+
+def test_pairwise_kernels():
+    """ Test the pairwise_kernels helper function. """
+    rng = np.random.RandomState(0)
+    X = rng.random_sample((5, 4))
+    Y = rng.random_sample((2, 4))
+    # Test with all metrics that should be in pairwise_kernel_functions.
+    test_metrics = ["rbf", "sigmoid", "polynomial", "linear"]
+    for metric in test_metrics:
+        function = pairwise_kernel_functions[metric]
+        # Test with Y=None
+        K1 = pairwise_kernels(X, metric=metric)
+        K2 = function(X)
+        assert_array_almost_equal(K1, K2)
+        # Test with Y=Y
+        K1 = pairwise_kernels(X, Y=Y, metric=metric)
+        K2 = function(X, Y=Y)
+        assert_array_almost_equal(K1, K2)
+        # Test with tuples as X and Y
+        X_tuples = tuple([tuple([v for v in row]) for row in X])
+        Y_tuples = tuple([tuple([v for v in row]) for row in Y])
+        K2 = pairwise_kernels(X_tuples, Y_tuples, metric=metric)
+        assert_array_almost_equal(K1, K2)
+        # Test with sparse X and Y
+        X_sparse = csr_matrix(X)
+        Y_sparse = csr_matrix(Y)
+        K1 = pairwise_kernels(X_sparse, Y=Y_sparse, metric=metric)
+        assert_array_almost_equal(K1, K2)
+    # Test with a callable function, with given keywords.
+    metric = callable_rbf_kernel
+    kwds = {}
+    kwds['gamma'] = 0.
+    K1 = pairwise_kernels(X, Y=Y, metric=metric, **kwds)
+    K2 = rbf_kernel(X, Y=Y, **kwds)
+    assert_array_almost_equal(K1, K2)
+
+
+def callable_rbf_kernel(x, y, **kwds):
+    """ Callable version of pairwise.rbf_kernel. """
+    K = rbf_kernel(np.atleast_2d(x), np.atleast_2d(y), **kwds)
+    return K
 
 
 def test_euclidean_distances():
-    """Check the pairwise Euclidean distances computation"""
+    """ Check the pairwise Euclidean distances computation"""
     X = [[0]]
     Y = [[1], [2]]
     D = euclidean_distances(X, Y)
@@ -53,7 +130,7 @@ def test_euclidean_distances():
 
 
 def test_kernel_symmetry():
-    """valid kernels should be symmetric"""
+    """ Valid kernels should be symmetric"""
     rng = np.random.RandomState(0)
     X = rng.random_sample((5, 4))
     for kernel in (linear_kernel, polynomial_kernel, rbf_kernel,
@@ -87,3 +164,75 @@ def test_rbf_kernel():
     K = rbf_kernel(X, X)
     # the diagonal elements of a rbf kernel are 1
     assert_array_almost_equal(K.flat[::6], np.ones(5))
+
+
+def test_check_dense_matrices():
+    """ Ensure that pairwise array check works for dense matrices."""
+    # Check that if XB is None, XB is returned as reference to XA
+    XA = np.resize(np.arange(40), (5, 8))
+    XA_checked, XB_checked = check_pairwise_arrays(XA, None)
+    assert_true(XA_checked is XB_checked)
+    assert_equal(XA, XA_checked)
+
+
+def test_check_XB_returned():
+    """ Ensure that if XA and XB are given correctly, they return as equal."""
+    # Check that if XB is not None, it is returned equal.
+    # Note that the second dimension of XB is the same as XA.
+    XA = np.resize(np.arange(40), (5, 8))
+    XB = np.resize(np.arange(32), (4, 8))
+    XA_checked, XB_checked = check_pairwise_arrays(XA, XB)
+    assert_equal(XA, XA_checked)
+    assert_equal(XB, XB_checked)
+
+
+def test_check_different_dimensions():
+    """ Ensure an error is raised if the dimensions are different. """
+    XA = np.resize(np.arange(45), (5, 9))
+    XB = np.resize(np.arange(32), (4, 8))
+    assert_raises(ValueError, check_pairwise_arrays, XA, XB)
+
+
+def test_check_invalid_dimensions():
+    """ Ensure an error is raised on 1D input arrays. """
+    XA = np.arange(45)
+    XB = np.resize(np.arange(32), (4, 8))
+    assert_raises(ValueError, check_pairwise_arrays, XA, XB)
+    XA = np.resize(np.arange(45), (5, 9))
+    XB = np.arange(32)
+    assert_raises(ValueError, check_pairwise_arrays, XA, XB)
+
+
+def test_check_sparse_arrays():
+    """ Ensures that checks return valid sparse matrices. """
+    rng = np.random.RandomState(0)
+    XA = rng.random_sample((5, 4))
+    XA_sparse = csr_matrix(XA)
+    XB = rng.random_sample((5, 4))
+    XB_sparse = csr_matrix(XB)
+    XA_checked, XB_checked = check_pairwise_arrays(XA_sparse, XB_sparse)
+    assert_equal(XA_sparse, XA_checked)
+    assert_equal(XB_sparse, XB_checked)
+
+
+def tuplify(X):
+    """ Turns a numpy matrix (any n-dimensional array) into tuples."""
+    s = X.shape
+    if len(s) > 1:
+        # Tuplify each sub-array in the input.
+        return tuple(tuplify(row) for row in X)
+    else:
+        # Single dimension input, just return tuple of contents.
+        return tuple(r for r in X)
+
+
+def test_check_tuple_input():
+    """ Ensures that checks return valid tuples. """
+    rng = np.random.RandomState(0)
+    XA = rng.random_sample((5, 4))
+    XA_tuples = tuplify(XA)
+    XB = rng.random_sample((5, 4))
+    XB_tuples = tuplify(XB)
+    XA_checked, XB_checked = check_pairwise_arrays(XA_tuples, XB_tuples)
+    assert_equal(XA_tuples, XA_checked)
+    assert_equal(XB_tuples, XB_checked)