-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyze_divergence.py
62 lines (53 loc) · 2.39 KB
/
analyze_divergence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
'''
File: /analyze_divergence.py
Project: learning-hive
Created Date: Friday April 7th 2023
Author: Long Le (vlongle@seas.upenn.edu)
Copyright (c) 2023 Long Le
'''
import re
from shell.utils.metric import DivergenceMetric
from shell.utils.record import Record
import os
import pandas as pd
result_dir = "experiment_results/jorge_setting_fedavg"
pattern = r".*"
concat_df = pd.DataFrame()
for job_name in os.listdir(result_dir):
use_contrastive = "contrastive" in job_name
for dataset_name in os.listdir(os.path.join(result_dir, job_name)):
for algo in os.listdir(os.path.join(result_dir, job_name, dataset_name)):
for seed in os.listdir(os.path.join(result_dir, job_name, dataset_name, algo)):
for agent_id in os.listdir(os.path.join(result_dir, job_name, dataset_name, algo, seed)):
if agent_id == "hydra_out" or agent_id == "agent_69420":
continue
save_dir = os.path.join(
result_dir, job_name, dataset_name, algo, seed, agent_id)
# if the pattern doesn't match, continue
if not re.search(pattern, save_dir):
continue
print(save_dir)
df = DivergenceMetric(save_dir).df
# only keep task_id, communication_round, time, and avg_params
# columns
df = df[df['info'] == 'before']
df = df[["task_id", "communication_round",
"epoch", "avg_params"]]
# add seed and agent_id columns
df["seed"] = seed
df["agent_id"] = agent_id
df['algo'] = algo
df['dataset'] = dataset_name
df['use_contrastive'] = use_contrastive
concat_df = pd.concat([concat_df, df])
# reduce concat_df averaging over seed, and agent_id
# to get avg_params (mean) and avg_params_stderr (standard error) and
# avg_params_std (standard deviation)
concat_df = concat_df.groupby(
["task_id", "communication_round", "epoch", "algo", "dataset", "use_contrastive"]).agg(
avg_params=("avg_params", "mean"),
avg_params_stderr=("avg_params", "sem"),
avg_params_std=("avg_params", "std")
).reset_index()
# save to csv
concat_df.to_csv(f"{result_dir}_divergence.csv", index=False)