-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotdata.py
38 lines (32 loc) · 1.39 KB
/
plotdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import matplotlib.pyplot as plt
import numpy as np
def plotdata(X, y, wf, wg, desc):
'''
PLOTDATA Plot data set.
INPUT: X: sample features, P-by-N matrix.
y: sample labels, 1-by-N row vector.
wf: true target function parameters, (P+1)-by-1 column vector.
wg: learnt target function parameters, (P+1)-by-1 column vector.
desc: title of figure.
'''
if X.shape[0] != 2:
print('Here we only support 2-d X data')
return
plt.plot(X[0, y.flatten() == 1], X[1, y.flatten() == 1], 'o', markerfacecolor='r', \
markersize=10)
plt.plot(X[0, y.flatten() == -1], X[1, y.flatten() == -1], 'o', markerfacecolor='g', \
markersize=10)
k, b = -wf[1] / wf[2], -wf[0] / wf[2]
max_x = max(min((1 - b) / k, (-1 - b ) / k), -1)
min_x = min(max((1 - b) / k, (-1 - b ) / k), 1)
x = np.arange(min_x, max_x, (max_x - min_x) / 100)
temp_y = k * x + b
plt.plot(x, temp_y, color='b', linewidth=2, linestyle='-')
k, b = -wg[1] / wg[2], -wg[0] / wg[2]
max_x = max(min((1 - b) / k, (-1 - b ) / k), -1)
min_x = min(max((1 - b) / k, (-1 - b ) / k), 1)
x = np.arange(min_x, max_x, (max_x - min_x) / 100)
temp_y = k * x + b
plt.plot(x, temp_y, color='b', linewidth=2, linestyle='--')
plt.title(desc)
plt.show()