diff --git a/examples/svm/svm_gui.py b/examples/svm/svm_gui.py index ebc929936310be9609885e34c1908cb39777ee2b..d2e2a160de9cd3771e26feb600c744b81a09224d 100644 --- a/examples/svm/svm_gui.py +++ b/examples/svm/svm_gui.py @@ -10,7 +10,7 @@ the decision region induced by different kernels and parameter settings. To create positive examples click the left mouse button; to create negative examples click the right button. -If all examples are from the same class, it uses a one-class svm. +If all examples are from the same class, it uses a one-class SVM. Requirements ------------ @@ -36,6 +36,7 @@ matplotlib.use('TkAgg') from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg from matplotlib.figure import Figure +from matplotlib.contour import ContourSet import Tkinter as Tk import sys @@ -48,6 +49,10 @@ x_min, x_max = -50, 50 class Model(object): + """The Model which hold the data. It implements the + observable in the observer pattern and notifies the + registered observers on change event. + """ def __init__(self): self.observers = [] self.surface = None @@ -56,10 +61,12 @@ class Model(object): self.surface_type = 0 def changed(self, event): + """Notify the observers. """ for observer in self.observers: observer.update(event, self) def add_observer(self, observer): + """Register an observer. """ self.observers.append(observer) def set_surface(self, surface): @@ -104,7 +111,7 @@ class Controller(object): x = np.arange(x_min, x_max + delta, delta) y = np.arange(y_min, y_max + delta, delta) X1, X2 = np.meshgrid(x, y) - Z = cls.predict_margin(np.c_[X1.ravel(), X2.ravel()]) + Z = cls.decision_function(np.c_[X1.ravel(), X2.ravel()]) Z = Z.reshape(X1.shape) return X1, X2, Z @@ -118,6 +125,11 @@ class Controller(object): class View(object): + """The view of the Model. This class implements the + Observer in the observer pattern. It is registered to + the Model and will be updated if the change event in + the model is triggered. + """ def __init__(self, root, controller): f = Figure() ax = f.add_subplot(111) @@ -137,7 +149,6 @@ class View(object): self.ax = ax self.canvas = canvas self.controller = controller - self.hascolormaps = False self.contours = [] self.c_labels = None self.plot_kernels() @@ -173,23 +184,33 @@ class View(object): self.plot_kernels() if event == "surface": + self.remove_surface() + self.plot_support_vectors(model.clf.support_vectors_) self.plot_decision_surface(model.surface, model.surface_type) self.canvas.draw() - def plot_decision_surface(self, surface, type): - X1, X2, Z = surface - + def remove_surface(self): + """Remove old decision surface.""" if len(self.contours) > 0: for contour in self.contours: - for lineset in contour.collections: - lineset.remove() + if isinstance(contour, ContourSet): + for lineset in contour.collections: + lineset.remove() + else: + contour.remove() self.contours = [] - if self.c_labels: - for label in self.c_labels: - label.remove() + def plot_support_vectors(self, support_vectors): + """Plot the support vectors by placing circles over the + corresponding data points and adds the circle collection + to the contours list.""" + cs = self.ax.scatter(support_vectors[:, 0], support_vectors[:, 1], + s=80, edgecolors="k", facecolors="none") + self.contours.append(cs) + def plot_decision_surface(self, surface, type): + X1, X2, Z = surface if type == 0: levels = [-1.0, 0.0, 1.0] linestyles = ['dashed', 'solid', 'dashed']