diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index e18658b7025a0dfe184f6f4950143beff62b4358..0f455675f03c60a6c410d6e4a94d9717d387a241 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -41,51 +41,32 @@ def _check_weights(weights): "'distance', or a callable function") -def _dist_to_weight(dist): - """ Calculates weights from distances. Replaces line - "weights = 1. / dist", which gives warning div by zeros, - if one sample in dist is zero. - - Takes dist matrix, which can be multidimensional. - Returns weights matrix of same dimension.""" - - # Dist could be multidimensional, flatten it so it's values - # can be looped. - dist_values = dist.ravel() - retval = np.zeros(len(dist_values)) - for i, d in enumerate(dist_values): - retval[i] = (1.0 / d) if (d != 0.0) else np.inf - return retval.reshape(dist.shape) - - def _get_weights(dist, weights): """Get the weights from an array of distances and a parameter ``weights`` - ``weights`` can be either a string or an executable. - - returns ``weights_arr``, an array of the same size as ``dist`` - if ``weights == 'uniform'``, then returns None + Parameters + =========== + dist: ndarray + The input distances + weights: {'uniform', 'distance' or a callable} + The kind of weighting used + + Returns + ======== + weights_arr: array of the same shape as ``dist`` + if ``weights == 'uniform'``, then returns None """ - if dist.dtype == np.dtype(object): - if weights in (None, 'uniform'): - return None - elif weights == 'distance': - return [_dist_to_weight(d) for d in dist] - elif callable(weights): - return [weights(d) for d in dist] - else: - raise ValueError("weights not recognized: should be 'uniform', " - "'distance', or a callable function") + if weights in (None, 'uniform'): + return None + elif weights == 'distance': + with np.errstate(divide='ignore'): + dist = 1./dist + return dist + elif callable(weights): + return weights(dist) else: - if weights in (None, 'uniform'): - return None - elif weights == 'distance': - return _dist_to_weight(dist) - elif callable(weights): - return weights(dist) - else: - raise ValueError("weights not recognized: should be 'uniform', " - "'distance', or a callable function") + raise ValueError("weights not recognized: should be 'uniform', " + "'distance', or a callable function") class NeighborsBase(BaseEstimator): diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index cac97b09349b790a29c8e26783b0a907488fffc2..19a56da28bcc8cc49db4b6f94a182434900d7d3b 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -32,11 +32,9 @@ def _weight_func(dist): # Dist could be multidimensional, flatten it so all values # can be looped - dist_values = dist.ravel() - retval = np.zeros(len(dist_values)) - for i, d in enumerate(dist_values): - retval[i] = (d ** -2) if (d != 0.0) else np.inf - return retval.reshape(dist.shape) + with np.errstate(divide='ignore'): + retval = 1./dist + return retval**2 def test_warn_on_equidistant(n_samples=100, n_features=3, k=3):