-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
86 lines (66 loc) · 2.58 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import subprocess
from tempfile import mkstemp
from itertools import cycle
import matplotlib.pyplot as plt
import numpy as np
from scipy import interp
from sklearn.tree import export_graphviz
from sklearn.metrics import roc_curve, auc
from IPython.core.display import HTML
def plot_surface(clf, X, y,
xlim=(140, 210), ylim=(21, 29), n_steps=250,
subplot=None, show=True):
if subplot is None:
fig = plt.figure()
else:
plt.subplot(*subplot)
xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], n_steps),
np.linspace(ylim[0], ylim[1], n_steps))
if hasattr(clf, "decision_function"):
z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
else:
z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
z = z.reshape(xx.shape)
plt.contourf(xx, yy, z, alpha=0.8, cmap=plt.cm.RdBu_r)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.xlim(*xlim)
plt.ylim(*ylim)
if show:
plt.show()
def draw_tree(clf, feature_names, **kwargs):
_, name = mkstemp(suffix='.dot')
_, svg_name = mkstemp(suffix='.svg')
export_graphviz(clf, out_file=name,
feature_names=feature_names,
**kwargs)
command = ["dot", "-Tsvg", name, "-o", svg_name]
subprocess.check_call(command)
return HTML(open(svg_name).read())
# Taken from http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html
def draw_roc_curve(classifier, cv, X, y):
mean_tpr = 0.0
mean_fpr = np.linspace(0, 1, 100)
colors = cycle(['cyan', 'indigo', 'seagreen', 'yellow', 'blue', 'darkorange'])
lw = 2
for i, (train, test) in enumerate(cv):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
mean_tpr += interp(mean_fpr, fpr, tpr)
mean_tpr[0] = 0.0
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc))
plt.plot([0, 1], [0, 1], linestyle='--', lw=lw, color='k',
label='Luck')
mean_tpr /= len(cv)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
plt.plot(mean_fpr, mean_tpr, color='g', linestyle='--',
label='Mean ROC (area = %0.2f)' % mean_auc, lw=lw)
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()