forked from leriomaggio/python-ml-course
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_interactive_tree.py
More file actions
74 lines (61 loc) · 2.33 KB
/
Copy pathplot_interactive_tree.py
File metadata and controls
74 lines (61 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals.six import StringIO # doctest: +SKIP
from sklearn.tree import export_graphviz
from scipy import ndimage
try:
from scipy.misc import imread
except ImportError:
from scipy.ndimage import imread
import re
X, y = make_blobs(centers=[[0, 0], [1, 1]], random_state=61526, n_samples=50)
def tree_image(tree, fout=None):
try:
import pydot
except ImportError:
# make a hacky white plot
x = np.ones((10, 10))
x[0, 0] = 0
return x
dot_data = StringIO()
export_graphviz(tree, out_file=dot_data)
data = re.sub(r"gini = 0\.[0-9]+\\n", "", dot_data.getvalue())
data = re.sub(r"samples = [0-9]+\\n", "", data)
data = re.sub(r"\\nsamples = [0-9]+", "", data)
graph = pydot.graph_from_dot_data(data)
if fout is None:
fout = "tmp.png"
graph.write_png(fout)
return imread(fout)
def plot_tree(max_depth=1):
fig, ax = plt.subplots(1, 2, figsize=(15, 7))
h = 0.02
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
if max_depth != 0:
tree = DecisionTreeClassifier(max_depth=max_depth, random_state=1).fit(X, y)
Z = tree.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
Z = Z.reshape(xx.shape)
faces = tree.tree_.apply(np.c_[xx.ravel(), yy.ravel()].astype(np.float32))
faces = faces.reshape(xx.shape)
border = ndimage.laplace(faces) != 0
ax[0].contourf(xx, yy, Z, alpha=.4)
ax[0].scatter(xx[border], yy[border], marker='.', s=1)
ax[0].set_title("max_depth = %d" % max_depth)
ax[1].imshow(tree_image(tree))
ax[1].axis("off")
else:
ax[0].set_title("data set")
ax[1].set_visible(False)
ax[0].scatter(X[:, 0], X[:, 1], c=np.array(['b', 'r'])[y], s=60)
ax[0].set_xlim(x_min, x_max)
ax[0].set_ylim(y_min, y_max)
ax[0].set_xticks(())
ax[0].set_yticks(())
def plot_tree_interactive():
from IPython.html.widgets import interactive, IntSlider
slider = IntSlider(min=0, max=8, step=1, value=0)
return interactive(plot_tree, max_depth=slider)