-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_svm.py
executable file
·116 lines (93 loc) · 3.7 KB
/
train_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
import pickle
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn import cross_validation
from sklearn import metrics
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, '{0:.2f}'.format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Load training data from disk
training_set = pickle.load(open('training_set.sav', 'rb'))
# Format the features and labels for use with scikit learn
feature_list = []
label_list = []
for item in training_set:
if np.isnan(item[0]).sum() < 1:
feature_list.append(item[0])
label_list.append(item[1])
print('Features in Training Set: {}'.format(len(training_set)))
print('Invalid Features in Training set: {}'.format(len(training_set)-len(feature_list)))
X = np.array(feature_list)
# Fit a per-column scaler
X_scaler = StandardScaler().fit(X)
# Apply the scaler to X
X_train = X_scaler.transform(X)
y_train = np.array(label_list)
# Convert label strings to numerical encoding
encoder = LabelEncoder()
y_train = encoder.fit_transform(y_train)
# Create classifier
clf = svm.SVC(kernel='linear')
# Set up 5-fold cross-validation
kf = cross_validation.KFold(len(X_train),
n_folds=5,
shuffle=True,
random_state=1)
# Perform cross-validation
scores = cross_validation.cross_val_score(cv=kf,
estimator=clf,
X=X_train,
y=y_train,
scoring='accuracy'
)
print('Scores: ' + str(scores))
print('Accuracy: %0.2f (+/- %0.2f)' % (scores.mean(), 2*scores.std()))
# Gather predictions
predictions = cross_validation.cross_val_predict(cv=kf,
estimator=clf,
X=X_train,
y=y_train
)
accuracy_score = metrics.accuracy_score(y_train, predictions)
print('accuracy score: '+str(accuracy_score))
confusion_matrix = metrics.confusion_matrix(y_train, predictions)
class_names = encoder.classes_.tolist()
#Train the classifier
clf.fit(X=X_train, y=y_train)
model = {'classifier': clf, 'classes': encoder.classes_, 'scaler': X_scaler}
# Save classifier to disk
pickle.dump(model, open('model.sav', 'wb'))
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(confusion_matrix, classes=encoder.classes_,
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(confusion_matrix, classes=encoder.classes_, normalize=True,
title='Normalized confusion matrix')
plt.show()