-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyze.py
82 lines (67 loc) · 2.74 KB
/
analyze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
File: /analyze.py
Project: learning-hive
Created Date: Monday March 20th 2023
Author: Long Le (vlongle@seas.upenn.edu)
Copyright (c) 2023 Long Le
'''
from shell.utils.record import Record
from shell.utils.metric import Metric
import re
import os
def get_save_dirs(result_dir):
save_dirs = []
for job_name in os.listdir(result_dir):
use_contrastive = "contrastive" in job_name
for dataset_name in os.listdir(os.path.join(result_dir, job_name)):
for algo in os.listdir(os.path.join(result_dir, job_name, dataset_name)):
for seed in os.listdir(os.path.join(result_dir, job_name, dataset_name, algo)):
for agent_id in os.listdir(os.path.join(result_dir, job_name, dataset_name, algo, seed)):
if agent_id == "hydra_out" or agent_id == "agent_69420":
continue
save_dir = {'path': os.path.join(
result_dir, job_name, dataset_name, algo, seed, agent_id),
"dataset": dataset_name,
"algo": algo,
"use_contrastive": use_contrastive,
"seed": seed,
"agent_id": agent_id,
}
save_dirs.append(save_dir)
return save_dirs
def analyze_save_dirs(save_dirs, pattern=None, num_init_tasks=4, name="result"):
if pattern is None:
pattern = r".*"
record = Record(f"{name}.csv")
for save_dir in save_dirs:
# if the pattern doesn't match, continue
path = save_dir.pop("path")
if not re.search(pattern, path):
print('SKIPPING', path)
continue
m = Metric(path, num_init_tasks)
record.write(
{
"final_acc": m.compute_final_accuracy(),
"auc": m.compute_auc(mode='avg'),
} | save_dir
)
record.save()
return record
def analyze(result_dir):
save_dirs = get_save_dirs(result_dir)
record = analyze_save_dirs(save_dirs, name=result_dir)
return record
if __name__ == "__main__":
# root_save_dir = "experiment_results"
# vanilla_dir = "vanilla_jorge_setting_basis_no_sparse"
root_save_dir = "budget_experiment_results/jorge_setting_recv_variable_shared_memory_size"
vanilla_dir = "mem_size_300_comm_freq_9_num_queries_30"
result_dir = os.path.join(root_save_dir, vanilla_dir)
record = analyze(result_dir)
print("=====FINAL ACC======")
print(record.df.groupby(["algo", "dataset", "use_contrastive"])[
"final_acc"].mean() * 100)
print("=====AUC======")
print(record.df.groupby(["algo", "dataset", "use_contrastive"])[
"auc"].mean())