diff --git a/examples/svm/plot_iris.py b/examples/svm/plot_iris.py index dbf918e8c63e2890245a7f63e33dc4514bee6fc9..2fe6d64daca497557363bd2b4db05265c9b64372 100644 --- a/examples/svm/plot_iris.py +++ b/examples/svm/plot_iris.py @@ -3,31 +3,57 @@ Plot different SVM classifiers in the iris dataset ================================================== -Comparison of different linear SVM classifiers on the iris dataset. It -will plot the decision surface for four different SVM classifiers. +Comparison of different linear SVM classifiers on a 2D projection of the iris +dataset. We only consider the first 2 features of this dataset: + +- Sepal length +- Sepal width + +This example shows how to plot the decision surface for four SVM classifiers +with different kernels. + +The linear models ``LinearSVC()`` and ``SVC(kernel='linear')`` yield slightly +different decision boundaries. This can be a consequence of the following +differences: + +- ``LinearSVC`` minimizes the squared hinge loss while ``SVC`` minimizes the + regular hinge loss. + +- ``LinearSVC`` uses the One-vs-All (also known as One-vs-Rest) multiclass + reduction while ``SVC`` uses the One-vs-One multiclass reduction. + +Both linear models have linear decision boundaries (intersecting hyperplanes) +while the non-linear kernel models (polynomial or Gaussian RBF) have more +flexible non-linear decision boundaries with shapes that depend on the kind of +kernel and its parameters. + +.. NOTE:: while plotting the decision function of classifiers for toy 2D + datasets can help get an intuitive understanding of their respective + expressive power, be aware that those intuitions don't always generalize to + more realistic high-dimensional problem. """ print(__doc__) import numpy as np -import pylab as pl +import matplotlib.pyplot as plt from sklearn import svm, datasets # import some data to play with iris = datasets.load_iris() X = iris.data[:, :2] # we only take the first two features. We could # avoid this ugly slicing by using a two-dim dataset -Y = iris.target +y = iris.target h = .02 # step size in the mesh # we create an instance of SVM and fit out data. We do not scale our # data since we want to plot the support vectors C = 1.0 # SVM regularization parameter -svc = svm.SVC(kernel='linear', C=C).fit(X, Y) -rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, Y) -poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, Y) -lin_svc = svm.LinearSVC(C=C).fit(X, Y) +svc = svm.SVC(kernel='linear', C=C).fit(X, y) +rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y) +poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y) +lin_svc = svm.LinearSVC(C=C).fit(X, y) # create a mesh to plot in x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 @@ -37,31 +63,31 @@ xx, yy = np.meshgrid(np.arange(x_min, x_max, h), # title for the plots titles = ['SVC with linear kernel', + 'LinearSVC (linear kernel)', 'SVC with RBF kernel', - 'SVC with polynomial (degree 3) kernel', - 'LinearSVC (linear kernel)'] + 'SVC with polynomial (degree 3) kernel'] -for i, clf in enumerate((svc, rbf_svc, poly_svc, lin_svc)): +for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)): # Plot the decision boundary. For that, we will assign a color to each # point in the mesh [x_min, m_max]x[y_min, y_max]. - pl.subplot(2, 2, i + 1) - pl.subplots_adjust(wspace=0.4, hspace=0.4) + plt.subplot(2, 2, i + 1) + plt.subplots_adjust(wspace=0.4, hspace=0.4) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) - pl.contourf(xx, yy, Z, cmap=pl.cm.Paired) + plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8) # Plot also the training points - pl.scatter(X[:, 0], X[:, 1], c=Y, cmap=pl.cm.Paired) - pl.xlabel('Sepal length') - pl.ylabel('Sepal width') - pl.xlim(xx.min(), xx.max()) - pl.ylim(yy.min(), yy.max()) - pl.xticks(()) - pl.yticks(()) - pl.title(titles[i]) - -pl.show() + plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired) + plt.xlabel('Sepal length') + plt.ylabel('Sepal width') + plt.xlim(xx.min(), xx.max()) + plt.ylim(yy.min(), yy.max()) + plt.xticks(()) + plt.yticks(()) + plt.title(titles[i]) + +plt.show()