diff --git a/doc/developers/performance.rst b/doc/developers/performance.rst index d962348e6d9cbf9628b2306292574d611c2a9c8a..df92089dd345b36f623d395c8e100bda3d80c254 100644 --- a/doc/developers/performance.rst +++ b/doc/developers/performance.rst @@ -174,17 +174,42 @@ order to better understand the profile of this specific function, let us install ``line-prof`` and wire it to IPython:: $ pip install line-profiler - $ vim ~/.ipython/ipy_user_conf.py -Ensure the following lines are present:: +- **Under IPython <= 0.10**, edit ``~/.ipython/ipy_user_conf.py`` and + ensure the following lines are present:: - import IPython.ipapi - ip = IPython.ipapi.get() + import IPython.ipapi + ip = IPython.ipapi.get() -Towards the end of the file, define the ``%lprun`` magic:: + Towards the end of the file, define the ``%lprun`` magic:: - import line_profiler - ip.expose_magic('lprun', line_profiler.magic_lprun) + import line_profiler + ip.expose_magic('lprun', line_profiler.magic_lprun) + +- **Under IPython 0.11+**, first create a configuration profile:: + + $ ipython profile create + + Then create a file named ``~/.ipython/extensions/line_profile_ext`` with + the following content:: + + import line_profiler + + def load_ipython_extension(ip): + ip.define_magic('lprun', line_profiler.magic_lprun) + + Then register it in ``~/.ipython/profile_default/ipython_config.py``:: + + c.TerminalIPythonApp.extensions = [ + 'line_profiler_ext', + ] + + This will register the ``%lprun`` magic command in the IPython terminal + client. + + You can do a similar operation ``ipython_notebook_config.py`` and + ``ipython_qtconsole_config`` to register the same extensions for the + HTML notebook and qtconsole clients. Now restart IPython and let us use this new toy:: diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 76eb410a37d9215ec87b0b2b13885db135ac1d90..ab472c206221fcb8da83af6cc1a34634ee392586 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -325,7 +325,10 @@ Pairwise metrics metrics.pairwise.linear_kernel metrics.pairwise.polynomial_kernel metrics.pairwise.rbf_kernel - + metrics.pairwise.distance_metrics + metrics.pairwise.pairwise_distances + metrics.pairwise.kernel_metrics + metrics.pairwise.pairwise_kernels Covariance Estimators ===================== diff --git a/sklearn/gaussian_process/gaussian_process.py b/sklearn/gaussian_process/gaussian_process.py index 1639396e99fa170a6e2bdd0db20f91aac67feba2..7f7b1dc1e6ef8b108fd8401b64e82cdc4b0c9b47 100644 --- a/sklearn/gaussian_process/gaussian_process.py +++ b/sklearn/gaussian_process/gaussian_process.py @@ -9,7 +9,7 @@ import numpy as np from scipy import linalg, optimize, rand from ..base import BaseEstimator, RegressorMixin -from ..metrics.pairwise import l1_distances +from ..metrics.pairwise import manhattan_distances from . import regression_models as regression from . import correlation_models as correlation @@ -413,8 +413,7 @@ class GaussianProcess(BaseEstimator, RegressorMixin): MSE = np.zeros(n_eval) # Get pairwise componentwise L1-distances to the input training set - dx = l1_distances(X, self.X) - + dx = manhattan_distances(X, Y=self.X, sum_over_features=False) # Get regression function and correlation f = self.regr(X) r = self.corr(self.theta, dx).reshape(n_eval, n_samples) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 4ed164bef2a26a9e729c6d6f55b3fe6e56bc15a5..4eff13b4a6472b9153d584050c62b2aa71718070 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -14,4 +14,4 @@ from .cluster import homogeneity_completeness_v_measure from .cluster import homogeneity_score from .cluster import completeness_score from .cluster import v_measure_score -from .pairwise import euclidean_distances, pairwise_distances +from .pairwise import euclidean_distances, pairwise_distances, pairwise_kernels diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index a85e8ce7e8f5ab2b83cd44330e7969dcf78ef257..70bfd87adaa5cbcdaddf0a825a5b55fa81124db5 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -749,5 +749,6 @@ def hinge_loss(y_true, pred_decision, pos_label=1, neg_label=-1): margin = y_true * pred_decision losses = 1 - margin - losses[losses <= 0] = 0 # hinge doesn't penalize good enough predictions + # The hinge doesn't penalize good enough predictions. + losses[losses <= 0] = 0 return np.mean(losses) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 2c513bbdb86dfb261d0f15e1298d92ec75e29d7d..7fb9a791a0435ca45fef62b9d911194158a5d4d2 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1,7 +1,36 @@ -"""Utilities to evaluate pairwise distances or affinity of sets of samples""" +""" Utilities to evaluate pairwise distances or affinity of sets of samples. + +This module contains both distance metrics and kernels. A brief summary is +given on the two here. + +Distance metrics are a function d(a, b) such that d(a, b) < d(a, c) if objects +a and b are considered "more similar" to objects a and c. Two objects exactly +alike would have a distance of zero. +One of the most popular examples is Euclidean distance. +To be a 'true' metric, it must obey the following four conditions: + +1. d(a, b) >= 0, for all a and b +2. d(a, b) == 0, if and only if a = b, positive definiteness +3. d(a, b) == d(b, a), symmetry +4. d(a, c) <= d(a, b) + d(b, c), the triangle inequality + + +Kernels are measures of similarity, i.e. s(a, b) > s(a, c) if objects a and b +are considered "more similar" to objects a and c. A kernel must also be +positive semi-definite. + +There are a number of ways to convert between a distance metric and a similarity +measure, such as a kernel. Let D be the distance, and S be the kernel: + +1. S = np.exp(-D * gamma), where one heuristic for choosing + gamma is 1 / num_features +2. S = 1. / (D / np.max(D)) + +""" # Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr> # Mathieu Blondel <mathieu@mblondel.org> +# Robert Layton <robertlayton@gmail.com> # License: BSD Style. import numpy as np @@ -10,61 +39,50 @@ from scipy.sparse import csr_matrix, issparse from ..utils import safe_asanyarray, atleast2d_or_csr, deprecated from ..utils.extmath import safe_sparse_dot -################################################################################ -# Distances -def pairwise_distances(X, Y=None, metric="euclidean"): - """ Calculates the distance matrix from a vector matrix X. +# Utility Functions +def check_pairwise_arrays(X, Y): + """ Set X and Y appropriately and checks inputs - This method takes either a vector array or a distance matrix, and returns - a distance matrix. If the input is a vector array, the distances are - computed. If the input is a distances matrix, it is returned instead. + If Y is None, it is set as a pointer to X (i.e. not a copy). + If Y is given, this does not happen. + All distance metrics should use this function first to assert that the + given parameters are correct and safe to use. - This method provides a safe way to take a distance matrix as input, while - preserving compatability with many other algorithms that take a vector - array. + Specifically, this function first ensures that both X and Y are arrays, + then checkes that they are at least two dimensional. Finally, the function + checks that the size of the second dimension of the two arrays is equal. Parameters ---------- - X: array [n_samples, n_samples] if metric == "precomputed", or, - [n_samples, n_features] otherwise - Array of pairwise distances between samples, or a feature array. + X: {array-like, sparse matrix}, shape = [n_samples_a, n_features] - X: array [n_samples, n_features] - A second feature array only if X has shape [n_samples, n_features]. - - metric: string, or callable - The metric to use when calculating distance between instances in a - feature array. If metric is a string, it must be one of the options - allowed by scipy.spatial.distance.pdist for its metric parameter. - If metric is "precomputed", X is assumed to be a distance matrix and - must be square. - Alternatively, if metric is a callable function, it is called on each - pair of instances (rows) and the resulting value recorded. The callable - should take two arrays from X as input and return a value indicating - the distance between them. + Y: {array-like, sparse matrix}, shape = [n_samples_b, n_features] Returns ------- - D: array [n_samples, n_samples] - A distance matrix D such that D_{i, j} is the distance between the - ith and jth vectors of the given matrix X. - - """ - if metric == "precomputed": - if X.shape[0] != X.shape[1]: - raise ValueError("X is not square!") - return X + safe_X: {array-like, sparse matrix}, shape = [n_samples_a, n_features] + An array equal to X, guarenteed to be a numpy array. - elif metric == "euclidean": - return euclidean_distances(X, Y) + safe_Y: {array-like, sparse matrix}, shape = [n_samples_b, n_features] + An array equal to Y if Y was not None, guarenteed to be a numpy array. + If Y was None, safe_Y will be a pointer to X. + """ + if Y is X or Y is None: + X = Y = safe_asanyarray(X) else: - # FIXME: the distance module doesn't support sparse matrices! - if Y is None: - return distance.squareform(distance.pdist(X, metric=metric)) - else: - return distance.cdist(X, Y, metric=metric) + X = safe_asanyarray(X) + Y = safe_asanyarray(Y) + X = atleast2d_or_csr(X) + Y = atleast2d_or_csr(Y) + if len(X.shape) < 2: + raise ValueError("X is required to be at least two dimensional.") + if len(Y.shape) < 2: + raise ValueError("Y is required to be at least two dimensional.") + if X.shape[1] != Y.shape[1]: + raise ValueError("Incompatible dimension for X and Y matrices") + return X, Y # Distances @@ -115,15 +133,7 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False): # should not need X_norm_squared because if you could precompute that as # well as Y, then you should just pre-compute the output and not even # call this function. - if Y is X or Y is None: - X = Y = safe_asanyarray(X) - else: - X = safe_asanyarray(X) - Y = safe_asanyarray(Y) - - if X.shape[1] != Y.shape[1]: - raise ValueError("Incompatible dimension for X and Y matrices") - + X, Y = check_pairwise_arrays(X, Y) if issparse(X): XX = X.multiply(X).sum(axis=1) else: @@ -167,56 +177,67 @@ def euclidian_distances(*args, **kwargs): return euclidean_distances(*args, **kwargs) +def manhattan_distances(X, Y=None, sum_over_features=True): + """ Compute the L1 distances between the vectors in X and Y. -def l1_distances(X, Y): - """ - Computes the componentwise L1 pairwise-distances between the vectors - in X and Y. + With sum_over_features equal to False it returns the componentwise + distances. Parameters ---------- X: array_like - An array with shape (n_samples_X, n_features) + An array with shape (n_samples_X, n_features). Y: array_like, optional An array with shape (n_samples_Y, n_features). + sum_over_features: bool, default=True + If True the function returns the pairwise distance matrix + else it returns the componentwise L1 pairwise-distances. + Returns ------- - D: array with shape (n_samples_X * n_samples_Y, n_features) - The array of componentwise L1 pairwise-distances. + D: array + If sum_over_features is False shape is + (n_samples_X * n_samples_Y, n_features) and D contains the + componentwise L1 pairwise-distances (ie. absolute difference), + else shape is (n_samples_X, n_samples_Y) and D contains + the pairwise l1 distances. Examples -------- - >>> from sklearn.metrics.pairwise import l1_distances - >>> l1_distances(3, 3) + >>> from sklearn.metrics.pairwise import manhattan_distances + >>> manhattan_distances(3, 3) array([[0]]) - >>> l1_distances(3, 2) + >>> manhattan_distances(3, 2) array([[1]]) - >>> l1_distances(2, 3) + >>> manhattan_distances(2, 3) array([[1]]) + >>> manhattan_distances([[1, 2], [3, 4]], [[1, 2], [0, 3]]) + array([[0, 2], + [4, 4]]) >>> import numpy as np >>> X = np.ones((1, 2)) - >>> y = 2*np.ones((2, 2)) - >>> l1_distances(X, y) + >>> y = 2 * np.ones((2, 2)) + >>> manhattan_distances(X, y, sum_over_features=False) array([[ 1., 1.], [ 1., 1.]]) """ - X, Y = np.atleast_2d(X), np.atleast_2d(Y) + X, Y = check_pairwise_arrays(X, Y) n_samples_X, n_features_X = X.shape n_samples_Y, n_features_Y = Y.shape if n_features_X != n_features_Y: raise Exception("X and Y should have the same number of features!") - else: - n_features = n_features_X D = np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :]) - D = D.reshape((n_samples_X * n_samples_Y, n_features)) - + if sum_over_features: + D = np.sum(D, axis=2) + else: + D = D.reshape((n_samples_X * n_samples_Y, n_features_X)) return D # Kernels -def linear_kernel(X, Y): +def linear_kernel(X, Y=None): """ Compute the linear kernel between X and Y. @@ -230,10 +251,11 @@ def linear_kernel(X, Y): ------- Gram matrix: array of shape (n_samples_1, n_samples_2) """ + X, Y = check_pairwise_arrays(X, Y) return safe_sparse_dot(X, Y.T, dense_output=True) -def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1): +def polynomial_kernel(X, Y=None, degree=3, gamma=0, coef0=1): """ Compute the polynomial kernel between X and Y. @@ -251,6 +273,7 @@ def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1): ------- Gram matrix: array of shape (n_samples_1, n_samples_2) """ + X, Y = check_pairwise_arrays(X, Y) if gamma == 0: gamma = 1.0 / X.shape[1] @@ -261,7 +284,7 @@ def polynomial_kernel(X, Y, degree=3, gamma=0, coef0=1): return K -def sigmoid_kernel(X, Y, gamma=0, coef0=1): +def sigmoid_kernel(X, Y=None, gamma=0, coef0=1): """ Compute the sigmoid kernel between X and Y. @@ -279,6 +302,7 @@ def sigmoid_kernel(X, Y, gamma=0, coef0=1): ------- Gram matrix: array of shape (n_samples_1, n_samples_2) """ + X, Y = check_pairwise_arrays(X, Y) if gamma == 0: gamma = 1.0 / X.shape[1] @@ -289,7 +313,7 @@ def sigmoid_kernel(X, Y, gamma=0, coef0=1): return K -def rbf_kernel(X, Y, gamma=0): +def rbf_kernel(X, Y=None, gamma=0): """ Compute the rbf (gaussian) kernel between X and Y. @@ -307,6 +331,7 @@ def rbf_kernel(X, Y, gamma=0): ------- Gram matrix: array of shape (n_samples_1, n_samples_2) """ + X, Y = check_pairwise_arrays(X, Y) if gamma == 0: gamma = 1.0 / X.shape[1] @@ -315,3 +340,246 @@ def rbf_kernel(X, Y, gamma=0): np.exp(K, K) # exponentiate K in-place return K + +# Helper functions - distance +pairwise_distance_functions = { + # If updating this dictionary, update the doc in both distance_metrics() + # and also in pairwise_distances()! + 'euclidean': euclidean_distances, + 'l2': euclidean_distances, + 'l1': manhattan_distances, + 'manhattan': manhattan_distances, + 'cityblock': manhattan_distances + } + + +def distance_metrics(): + """ Valid metrics for pairwise_distances + + This function simply returns the valid pairwise distance metrics. + It exists, however, to allow for a verbose description of the mapping for + each of the valid strings. + + The valid distance metrics, and the function they map to, are: + =========== ==================================== + metric Function + =========== ==================================== + 'cityblock' sklearn.pairwise.manhattan_distances + 'euclidean' sklearn.pairwise.euclidean_distances + 'l1' sklearn.pairwise.manhattan_distances + 'l2' sklearn.pairwise.euclidean_distances + 'manhattan' sklearn.pairwise.manhattan_distances + =========== ==================================== + """ + return pairwise_distance_functions + + +def pairwise_distances(X, Y=None, metric="euclidean", **kwds): + """ Compute the distance matrix from a vector array X and optional Y. + + This method takes either a vector array or a distance matrix, and returns + a distance matrix. If the input is a vector array, the distances are + computed. If the input is a distances matrix, it is returned instead. + + This method provides a safe way to take a distance matrix as input, while + preserving compatability with many other algorithms that take a vector + array. + + If Y is given (default is None), then the returned matrix is the pairwise + distance between the arrays from both X and Y. + + Please note that support for sparse matrices is currently limited to those + metrics listed in pairwise.pairwise_distance_functions. + + Valid values for metric are: + - from scikits.learn: ['euclidean', 'l2', 'l1', 'manhattan', 'cityblock'] + - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', + 'correlation', 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski', + 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', + 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeucludean', 'yule'] + See the documentation for scipy.spatial.distance for details on these + metrics. + Note in the case of 'euclidean' and 'cityblock' (which are valid + scipy.spatial.distance metrics), the values will use the scikits.learn + implementation, which is faster and has support for sparse matrices. + For a verbose description of the metrics from scikits.learn, see the + __doc__ of the sklearn.pairwise.distance_metrics function. + + Parameters + ---------- + X: array [n_samples_a, n_samples_a] if metric == "precomputed", or, + [n_samples_a, n_features] otherwise + Array of pairwise distances between samples, or a feature array. + + Y: array [n_samples_b, n_features] + A second feature array only if X has shape [n_samples_a, n_features]. + + metric: string, or callable + The metric to use when calculating distance between instances in a + feature array. If metric is a string, it must be one of the options + allowed by scipy.spatial.distance.pdist for its metric parameter, or + a metric listed in pairwise.pairwise_distance_functions. + If metric is "precomputed", X is assumed to be a distance matrix and + must be square. + Alternatively, if metric is a callable function, it is called on each + pair of instances (rows) and the resulting value recorded. The callable + should take two arrays from X as input and return a value indicating + the distance between them. + + **kwds: optional keyword parameters + Any further parameters are passed directly to the distance function. + If using a scipy.spatial.distance metric, the parameters are still + metric dependent. See the scipy docs for usage examples. + + Returns + ------- + D: array [n_samples_a, n_samples_a] or [n_samples_a, n_samples_b] + A distance matrix D such that D_{i, j} is the distance between the + ith and jth vectors of the given matrix X, if Y is None. + If Y is not None, then D_{i, j} is the distance between the ith array + from X and the jth array from Y. + + """ + if metric == "precomputed": + if X.shape[0] != X.shape[1]: + raise ValueError("X is not square!") + return X + elif metric in pairwise_distance_functions: + return pairwise_distance_functions[metric](X, Y, **kwds) + elif callable(metric): + # Check matrices first (this is usually done by the metric). + X, Y = check_pairwise_arrays(X, Y) + n_x, n_y = X.shape[0], Y.shape[0] + # Calculate distance for each element in X and Y. + D = np.zeros((n_x, n_y), dtype='float') + for i in range(n_x): + start = 0 + if X is Y: + start = i + for j in range(start, n_y): + # Kernel assumed to be symmetric. + D[i][j] = metric(X[i], Y[j], **kwds) + if X is Y: + D[j][i] = D[i][j] + return D + else: + # Note: the distance module doesn't support sparse matrices! + if type(X) is csr_matrix: + raise TypeError("scipy distance metrics do not" + " support sparse matrices.") + if Y is None: + return distance.squareform(distance.pdist(X, metric=metric, + **kwds)) + else: + if type(Y) is csr_matrix: + raise TypeError("scipy distance metrics do not" + " support sparse matrices.") + return distance.cdist(X, Y, metric=metric, **kwds) + + +# Helper functions - distance +pairwise_kernel_functions = { + # If updating this dictionary, update the doc in both distance_metrics() + # and also in pairwise_distances()! + 'rbf': rbf_kernel, + 'sigmoid': sigmoid_kernel, + 'polynomial': polynomial_kernel, + 'poly': polynomial_kernel, + 'linear': linear_kernel + } + + +def kernel_metrics(): + """ Valid metrics for pairwise_kernels + + This function simply returns the valid pairwise distance metrics. + It exists, however, to allow for a verbose description of the mapping for + each of the valid strings. + + The valid distance metrics, and the function they map to, are: + ============ ================================== + metric Function + ============ ================================== + 'linear' sklearn.pairwise.linear_kernel + 'poly' sklearn.pairwise.polynomial_kernel + 'polynomial' sklearn.pairwise.polynomial_kernel + 'rbf' sklearn.pairwise.rbf_kernel + 'sigmoid' sklearn.pairwise.sigmoid_kernel + ============ ================================== + """ + return pairwise_kernel_functions + + +def pairwise_kernels(X, Y=None, metric="linear", **kwds): + """ Compute the kernel between arrays X and optional array Y. + + This method takes either a vector array or a kernel matrix, and returns + a kernel matrix. If the input is a vector array, the kernels are + computed. If the input is a kernel matrix, it is returned instead. + + This method provides a safe way to take a kernel matrix as input, while + preserving compatability with many other algorithms that take a vector + array. + + If Y is given (default is None), then the returned matrix is the pairwise + kernel between the arrays from both X and Y. + + Valid values for metric are: + ['rbf', 'sigmoid', 'polynomial', 'poly', 'linear'] + + Parameters + ---------- + X: array [n_samples_a, n_samples_a] if metric == "precomputed", or, + [n_samples_a, n_features] otherwise + Array of pairwise kernels between samples, or a feature array. + + Y: array [n_samples_b, n_features] + A second feature array only if X has shape [n_samples_a, n_features]. + + metric: string, or callable + The metric to use when calculating kernel between instances in a + feature array. If metric is a string, it must be one of the metrics + in pairwise.pairwise_kernel_functions. + If metric is "precomputed", X is assumed to be a kernel matrix and + must be square. + Alternatively, if metric is a callable function, it is called on each + pair of instances (rows) and the resulting value recorded. The callable + should take two arrays from X as input and return a value indicating + the distance between them. + + **kwds: optional keyword parameters + Any further parameters are passed directly to the kernel function. + + Returns + ------- + K: array [n_samples_a, n_samples_a] or [n_samples_a, n_samples_b] + A kernel matrix K such that K_{i, j} is the kernel between the + ith and jth vectors of the given matrix X, if Y is None. + If Y is not None, then K_{i, j} is the kernel between the ith array + from X and the jth array from Y. + + """ + if metric == "precomputed": + if X.shape[0] != X.shape[1]: + raise ValueError("X is not square!") + return X + elif metric in pairwise_kernel_functions: + return pairwise_kernel_functions[metric](X, Y, **kwds) + elif callable(metric): + # Check matrices first (this is usually done by the metric). + X, Y = check_pairwise_arrays(X, Y) + n_x, n_y = X.shape[0], Y.shape[0] + # Calculate kernel for each element in X and Y. + K = np.zeros((n_x, n_y), dtype='float') + for i in range(n_x): + start = 0 + if X is Y: + start = i + for j in range(start, n_y): + # Kernel assumed to be symmetric. + K[i][j] = metric(X[i], Y[j], **kwds) + if X is Y: + K[j][i] = K[i][j] + return K + else: + raise AttributeError("Unknown metric %s" % metric) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 9bcd6159573a0a7217e15cf90bfd5b29b6eb996b..329dce0674ee822299ef3717cbfdb962c226936d 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -1,46 +1,123 @@ import numpy as np from numpy import linalg from numpy.testing import assert_array_almost_equal -from numpy.testing import assert_array_equal from numpy.testing import assert_equal from nose.tools import assert_raises from nose.tools import assert_true from scipy.sparse import csr_matrix +from scipy.spatial.distance import cosine, cityblock, minkowski -from ..pairwise import euclidean_distances, linear_kernel, polynomial_kernel, \ - rbf_kernel, sigmoid_kernel, pairwise_distances +from ..pairwise import (euclidean_distances, linear_kernel, polynomial_kernel, + rbf_kernel, sigmoid_kernel) +from .. import pairwise_distances, pairwise_kernels +from ..pairwise import pairwise_kernel_functions +from ..pairwise import check_pairwise_arrays np.random.seed(0) def test_pairwise_distances(): + """ Test the pairwise_distance helper function. """ rng = np.random.RandomState(0) + # Euclidean distance should be equivalent to calling the function. X = rng.random_sample((5, 4)) S = pairwise_distances(X, metric="euclidean") S2 = euclidean_distances(X) - assert_array_equal(S, S2) - - X2 = rng.random_sample((2, 4)) - S = pairwise_distances(X, X2, metric="euclidean") - S2 = euclidean_distances(X, X2) - assert_array_equal(S, S2) - + assert_array_almost_equal(S, S2) + # Euclidean distance, with Y != X. + Y = rng.random_sample((2, 4)) + S = pairwise_distances(X, Y, metric="euclidean") + S2 = euclidean_distances(X, Y) + assert_array_almost_equal(S, S2) + # Test with tuples as X and Y + X_tuples = tuple([tuple([v for v in row]) for row in X]) + Y_tuples = tuple([tuple([v for v in row]) for row in Y]) + S2 = pairwise_distances(X_tuples, Y_tuples, metric="euclidean") + assert_array_almost_equal(S, S2) + # "cityblock" uses sklearn metric, cityblock (function) is scipy.spatial. S = pairwise_distances(X, metric="cityblock") + S2 = pairwise_distances(X, metric=cityblock) assert_equal(S.shape[0], S.shape[1]) assert_equal(S.shape[0], X.shape[0]) - - S = pairwise_distances(X, X2, metric="cityblock") + assert_array_almost_equal(S, S2) + # The manhattan metric should be equivalent to cityblock. + S = pairwise_distances(X, Y, metric="manhattan") + S2 = pairwise_distances(X, Y, metric=cityblock) assert_equal(S.shape[0], X.shape[0]) - assert_equal(S.shape[1], X2.shape[0]) - + assert_equal(S.shape[1], Y.shape[0]) + assert_array_almost_equal(S, S2) + # Test cosine as a string metric versus cosine callable + S = pairwise_distances(X, Y, metric="cosine") + S2 = pairwise_distances(X, Y, metric=cosine) + assert_equal(S.shape[0], X.shape[0]) + assert_equal(S.shape[1], Y.shape[0]) + assert_array_almost_equal(S, S2) + # Tests that precomputed metric returns pointer to, and not copy of, X. S = np.dot(X, X.T) S2 = pairwise_distances(S, metric="precomputed") assert_true(S is S2) assert_raises(ValueError, pairwise_distances, X, None, "precomputed") + # Test with sparse X and Y + X_sparse = csr_matrix(X) + Y_sparse = csr_matrix(Y) + S = pairwise_distances(X_sparse, Y_sparse, metric="euclidean") + S2 = euclidean_distances(X_sparse, Y_sparse) + assert_array_almost_equal(S, S2) + # Test with scipy.spatial.distance metric, with a kwd + kwds = {"p":2.0} + S = pairwise_distances(X, Y, metric="minkowski", **kwds) + S2 = pairwise_distances(X, Y, metric=minkowski, **kwds) + assert_array_almost_equal(S, S2) + # Test that scipy distance metrics throw an error if sparse matrix given + assert_raises(TypeError, pairwise_distances, X_sparse, metric="minkowski") + assert_raises(TypeError, pairwise_distances, X, Y_sparse, + metric="minkowski") + + +def test_pairwise_kernels(): + """ Test the pairwise_kernels helper function. """ + rng = np.random.RandomState(0) + X = rng.random_sample((5, 4)) + Y = rng.random_sample((2, 4)) + # Test with all metrics that should be in pairwise_kernel_functions. + test_metrics = ["rbf", "sigmoid", "polynomial", "linear"] + for metric in test_metrics: + function = pairwise_kernel_functions[metric] + # Test with Y=None + K1 = pairwise_kernels(X, metric=metric) + K2 = function(X) + assert_array_almost_equal(K1, K2) + # Test with Y=Y + K1 = pairwise_kernels(X, Y=Y, metric=metric) + K2 = function(X, Y=Y) + assert_array_almost_equal(K1, K2) + # Test with tuples as X and Y + X_tuples = tuple([tuple([v for v in row]) for row in X]) + Y_tuples = tuple([tuple([v for v in row]) for row in Y]) + K2 = pairwise_kernels(X_tuples, Y_tuples, metric=metric) + assert_array_almost_equal(K1, K2) + # Test with sparse X and Y + X_sparse = csr_matrix(X) + Y_sparse = csr_matrix(Y) + K1 = pairwise_kernels(X_sparse, Y=Y_sparse, metric=metric) + assert_array_almost_equal(K1, K2) + # Test with a callable function, with given keywords. + metric = callable_rbf_kernel + kwds = {} + kwds['gamma'] = 0. + K1 = pairwise_kernels(X, Y=Y, metric=metric, **kwds) + K2 = rbf_kernel(X, Y=Y, **kwds) + assert_array_almost_equal(K1, K2) + + +def callable_rbf_kernel(x, y, **kwds): + """ Callable version of pairwise.rbf_kernel. """ + K = rbf_kernel(np.atleast_2d(x), np.atleast_2d(y), **kwds) + return K def test_euclidean_distances(): - """Check the pairwise Euclidean distances computation""" + """ Check the pairwise Euclidean distances computation""" X = [[0]] Y = [[1], [2]] D = euclidean_distances(X, Y) @@ -53,7 +130,7 @@ def test_euclidean_distances(): def test_kernel_symmetry(): - """valid kernels should be symmetric""" + """ Valid kernels should be symmetric""" rng = np.random.RandomState(0) X = rng.random_sample((5, 4)) for kernel in (linear_kernel, polynomial_kernel, rbf_kernel, @@ -87,3 +164,75 @@ def test_rbf_kernel(): K = rbf_kernel(X, X) # the diagonal elements of a rbf kernel are 1 assert_array_almost_equal(K.flat[::6], np.ones(5)) + + +def test_check_dense_matrices(): + """ Ensure that pairwise array check works for dense matrices.""" + # Check that if XB is None, XB is returned as reference to XA + XA = np.resize(np.arange(40), (5, 8)) + XA_checked, XB_checked = check_pairwise_arrays(XA, None) + assert_true(XA_checked is XB_checked) + assert_equal(XA, XA_checked) + + +def test_check_XB_returned(): + """ Ensure that if XA and XB are given correctly, they return as equal.""" + # Check that if XB is not None, it is returned equal. + # Note that the second dimension of XB is the same as XA. + XA = np.resize(np.arange(40), (5, 8)) + XB = np.resize(np.arange(32), (4, 8)) + XA_checked, XB_checked = check_pairwise_arrays(XA, XB) + assert_equal(XA, XA_checked) + assert_equal(XB, XB_checked) + + +def test_check_different_dimensions(): + """ Ensure an error is raised if the dimensions are different. """ + XA = np.resize(np.arange(45), (5, 9)) + XB = np.resize(np.arange(32), (4, 8)) + assert_raises(ValueError, check_pairwise_arrays, XA, XB) + + +def test_check_invalid_dimensions(): + """ Ensure an error is raised on 1D input arrays. """ + XA = np.arange(45) + XB = np.resize(np.arange(32), (4, 8)) + assert_raises(ValueError, check_pairwise_arrays, XA, XB) + XA = np.resize(np.arange(45), (5, 9)) + XB = np.arange(32) + assert_raises(ValueError, check_pairwise_arrays, XA, XB) + + +def test_check_sparse_arrays(): + """ Ensures that checks return valid sparse matrices. """ + rng = np.random.RandomState(0) + XA = rng.random_sample((5, 4)) + XA_sparse = csr_matrix(XA) + XB = rng.random_sample((5, 4)) + XB_sparse = csr_matrix(XB) + XA_checked, XB_checked = check_pairwise_arrays(XA_sparse, XB_sparse) + assert_equal(XA_sparse, XA_checked) + assert_equal(XB_sparse, XB_checked) + + +def tuplify(X): + """ Turns a numpy matrix (any n-dimensional array) into tuples.""" + s = X.shape + if len(s) > 1: + # Tuplify each sub-array in the input. + return tuple(tuplify(row) for row in X) + else: + # Single dimension input, just return tuple of contents. + return tuple(r for r in X) + + +def test_check_tuple_input(): + """ Ensures that checks return valid tuples. """ + rng = np.random.RandomState(0) + XA = rng.random_sample((5, 4)) + XA_tuples = tuplify(XA) + XB = rng.random_sample((5, 4)) + XB_tuples = tuplify(XB) + XA_checked, XB_checked = check_pairwise_arrays(XA_tuples, XB_tuples) + assert_equal(XA_tuples, XA_checked) + assert_equal(XB_tuples, XB_checked)