diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index 072f4559d56f3df2d3a225e34791f1ecb44b979c..0d8c1eab85b7a61b50664bfac45c78c0b8d2e1b9 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -3,48 +3,111 @@ RBF SVM parameters ================== -This example illustrates the effect of the parameters `gamma` -and `C` of the rbf kernel SVM. - -Intuitively, the `gamma` parameter defines how far the influence -of a single training example reaches, with low values meaning 'far' -and high values meaning 'close'. -The `C` parameter trades off misclassification of training examples -against simplicity of the decision surface. A low C makes -the decision surface smooth, while a high C aims at classifying -all training examples correctly. - -Two plots are generated. The first is a visualization of the -decision function for a variety of parameter values, and the second -is a heatmap of the classifier's cross-validation accuracy as -a function of `C` and `gamma`. For this example we explore a relatively -large grid for illustration purposes. In practice, a logarithmic -grid from `10**-3` to `10**3` is usually sufficient. +This example illustrates the effect of the parameters ``gamma`` and ``C`` of +the Radius Basis Function (RBF) kernel SVM. + +Intuitively, the ``gamma`` parameter defines how far the influence of a single +training example reaches, with low values meaning 'far' and high values meaning +'close'. The ``gamma`` parameters can be seen as the inverse of the radius of +influence of samples selected by the model as support vectors. + +The ``C`` parameter trades off misclassification of training examples against +simplicity of the decision surface. A low ``C`` makes the decision surface +smooth, while a high ``C`` aims at classifying all training examples correctly +by give the model freedom to select more samples as support vectors. + +The first plot is a visualization of the decision function for a variety of +parameter values on simplified classification problem involving only 2 input +features and 2 possible target classes (binary classification). Note that this +kind of plot is not possible to do for problems with more features or target +classes. + +The second plot is a heatmap of the classifier's cross-validation accuracy as a +function of ``C`` and ``gamma``. For this example we explore a relatively large +grid for illustration purposes. In practice, a logarithmic grid from +:math:`10^{-3}` to :math:`10^3` is usually sufficient. If the best parameters +lie on the boundaries of the grid, it can be extended in that direction in a +subsequent search. + +Note that the heat map plot has a special colorbar with a midpoint value close +to the score values of the best performing models so as to make it easy to tell +them appart in the blink of an eye. + +The behavior of the model is very sensitive to the ``gamma`` parameter. If +``gamma`` is too large, the radius of the area of influence of the support +vectors only includes the support vector it-self and no amount of +regularization with ``C`` will be able to prevent of overfitting. + +When ``gamma`` is very small, the model is too constrained and cannot capture +the complexity or "shape" of the data. The region of influence of any selected +support vector would include the whole training set. The resulting model will +behave similarly to a linear model with a set of hyperplanes that separate the +centers of high density of any pair of two classes. + +For intermediate values, we can see on a the second plot that good models can +be found on a diagonal of ``C`` and ``gamma``. Smooth models (lower ``gamma`` +values) can be made more complex by selecting a larger number of support +vectors (larger ``C`` values) hence the diagonal of good performing models. + +Finally one can also observe that for some intermediate values of ``gamma`` we +get equally performing models when ``C`` becomes very large: it is not +necessary to regularize by limiting the number support vectors. The radius of +the RBF kernel alone acts as a good structural regularizer. In practice though +it might still be interesting to limit the number of support vectors with a +lower value of ``C`` so as to favor models that use less memory and that are +faster to predict. + +We should also note that small differences in scores results from the random +splits of the cross-validation procedure. Those spurious variations can +smoothed out by increasing the number of CV iterations ``n_iter`` at the +expense of compute time. Increasing the value number of ``C_range`` and +``gamma_range`` steps will increase the resolution of the hyper-parameter heat +map. + ''' print(__doc__) import numpy as np import matplotlib.pyplot as plt +from matplotlib.colors import Normalize from sklearn.svm import SVC from sklearn.preprocessing import StandardScaler from sklearn.datasets import load_iris -from sklearn.cross_validation import StratifiedKFold +from sklearn.cross_validation import StratifiedShuffleSplit from sklearn.grid_search import GridSearchCV + +# Utility function to move the midpoint of a colormap to be around +# the values of interest. + +class MidpointNormalize(Normalize): + + def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): + self.midpoint = midpoint + Normalize.__init__(self, vmin, vmax, clip) + + def __call__(self, value, clip=None): + x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] + return np.ma.masked_array(np.interp(value, x, y)) + ############################################################################## # Load and prepare data set # # dataset for grid search + iris = load_iris() X = iris.data -Y = iris.target +y = iris.target + +# Dataset for decision function visualization: we only keep the first two +# features in X and sub-sample the dataset to keep only 2 class to has +# to make it a binary classification problem. -# dataset for decision function visualization X_2d = X[:, :2] -X_2d = X_2d[Y > 0] -Y_2d = Y[Y > 0] -Y_2d -= 1 +X_2d = X_2d[y > 0] +y_2d = y[y > 0] +y_2d -= 1 # It is usually a good idea to scale the data for SVM training. # We are cheating a bit in this example in scaling all of the data, @@ -52,43 +115,45 @@ Y_2d -= 1 # just applying it on the test set. scaler = StandardScaler() - X = scaler.fit_transform(X) X_2d = scaler.fit_transform(X_2d) ############################################################################## -# Train classifier +# Train classifiers # # For an initial search, a logarithmic grid with basis # 10 is often helpful. Using a basis of 2, a finer # tuning can be achieved but at a much higher cost. -C_range = 10.0 ** np.arange(-2, 9) -gamma_range = 10.0 ** np.arange(-5, 4) +C_range = np.logspace(-2, 10, 13) +gamma_range = np.logspace(-9, 3, 13) param_grid = dict(gamma=gamma_range, C=C_range) -cv = StratifiedKFold(y=Y, n_folds=3) +cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42) grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv) -grid.fit(X, Y) +grid.fit(X, y) -print("The best classifier is: ", grid.best_estimator_) +print("The best parameters are %s with a score of %0.2f" + % (grid.best_params_, grid.best_score_)) # Now we need to fit a classifier for all parameters in the 2d version # (we use a smaller set of parameters here because it takes a while to train) -C_2d_range = [1, 1e2, 1e4] + +C_2d_range = [1e-2, 1, 1e2] gamma_2d_range = [1e-1, 1, 1e1] classifiers = [] for C in C_2d_range: for gamma in gamma_2d_range: clf = SVC(C=C, gamma=gamma) - clf.fit(X_2d, Y_2d) + clf.fit(X_2d, y_2d) classifiers.append((C, gamma, clf)) ############################################################################## # visualization # # draw visualization of parameter effects + plt.figure(figsize=(8, 6)) -xx, yy = np.meshgrid(np.linspace(-5, 5, 200), np.linspace(-5, 5, 200)) +xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200)) for (k, (C, gamma, clf)) in enumerate(classifiers): # evaluate decision function in a grid Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) @@ -96,32 +161,39 @@ for (k, (C, gamma, clf)) in enumerate(classifiers): # visualize decision function for these parameters plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1) - plt.title("gamma 10^%d, C 10^%d" % (np.log10(gamma), np.log10(C)), + plt.title("gamma=10^%d, C=10^%d" % (np.log10(gamma), np.log10(C)), size='medium') # visualize parameter's effect on decision function - plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.jet) - plt.scatter(X_2d[:, 0], X_2d[:, 1], c=Y_2d, cmap=plt.cm.jet) + plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu) + plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r) plt.xticks(()) plt.yticks(()) plt.axis('tight') # plot the scores of the grid # grid_scores_ contains parameter settings and scores -score_dict = grid.grid_scores_ - # We extract just the scores -scores = [x[1] for x in score_dict] +scores = [x[1] for x in grid.grid_scores_] scores = np.array(scores).reshape(len(C_range), len(gamma_range)) -# draw heatmap of accuracy as a function of gamma and C +# Draw heatmap of the validation accuracy as a function of gamma and C +# +# The score are encoded as colors with the hot colormap which varies from dark +# red to bright yellow. As the most interesting scores are all located in the +# 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so +# as to make it easier to visualize the small variations of score values in the +# interesting range while not brutally collapsing all the low score values to +# the same color. + plt.figure(figsize=(8, 6)) -plt.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95) -plt.imshow(scores, interpolation='nearest', cmap=plt.cm.spectral) +plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95) +plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot, + norm=MidpointNormalize(vmin=0.2, midpoint=0.92)) plt.xlabel('gamma') plt.ylabel('C') plt.colorbar() plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45) plt.yticks(np.arange(len(C_range)), C_range) - +plt.title('Validation accuracy') plt.show()