-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
96 lines (84 loc) · 3.16 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from matplotlib.lines import lineStyles
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
RESULT_PATH = './results'
def plot_loss_acc():
# 读入不同client数的loss和acc
num_clients = [4, 8, 10, 15, 20]
# 创建两个二维数组,分别存储loss和acc
loss = []
acc = []
for i in range(len(num_clients)):
with open(os.path.join(RESULT_PATH, f'results_{num_clients[i]}_partial_multi.txt'), 'rb') as f:
lines = f.readlines()
# 略过第一行
lines = lines[1:]
# 保存num_clients为num_clients[i]时的loss和acc
loss.append([])
acc.append([])
for line in lines:
line = line.decode('utf-8')
loss[i].append(float(line.split(',')[2].split(':')[1]))
acc[i].append(float(line.split(',')[1].split(':')[1]))
# 绘制loss和acc
plt.figure()
for i in range(len(num_clients)):
plt.plot(loss[i], label=f'{num_clients[i]} clients', linestyle='--')
plt.title('Loss for different number of clients')
plt.xlabel('Epoch')
plt.ylabel('Loss (Cross Entropy)')
plt.legend()
plt.savefig(os.path.join(RESULT_PATH, 'loss_different_M.png'))
plt.show()
plt.figure()
for i in range(len(num_clients)):
plt.plot(acc[i], label=f'{num_clients[i]} clients')
plt.title('Accuracy for different number of clients')
plt.xlabel('Epoch')
plt.ylabel('Accuracy for gloabl model')
plt.legend()
plt.savefig(os.path.join(RESULT_PATH, 'accuracy_different_M.png'))
plt.show()
def plot_for_stage3():
num_clients = [5, 10, 15, 20]
results = {num: {'loss': [], 'acc': []} for num in num_clients}
with open(os.path.join(RESULT_PATH, 'results_socket.txt'), 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
line = lines[i].strip()
if 'num_clients' in line:
num = int(line.split('=')[-1].strip())
for j in range(1, 11):
round_line = lines[i+j].strip()
round_loss = float(round_line.split(',')[1].split(':')[-1].strip())
round_acc = float(round_line.split(',')[2].split(':')[-1].strip())
results[num]['loss'].append(round_loss)
results[num]['acc'].append(round_acc)
i += 11
else:
i += 1
# plot
plt.figure()
for num in num_clients:
plt.plot(results[num]['loss'], label=f'{num} clients', linestyle='--')
plt.title('Loss for different number of clients')
plt.xlabel('Epoch')
plt.ylabel('Loss with different client number')
plt.gca().get_yaxis().set_visible(False)
plt.legend()
plt.savefig(os.path.join(RESULT_PATH,'loss_socket.png'))
# plot accuracy
plt.figure()
for num in num_clients:
plt.plot(results[num]['acc'], label=f'{num} clients')
plt.title('Accuracy for different number of clients')
plt.xlabel('Round')
plt.ylabel('Accuracy with different client number')
plt.legend()
plt.savefig(os.path.join(RESULT_PATH, 'accuracy_socket.png'))
if __name__ == '__main__':
# plot_loss_acc()
plot_for_stage3()