diff --git a/README.rst b/README.rst index a1349738143c1dd2675fd2a753d0fa76b4e5edc0..1815a48d50f61cb14d65ac75421ba9831f72036a 100644 --- a/README.rst +++ b/README.rst @@ -79,8 +79,8 @@ Bugs ---- Please submit bugs you might encounter, as well as patches and feature -requests to the tracker located at the address -https://sourceforge.net/apps/trac/scikit-learn/report +requests to the tracker located at github +https://github.com/scikit-learn/scikit-learn/issues Testing diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index 68bcac1e212ea91cf070f88060b21208c29059c0..dc9df8c36ac8f6db2cb7ba7f1f39bdcc6f294438 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -20,6 +20,19 @@ The :class:`NeighborsClassifier` implements the nearest-neighbors classification method using a vote heuristic: the class most present in the k nearest neighbors of a point is assigned to this point. +It is possible to use different nearest neighbor search algorithms by +using the keyword ``algorithm``. Possible values are ``'auto'``, +``'ball_tree'``, ``'brute'`` and ``'brute_inplace'``. ``'ball_tree'`` +will create an instance of :class:`BallTree` to conduct the search, +which is usually very efficient in low-dimensional spaces. In higher +dimension, a brute-force approach is prefered thus parameters +``'brute'`` and ``'brute_inplace'`` can be used . Both conduct a +brute-force search, the difference being that ``'brute_inplace'`` does +not perform any precomputations, and thus is better suited for +low-memory settings. Finally, ``'auto'`` is a simple heuristic that +will guess the best approach based on the current dataset. + + .. figure:: ../auto_examples/images/plot_neighbors.png :target: ../auto_examples/plot_neighbors.html :align: center diff --git a/scikits/learn/cluster/k_means_.py b/scikits/learn/cluster/k_means_.py index d64f85659871a4de2e46718069ef16167fc38896..fc9baec2cf9b706b28e8b53ecbd7cea48c63b510 100644 --- a/scikits/learn/cluster/k_means_.py +++ b/scikits/learn/cluster/k_means_.py @@ -184,9 +184,13 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0, if verbose: print 'Initialization complete' # iterations + x_squared_norms = X.copy() + x_squared_norms **=2 + x_squared_norms = x_squared_norms.sum(axis=1) for i in range(max_iter): centers_old = centers.copy() - labels, inertia = _e_step(X, centers) + labels, inertia = _e_step(X, centers, + x_squared_norms=x_squared_norms) centers = _m_step(X, labels, k) if verbose: print 'Iteration %i, inertia %s' % (i, inertia) @@ -228,12 +232,18 @@ def _m_step(x, z, k): The resulting centers """ dim = x.shape[1] - centers = np.repeat(np.reshape(x.mean(0), (1, dim)), k, 0) + centers = np.empty((k, dim)) + X_center = None for q in range(k): - if np.sum(z == q) == 0: - pass + this_center_mask = (z == q) + if not np.any(this_center_mask): + # The centroid of empty clusters is set to the center of + # everything + if X_center is None: + X_center = x.mean(axis=0) + centers[q] = X_center else: - centers[q] = np.mean(x[z == q], axis=0) + centers[q] = np.mean(x[this_center_mask], axis=0) return centers @@ -265,8 +275,10 @@ def _e_step(x, centers, precompute_distances=True, x_squared_norms=None): if precompute_distances: distances = euclidean_distances(centers, x, x_squared_norms, squared=True) - z = -np.ones(n_samples).astype(np.int) - mindist = np.infty * np.ones(n_samples) + z = np.empty(n_samples, dtype=np.int) + z.fill(-1) + mindist = np.empty(n_samples) + mindist.fill(np.infty) for q in range(k): if precompute_distances: dist = distances[q] diff --git a/scikits/learn/linear_model/setup.py b/scikits/learn/linear_model/setup.py index 707f673a567ac65d838558418db9cb7a641ee4d2..a8f7a079f628ceae4fb0a52ae7bac21dc6848a49 100644 --- a/scikits/learn/linear_model/setup.py +++ b/scikits/learn/linear_model/setup.py @@ -1,11 +1,9 @@ from os.path import join -import warnings import numpy -import sys def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration - from numpy.distutils.system_info import get_info, get_standard_file, BlasNotFoundError + from numpy.distutils.system_info import get_info config = Configuration('linear_model', parent_package, top_path) # cd fast needs CBLAS diff --git a/scikits/learn/metrics/__init__.py b/scikits/learn/metrics/__init__.py index 8bc4c44b47da540351fc3cc0b07badf99ac3b983..4b761ca18ae898dddbb4380dfd47af695a05e8c5 100644 --- a/scikits/learn/metrics/__init__.py +++ b/scikits/learn/metrics/__init__.py @@ -8,3 +8,5 @@ from .metrics import confusion_matrix, roc_curve, auc, precision_score, \ precision_recall_fscore_support, classification_report, \ precision_recall_curve, explained_variance_score, r2_score, \ zero_one, mean_square_error + +from .pairwise import euclidean_distances diff --git a/scikits/learn/metrics/pairwise.py b/scikits/learn/metrics/pairwise.py index 5f138b55f406ed40c8f2109fc69b43ccf3396dd2..4b60e2c3d61677a3a0c22acc4bea509a6ae2b443 100644 --- a/scikits/learn/metrics/pairwise.py +++ b/scikits/learn/metrics/pairwise.py @@ -8,10 +8,7 @@ sets of points. import numpy as np - -def euclidean_distances(X, Y, - Y_norm_squared=None, - squared=False): +def euclidean_distances(X, Y, Y_norm_squared=None, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. @@ -61,7 +58,9 @@ def euclidean_distances(X, Y, if X is Y: # shortcut in the common case euclidean_distances(X, X) YY = XX.T elif Y_norm_squared is None: - YY = np.sum(Y * Y, axis=1)[np.newaxis, :] + YY = Y.copy() + YY **= 2 + YY = np.sum(YY, axis=1)[np.newaxis, :] else: YY = np.asanyarray(Y_norm_squared) if YY.shape != (Y.shape[0],): diff --git a/scikits/learn/neighbors.py b/scikits/learn/neighbors.py index 18b5e474295de6aa7230aa81cb7f2d37e29e2512..07fca967db8ca255728b0c7d7da03074cdc3fa08 100644 --- a/scikits/learn/neighbors.py +++ b/scikits/learn/neighbors.py @@ -8,7 +8,7 @@ import numpy as np from .base import BaseEstimator, ClassifierMixin, RegressorMixin -from .ball_tree import BallTree +from .ball_tree import BallTree, knn_brute class NeighborsClassifier(BaseEstimator, ClassifierMixin): @@ -16,12 +16,18 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): Parameters ---------- - n_neighbors : int - default number of neighbors. + n_neighbors : int, optional + Default number of neighbors. Defaults to 5. - window_size : int + window_size : int, optional Window size passed to BallTree + algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional + Algorithm used to compute the nearest neighbors. 'ball_tree' + will construct a BallTree, 'brute' and 'brute_inplace' will + perform brute-force search.'auto' will guess the most + appropriate based on current dataset. + Examples -------- >>> samples = [[0, 0, 1], [1, 0, 0]] @@ -29,24 +35,25 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): >>> from scikits.learn.neighbors import NeighborsClassifier >>> neigh = NeighborsClassifier(n_neighbors=1) >>> neigh.fit(samples, labels) - NeighborsClassifier(n_neighbors=1, window_size=1) + NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto') >>> print neigh.predict([[0,0,0]]) [1] - Notes - ----- - Internally uses the ball tree datastructure and algorithm for fast - neighbors lookups on high dimensional datasets. + See also + -------- + BallTree References ---------- http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm """ - def __init__(self, n_neighbors=5, window_size=1): + def __init__(self, n_neighbors=5, algorithm='auto', window_size=1): self.n_neighbors = n_neighbors self.window_size = window_size + self.algorithm = algorithm + def fit(self, X, Y, **params): """ Fit the model using X, y as training data. @@ -62,12 +69,19 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): params : list of keyword, optional Overwrite keywords from __init__ """ + X = np.asanyarray(X) self._y = np.asanyarray(Y) self._set_params(**params) - self.ball_tree = BallTree(X, self.window_size) + if self.algorithm == 'ball_tree' or \ + (self.algorithm == 'auto' and X.shape[1] < 20): + self.ball_tree = BallTree(X, self.window_size) + else: + self.ball_tree = None + self._fit_X = X return self + def kneighbors(self, data, return_distance=True, **params): """Finds the K-neighbors of a point. @@ -105,7 +119,7 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): >>> from scikits.learn.neighbors import NeighborsClassifier >>> neigh = NeighborsClassifier(n_neighbors=1) >>> neigh.fit(samples, labels) - NeighborsClassifier(n_neighbors=1, window_size=1) + NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto') >>> print neigh.kneighbors([1., 1., 1.]) (array([ 0.5]), array([2])) @@ -123,6 +137,7 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): return self.ball_tree.query( data, k=self.n_neighbors, return_distance=return_distance) + def predict(self, X, **params): """Predict the class labels for the provided data. @@ -143,10 +158,21 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin): X = np.atleast_2d(X) self._set_params(**params) - ind = self.ball_tree.query( - X, self.n_neighbors, return_distance=False) - pred_labels = self._y[ind] + # .. get neighbors .. + if self.ball_tree is None: + if self.algorithm == 'brute_inplace': + neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors) + else: + from .metrics import euclidean_distances + dist = euclidean_distances( + X, self._fit_X, squared=True) + neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors] + else: + neigh_ind = self.ball_tree.query( + X, self.n_neighbors, return_distance=False) + # .. most popular label .. + pred_labels = self._y[neigh_ind] from scipy import stats mode, _ = stats.mode(pred_labels, axis=1) return mode.flatten().astype(np.int) @@ -168,15 +194,21 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin): Parameters ---------- - n_neighbors : int - default number of neighbors. + n_neighbors : int, optional + Default number of neighbors. Defaults to 5. - window_size : int + window_size : int, optional Window size passed to BallTree - mode : {'mean', 'barycenter'} + mode : {'mean', 'barycenter'}, optional Weights to apply to labels. + algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional + Algorithm used to compute the nearest neighbors. 'ball_tree' + will construct a BallTree, 'brute' and 'brute_inplace' will + perform brute-force search.'auto' will guess the most + appropriate based on current dataset. + Examples -------- >>> X = [[0], [1], [2], [3]] @@ -184,7 +216,8 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin): >>> from scikits.learn.neighbors import NeighborsRegressor >>> neigh = NeighborsRegressor(n_neighbors=2) >>> neigh.fit(X, y) - NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean') + NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean', + algorithm='auto') >>> print neigh.predict([[1.5]]) [ 0.5] @@ -194,10 +227,12 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin): """ - def __init__(self, n_neighbors=5, mode='mean', window_size=1): + def __init__(self, n_neighbors=5, mode='mean', algorithm='auto', + window_size=1): self.n_neighbors = n_neighbors self.window_size = window_size self.mode = mode + self.algorithm = algorithm def predict(self, X, **params): @@ -220,16 +255,22 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin): X = np.atleast_2d(np.asanyarray(X)) self._set_params(**params) -# -# .. compute neighbors .. -# - neigh_ind = self.ball_tree.query( - X, k=self.n_neighbors, return_distance=False) - neigh = self.ball_tree.data[neigh_ind] - -# -# .. return labels .. -# + # .. get neighbors .. + if self.ball_tree is None: + if self.algorithm == 'brute_inplace': + neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors) + else: + from .metrics.pairwise import euclidean_distances + dist = euclidean_distances( + X, self._fit_X, squared=False) + neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors] + neigh = self._fit_X[neigh_ind] + else: + neigh_ind = self.ball_tree.query( + X, self.n_neighbors, return_distance=False) + neigh = self.ball_tree.data[neigh_ind] + + # .. return labels .. if self.mode == 'barycenter': W = barycenter_weights(X, neigh) return (W * self._y[neigh_ind]).sum(axis=1) diff --git a/scikits/learn/setup.py b/scikits/learn/setup.py index b2e940a486a73e66103b892ca5c98c39f2016ef2..2a903f27a20bc5c9b9b1ee5e0062a120e9c24f79 100644 --- a/scikits/learn/setup.py +++ b/scikits/learn/setup.py @@ -1,7 +1,6 @@ from os.path import join import warnings import numpy -import sys def configuration(parent_package='', top_path=None): @@ -36,12 +35,7 @@ def configuration(parent_package='', top_path=None): ('NO_ATLAS_INFO', 1) in blas_info.get('define_macros', [])): config.add_library('cblas', sources=[join('src', 'cblas', '*.c')]) - cblas_libs = ['cblas'] - blas_info.pop('libraries', None) warnings.warn(BlasNotFoundError.__doc__) - else: - cblas_libs = blas_info.pop('libraries', []) - config.add_extension('ball_tree', sources=[join('src', 'BallTree.cpp')], diff --git a/scikits/learn/src/BallTree.cpp b/scikits/learn/src/BallTree.cpp index c1becd13845f49c957a469181eb691209b7fbf26..f98cdfa6684d4ab9e89b56fd412ac9326b7bb8e4 100644 --- a/scikits/learn/src/BallTree.cpp +++ b/scikits/learn/src/BallTree.cpp @@ -712,22 +712,6 @@ BallTree_knn_brute(PyObject *self, PyObject *args, PyObject *kwds){ for(int i=0;i<N;i++) delete Points[i]; - //if only one neighbor is requested, then resize the neighbors array - if(k==1){ - PyArray_Dims dims; - dims.ptr = PyArray_DIMS(arr2); - dims.len = PyArray_NDIM(arr2)-1; - - //PyArray_Resize returns None - this needs to be picked - // up and dereferenced. - PyObject *NoneObj = PyArray_Resize( (PyArrayObject*)nbrs, &dims, - 0, NPY_ANYORDER ); - if (NoneObj == NULL){ - goto fail; - } - Py_DECREF(NoneObj); - } - return nbrs; fail: diff --git a/scikits/learn/svm/base.py b/scikits/learn/svm/base.py index 1f21c3d670f31070dc3ad6ac344728d0f5e99730..3be7947c743688e23d5b103bcfbc0e0d79ef3a8a 100644 --- a/scikits/learn/svm/base.py +++ b/scikits/learn/svm/base.py @@ -425,17 +425,14 @@ class BaseLibLinear(BaseEstimator): X = np.atleast_2d(np.asanyarray(X, dtype=np.float64, order='C')) self._check_n_features(X) - coef = self.raw_coef_ - - dec_func = _liblinear.decision_function_wrap(X, coef, - self._get_solver_type(), - self.eps, self.C, - self.class_weight_label, - self.class_weight, self.label_, - self._get_bias()) + dec_func = _liblinear.decision_function_wrap( + X, self.raw_coef_, self._get_solver_type(), self.eps, + self.C, self.class_weight_label, self.class_weight, + self.label_, self._get_bias()) if len(self.label_) <= 2: - # one class + # in the two-class case, the decision sign needs be flipped + # due to liblinear's design return -dec_func else: return dec_func @@ -451,14 +448,22 @@ class BaseLibLinear(BaseEstimator): @property def intercept_(self): if self.fit_intercept: - return self.intercept_scaling * self.raw_coef_[:, -1] + ret = self.intercept_scaling * self.raw_coef_[:, -1] + if len(self.label_) <= 2: + ret *= -1 + return ret return 0.0 @property def coef_(self): if self.fit_intercept: - return self.raw_coef_[:, : -1] - return self.raw_coef_ + ret = self.raw_coef_[:, : -1] + else: + ret = self.raw_coef_ + if len(self.label_) <= 2: + return -ret + else: + return ret def predict_proba(self, T): # only available for logistic regression diff --git a/scikits/learn/svm/sparse/base.py b/scikits/learn/svm/sparse/base.py index 36e85bebac68e3eaf5e12f7982c8392c40c8f41c..62b3d3ed18a4f8ee137de5c94f4d9f6545ecb3cb 100644 --- a/scikits/learn/svm/sparse/base.py +++ b/scikits/learn/svm/sparse/base.py @@ -265,7 +265,8 @@ class SparseBaseLibLinear(BaseLibLinear): self._get_bias()) if len(self.label_) <= 2: - # one class + # in the two-class case, the decision sign needs be flipped + # due to liblinear's design return -dec_func else: return dec_func diff --git a/scikits/learn/svm/tests/test_svm.py b/scikits/learn/svm/tests/test_svm.py index 46e0e3be0c13502bf1f3c22618c28f8330ed9000..18bbab4818d611f0e8f8c4cc38ecc8f85ec5ac78 100644 --- a/scikits/learn/svm/tests/test_svm.py +++ b/scikits/learn/svm/tests/test_svm.py @@ -417,7 +417,7 @@ def test_dense_liblinear_intercept_handling(classifier=svm.LinearSVC): clf.intercept_scaling = 100 clf.fit(X, y) intercept1 = clf.intercept_ - assert intercept1 > 1 + assert intercept1 < -1 # when intercept_scaling is sufficiently high, the intercept value # doesn't depend on intercept_scaling value @@ -435,14 +435,26 @@ def test_liblinear_predict(): returns the same as the one in libliblinear """ + # multi-class case clf = svm.LinearSVC().fit(iris.data, iris.target) - weights = clf.coef_.T bias = clf.intercept_ H = np.dot(iris.data, weights) + bias - assert_array_equal(clf.predict(iris.data), H.argmax(axis=1)) + # binary-class case + X = [[2, 1], + [3, 1], + [1, 3], + [2, 3]] + y = [0, 0, 1, 1] + + clf = svm.LinearSVC().fit(X, y) + weights = np.ravel(clf.coef_) + bias = clf.intercept_ + H = np.dot(X, weights) + bias + assert_array_equal(clf.predict(X), (H > 0).astype(int)) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/scikits/learn/tests/test_neighbors.py b/scikits/learn/tests/test_neighbors.py index 34d448f9a0250c0fee10603f50ba2ab35e37f19e..3921fadf6b74c090b2f8610915eb18784a8f0815 100644 --- a/scikits/learn/tests/test_neighbors.py +++ b/scikits/learn/tests/test_neighbors.py @@ -1,7 +1,14 @@ import numpy as np -from numpy.testing import assert_array_almost_equal, assert_array_equal +from numpy.testing import assert_array_almost_equal, assert_array_equal, \ + assert_ -from scikits.learn import neighbors +from scikits.learn import neighbors, datasets + +# load and shuffle iris dataset +iris = datasets.load_iris() +perm = np.random.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] def test_neighbors_1D(): @@ -15,61 +22,49 @@ def test_neighbors_1D(): X = [[x] for x in range(0, n)] Y = [0]*(n/2) + [1]*(n/2) - # n_neighbors = 1 - knn = neighbors.NeighborsClassifier(n_neighbors=1) - knn.fit(X, Y) - test = [[i + 0.01] for i in range(0, n/2)] + \ - [[i - 0.01] for i in range(n/2, n)] - assert_array_equal(knn.predict(test), [0]*3 + [1]*3) - - # n_neighbors = 2 - knn = neighbors.NeighborsClassifier(n_neighbors=2) - knn.fit(X, Y) - assert_array_equal(knn.predict(test), [0]*4 + [1]*2) - - - # n_neighbors = 3 - knn = neighbors.NeighborsClassifier(n_neighbors=3) - knn.fit(X, Y) - assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]), - [0 for i in range(n/2)]) - assert_array_equal(knn.predict([[i-0.01] for i in range(n/2, n)]), - [1 for i in range(n/2)]) - - -def test_neighbors_2D(): + for s in ('auto', 'ball_tree', 'brute', 'inplace'): + # n_neighbors = 1 + knn = neighbors.NeighborsClassifier(n_neighbors=1, algorithm=s) + knn.fit(X, Y) + test = [[i + 0.01] for i in range(0, n/2)] + \ + [[i - 0.01] for i in range(n/2, n)] + assert_array_equal(knn.predict(test), [0]*3 + [1]*3) + + # n_neighbors = 2 + knn = neighbors.NeighborsClassifier(n_neighbors=2, algorithm=s) + knn.fit(X, Y) + assert_array_equal(knn.predict(test), [0]*4 + [1]*2) + + # n_neighbors = 3 + knn = neighbors.NeighborsClassifier(n_neighbors=3, algorithm=s) + knn.fit(X, Y) + assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]), + [0 for i in range(n/2)]) + assert_array_equal(knn.predict([[i-0.01] for i in range(n/2, n)]), + [1 for i in range(n/2)]) + + +def test_neighbors_iris(): """ - Nearest Neighbor in the plane. + Sanity checks on the iris dataset Puts three points of each label in the plane and performs a nearest neighbor query on points near the decision boundary. """ - X = ( - (0, 1), (1, 1), (1, 0), # label 0 - (-1, 0), (-1, -1), (0, -1)) # label 1 - n_2 = len(X)/2 - Y = [0]*n_2 + [1]*n_2 - knn = neighbors.NeighborsClassifier() - knn.fit(X, Y) - prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]]) - assert_array_equal(prediction, [0, 1, 0, 1]) + for s in ('auto', 'ball_tree', 'brute', 'inplace'): + clf = neighbors.NeighborsClassifier() + clf.fit(iris.data, iris.target, n_neighbors=1, algorithm=s) + assert_array_equal(clf.predict(iris.data), iris.target) + clf.fit(iris.data, iris.target, n_neighbors=9, algorithm=s) + assert_(np.mean(clf.predict(iris.data)== iris.target) > 0.95) -def test_neighbors_regressor(): - """ - NeighborsRegressor for regression using k-NN - """ - X = [[0], [1], [2], [3]] - y = [0, 0, 1, 1] - neigh = neighbors.NeighborsRegressor(n_neighbors=3) - neigh.fit(X, y, mode='barycenter') - assert_array_almost_equal( - neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3) - neigh.fit(X, y, mode='mean') - assert_array_almost_equal( - neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3) - + for m in ('barycenter', 'mean'): + rgs = neighbors.NeighborsRegressor() + rgs.fit(iris.data, iris.target, mode=m, algorithm=s) + assert_(np.mean( + rgs.predict(iris.data).round() == iris.target) > 0.95) def test_kneighbors_graph():