mirror of
https://github.com/donnemartin/data-science-ipython-notebooks.git
synced 2024-03-22 13:30:56 +08:00
342 lines
11 KiB
Python
342 lines
11 KiB
Python
"""
|
|
==========
|
|
Libsvm GUI
|
|
==========
|
|
|
|
A simple graphical frontend for Libsvm mainly intended for didactic
|
|
purposes. You can create data points by point and click and visualize
|
|
the decision region induced by different kernels and parameter settings.
|
|
|
|
To create positive examples click the left mouse button; to create
|
|
negative examples click the right button.
|
|
|
|
If all examples are from the same class, it uses a one-class SVM.
|
|
|
|
"""
|
|
from __future__ import division, print_function
|
|
|
|
print(__doc__)
|
|
|
|
# Author: Peter Prettenhoer <peter.prettenhofer@gmail.com>
|
|
#
|
|
# License: BSD 3 clause
|
|
|
|
import matplotlib
|
|
matplotlib.use('TkAgg')
|
|
|
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
|
from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg
|
|
from matplotlib.figure import Figure
|
|
from matplotlib.contour import ContourSet
|
|
|
|
import Tkinter as Tk
|
|
import sys
|
|
import numpy as np
|
|
|
|
from sklearn import svm
|
|
from sklearn.datasets import dump_svmlight_file
|
|
from sklearn.externals.six.moves import xrange
|
|
|
|
y_min, y_max = -50, 50
|
|
x_min, x_max = -50, 50
|
|
|
|
|
|
class Model(object):
|
|
"""The Model which hold the data. It implements the
|
|
observable in the observer pattern and notifies the
|
|
registered observers on change event.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.observers = []
|
|
self.surface = None
|
|
self.data = []
|
|
self.cls = None
|
|
self.surface_type = 0
|
|
|
|
def changed(self, event):
|
|
"""Notify the observers. """
|
|
for observer in self.observers:
|
|
observer.update(event, self)
|
|
|
|
def add_observer(self, observer):
|
|
"""Register an observer. """
|
|
self.observers.append(observer)
|
|
|
|
def set_surface(self, surface):
|
|
self.surface = surface
|
|
|
|
def dump_svmlight_file(self, file):
|
|
data = np.array(self.data)
|
|
X = data[:, 0:2]
|
|
y = data[:, 2]
|
|
dump_svmlight_file(X, y, file)
|
|
|
|
|
|
class Controller(object):
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.kernel = Tk.IntVar()
|
|
self.surface_type = Tk.IntVar()
|
|
# Whether or not a model has been fitted
|
|
self.fitted = False
|
|
|
|
def fit(self):
|
|
print("fit the model")
|
|
train = np.array(self.model.data)
|
|
X = train[:, 0:2]
|
|
y = train[:, 2]
|
|
|
|
C = float(self.complexity.get())
|
|
gamma = float(self.gamma.get())
|
|
coef0 = float(self.coef0.get())
|
|
degree = int(self.degree.get())
|
|
kernel_map = {0: "linear", 1: "rbf", 2: "poly"}
|
|
if len(np.unique(y)) == 1:
|
|
clf = svm.OneClassSVM(kernel=kernel_map[self.kernel.get()],
|
|
gamma=gamma, coef0=coef0, degree=degree)
|
|
clf.fit(X)
|
|
else:
|
|
clf = svm.SVC(kernel=kernel_map[self.kernel.get()], C=C,
|
|
gamma=gamma, coef0=coef0, degree=degree)
|
|
clf.fit(X, y)
|
|
if hasattr(clf, 'score'):
|
|
print("Accuracy:", clf.score(X, y) * 100)
|
|
X1, X2, Z = self.decision_surface(clf)
|
|
self.model.clf = clf
|
|
self.model.set_surface((X1, X2, Z))
|
|
self.model.surface_type = self.surface_type.get()
|
|
self.fitted = True
|
|
self.model.changed("surface")
|
|
|
|
def decision_surface(self, cls):
|
|
delta = 1
|
|
x = np.arange(x_min, x_max + delta, delta)
|
|
y = np.arange(y_min, y_max + delta, delta)
|
|
X1, X2 = np.meshgrid(x, y)
|
|
Z = cls.decision_function(np.c_[X1.ravel(), X2.ravel()])
|
|
Z = Z.reshape(X1.shape)
|
|
return X1, X2, Z
|
|
|
|
def clear_data(self):
|
|
self.model.data = []
|
|
self.fitted = False
|
|
self.model.changed("clear")
|
|
|
|
def add_example(self, x, y, label):
|
|
self.model.data.append((x, y, label))
|
|
self.model.changed("example_added")
|
|
|
|
# update decision surface if already fitted.
|
|
self.refit()
|
|
|
|
def refit(self):
|
|
"""Refit the model if already fitted. """
|
|
if self.fitted:
|
|
self.fit()
|
|
|
|
|
|
class View(object):
|
|
"""Test docstring. """
|
|
def __init__(self, root, controller):
|
|
f = Figure()
|
|
ax = f.add_subplot(111)
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
ax.set_xlim((x_min, x_max))
|
|
ax.set_ylim((y_min, y_max))
|
|
canvas = FigureCanvasTkAgg(f, master=root)
|
|
canvas.show()
|
|
canvas.get_tk_widget().pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
|
|
canvas._tkcanvas.pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
|
|
canvas.mpl_connect('key_press_event', self.onkeypress)
|
|
canvas.mpl_connect('key_release_event', self.onkeyrelease)
|
|
canvas.mpl_connect('button_press_event', self.onclick)
|
|
toolbar = NavigationToolbar2TkAgg(canvas, root)
|
|
toolbar.update()
|
|
self.shift_down = False
|
|
self.controllbar = ControllBar(root, controller)
|
|
self.f = f
|
|
self.ax = ax
|
|
self.canvas = canvas
|
|
self.controller = controller
|
|
self.contours = []
|
|
self.c_labels = None
|
|
self.plot_kernels()
|
|
|
|
def plot_kernels(self):
|
|
self.ax.text(-50, -60, "Linear: $u^T v$")
|
|
self.ax.text(-20, -60, "RBF: $\exp (-\gamma \| u-v \|^2)$")
|
|
self.ax.text(10, -60, "Poly: $(\gamma \, u^T v + r)^d$")
|
|
|
|
def onkeypress(self, event):
|
|
if event.key == "shift":
|
|
self.shift_down = True
|
|
|
|
def onkeyrelease(self, event):
|
|
if event.key == "shift":
|
|
self.shift_down = False
|
|
|
|
def onclick(self, event):
|
|
if event.xdata and event.ydata:
|
|
if self.shift_down or event.button == 3:
|
|
self.controller.add_example(event.xdata, event.ydata, -1)
|
|
elif event.button == 1:
|
|
self.controller.add_example(event.xdata, event.ydata, 1)
|
|
|
|
def update_example(self, model, idx):
|
|
x, y, l = model.data[idx]
|
|
if l == 1:
|
|
color = 'w'
|
|
elif l == -1:
|
|
color = 'k'
|
|
self.ax.plot([x], [y], "%so" % color, scalex=0.0, scaley=0.0)
|
|
|
|
def update(self, event, model):
|
|
if event == "examples_loaded":
|
|
for i in xrange(len(model.data)):
|
|
self.update_example(model, i)
|
|
|
|
if event == "example_added":
|
|
self.update_example(model, -1)
|
|
|
|
if event == "clear":
|
|
self.ax.clear()
|
|
self.ax.set_xticks([])
|
|
self.ax.set_yticks([])
|
|
self.contours = []
|
|
self.c_labels = None
|
|
self.plot_kernels()
|
|
|
|
if event == "surface":
|
|
self.remove_surface()
|
|
self.plot_support_vectors(model.clf.support_vectors_)
|
|
self.plot_decision_surface(model.surface, model.surface_type)
|
|
|
|
self.canvas.draw()
|
|
|
|
def remove_surface(self):
|
|
"""Remove old decision surface."""
|
|
if len(self.contours) > 0:
|
|
for contour in self.contours:
|
|
if isinstance(contour, ContourSet):
|
|
for lineset in contour.collections:
|
|
lineset.remove()
|
|
else:
|
|
contour.remove()
|
|
self.contours = []
|
|
|
|
def plot_support_vectors(self, support_vectors):
|
|
"""Plot the support vectors by placing circles over the
|
|
corresponding data points and adds the circle collection
|
|
to the contours list."""
|
|
cs = self.ax.scatter(support_vectors[:, 0], support_vectors[:, 1],
|
|
s=80, edgecolors="k", facecolors="none")
|
|
self.contours.append(cs)
|
|
|
|
def plot_decision_surface(self, surface, type):
|
|
X1, X2, Z = surface
|
|
if type == 0:
|
|
levels = [-1.0, 0.0, 1.0]
|
|
linestyles = ['dashed', 'solid', 'dashed']
|
|
colors = 'k'
|
|
self.contours.append(self.ax.contour(X1, X2, Z, levels,
|
|
colors=colors,
|
|
linestyles=linestyles))
|
|
elif type == 1:
|
|
self.contours.append(self.ax.contourf(X1, X2, Z, 10,
|
|
cmap=matplotlib.cm.bone,
|
|
origin='lower', alpha=0.85))
|
|
self.contours.append(self.ax.contour(X1, X2, Z, [0.0], colors='k',
|
|
linestyles=['solid']))
|
|
else:
|
|
raise ValueError("surface type unknown")
|
|
|
|
|
|
class ControllBar(object):
|
|
def __init__(self, root, controller):
|
|
fm = Tk.Frame(root)
|
|
kernel_group = Tk.Frame(fm)
|
|
Tk.Radiobutton(kernel_group, text="Linear", variable=controller.kernel,
|
|
value=0, command=controller.refit).pack(anchor=Tk.W)
|
|
Tk.Radiobutton(kernel_group, text="RBF", variable=controller.kernel,
|
|
value=1, command=controller.refit).pack(anchor=Tk.W)
|
|
Tk.Radiobutton(kernel_group, text="Poly", variable=controller.kernel,
|
|
value=2, command=controller.refit).pack(anchor=Tk.W)
|
|
kernel_group.pack(side=Tk.LEFT)
|
|
|
|
valbox = Tk.Frame(fm)
|
|
controller.complexity = Tk.StringVar()
|
|
controller.complexity.set("1.0")
|
|
c = Tk.Frame(valbox)
|
|
Tk.Label(c, text="C:", anchor="e", width=7).pack(side=Tk.LEFT)
|
|
Tk.Entry(c, width=6, textvariable=controller.complexity).pack(
|
|
side=Tk.LEFT)
|
|
c.pack()
|
|
|
|
controller.gamma = Tk.StringVar()
|
|
controller.gamma.set("0.01")
|
|
g = Tk.Frame(valbox)
|
|
Tk.Label(g, text="gamma:", anchor="e", width=7).pack(side=Tk.LEFT)
|
|
Tk.Entry(g, width=6, textvariable=controller.gamma).pack(side=Tk.LEFT)
|
|
g.pack()
|
|
|
|
controller.degree = Tk.StringVar()
|
|
controller.degree.set("3")
|
|
d = Tk.Frame(valbox)
|
|
Tk.Label(d, text="degree:", anchor="e", width=7).pack(side=Tk.LEFT)
|
|
Tk.Entry(d, width=6, textvariable=controller.degree).pack(side=Tk.LEFT)
|
|
d.pack()
|
|
|
|
controller.coef0 = Tk.StringVar()
|
|
controller.coef0.set("0")
|
|
r = Tk.Frame(valbox)
|
|
Tk.Label(r, text="coef0:", anchor="e", width=7).pack(side=Tk.LEFT)
|
|
Tk.Entry(r, width=6, textvariable=controller.coef0).pack(side=Tk.LEFT)
|
|
r.pack()
|
|
valbox.pack(side=Tk.LEFT)
|
|
|
|
cmap_group = Tk.Frame(fm)
|
|
Tk.Radiobutton(cmap_group, text="Hyperplanes",
|
|
variable=controller.surface_type, value=0,
|
|
command=controller.refit).pack(anchor=Tk.W)
|
|
Tk.Radiobutton(cmap_group, text="Surface",
|
|
variable=controller.surface_type, value=1,
|
|
command=controller.refit).pack(anchor=Tk.W)
|
|
|
|
cmap_group.pack(side=Tk.LEFT)
|
|
|
|
train_button = Tk.Button(fm, text='Fit', width=5,
|
|
command=controller.fit)
|
|
train_button.pack()
|
|
fm.pack(side=Tk.LEFT)
|
|
Tk.Button(fm, text='Clear', width=5,
|
|
command=controller.clear_data).pack(side=Tk.LEFT)
|
|
|
|
|
|
def get_parser():
|
|
from optparse import OptionParser
|
|
op = OptionParser()
|
|
op.add_option("--output",
|
|
action="store", type="str", dest="output",
|
|
help="Path where to dump data.")
|
|
return op
|
|
|
|
|
|
def main(argv):
|
|
op = get_parser()
|
|
opts, args = op.parse_args(argv[1:])
|
|
root = Tk.Tk()
|
|
model = Model()
|
|
controller = Controller(model)
|
|
root.wm_title("Scikit-learn Libsvm GUI")
|
|
view = View(root, controller)
|
|
model.add_observer(view)
|
|
Tk.mainloop()
|
|
|
|
if opts.output:
|
|
model.dump_svmlight_file(opts.output)
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv)
|