diff --git a/scikits/learn/metrics/pairwise.py b/scikits/learn/metrics/pairwise.py index ef32cae79729271a17b0d611977dd4dd17d45aa2..4dc2c05e76cb875ae4d5e0bc1b264ae2ae420596 100644 --- a/scikits/learn/metrics/pairwise.py +++ b/scikits/learn/metrics/pairwise.py @@ -9,36 +9,49 @@ sets of points. import numpy as np -def euclidian_distances(X, Y=None): +def euclidian_distances(X, Y): """ Considering the rows of X (and Y=X) as vectors, compute the - distance matrix between each pair of vector + distance matrix between each pair of vectors. Parameters ---------- - X, array of shape (n_samples_1, n_features) + X: array of shape (n_samples_1, n_features) - Y, array of shape (n_samples_2, n_features), default None - if Y is None, then Y=X is used instead + Y: array of shape (n_samples_2, n_features) Returns ------- - distances, array of shape (n_samples_1, n_samples_2) - """ + distances: array of shape (n_samples_1, n_samples_2) + + Examples + -------- + >>> X = [[0, 1], [1, 1]] + >>> # distrance between rows of X + >>> euclidian_distances(X, X) + array([[ 0., 1.], + [ 1., 0.]]) + >>> # get distance to origin + >>> euclidian_distances(X, [[0, 0]]) + array([[ 1. ], + [ 1.41421356]]) + """ + # shortcut in the common case euclidean_distances(X, X) + compute_Y = X is not Y + X = np.asanyarray(X) Y = np.asanyarray(Y) - if Y is None: - Y = X + if X.shape[1] != Y.shape[1]: - raise ValueError, "incompatible dimension for X and Y matrices" + raise ValueError("Incompatible dimension for X and Y matrices") XX = np.sum(X * X, axis=1)[:, np.newaxis] - if Y is None: - YY = XX.T - else: + if compute_Y: YY = np.sum(Y * Y, axis=1)[np.newaxis, :] + else: + YY = XX.T + distances = XX + YY # Using broadcasting distances -= 2 * np.dot(X, Y.T) distances = np.maximum(distances, 0) - distances = np.sqrt(distances) - return distances + return np.sqrt(distances) diff --git a/scikits/learn/metrics/tests/test_pairwise.py b/scikits/learn/metrics/tests/test_pairwise.py index 1d8b1dc54eb8d033517e61536b116691ffe8ae87..26b1c200b5259e00c2f6210e7ca13695c31c361e 100644 --- a/scikits/learn/metrics/tests/test_pairwise.py +++ b/scikits/learn/metrics/tests/test_pairwise.py @@ -9,3 +9,4 @@ def test_euclidian_distances(): Y = [[1], [2]] D = euclidian_distances(X, Y) assert_array_almost_equal(D, [[1., 2.]]) +