forked from cjx0525/BGCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_main.py
108 lines (88 loc) · 3.64 KB
/
eval_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import setproctitle
import dataset
from model import BGCN, BGCN_Info
from utils import check_overfitting, early_stop, get_perf, logger
from train import simple_train, multi_train
from metric import Recall, NDCG, MRR
from config import CONFIG
from test import test, mlp_test
import loss
import time
import csv
# from utils.visshow import VisShow
TAG = ''
def main():
# set env
setproctitle.setproctitle(f"test{CONFIG['name']}")
os.environ["CUDA_VISIBLE_DEVICES"] = CONFIG['gpu_id']
device = torch.device('cuda')
# load data
bundle_train_data, bundle_test_data, item_data, assist_data = \
dataset.get_dataset(CONFIG['path'], CONFIG['dataset_name'], task=CONFIG['eval_task'])
bundle_test_loader = DataLoader(bundle_test_data, 4096, False,
num_workers=16, pin_memory=True)
test_loader = bundle_test_loader
# graph
ub_graph = bundle_train_data.ground_truth_u_b
ui_graph = item_data.ground_truth_u_i
bi_graph = assist_data.ground_truth_b_i
# metric
metrics = [Recall(20), NDCG(20), Recall(40), NDCG(40), Recall(80), NDCG(80)]
TARGET = 'Recall@20'
# log
log = logger.Logger(os.path.join(
CONFIG['log'], CONFIG['model'],
f"{CONFIG['dataset_name']}_{CONFIG['eval_task']}", TAG), 'best', checkpoint_target=TARGET)
# vis = VisShow('localhost', 16666,
# f'{CONFIG['dataset_name']}-{MODELTYPE.__name__}-{decay}-{lr}-{theta}-3layer')
DIRS = [
'your_model_dirs',
]
for DIR in DIRS:
with open(os.path.join(DIR, 'model.csv'), 'r') as f:
d = csv.DictReader(f)
d = [line for line in d]
for i in range(len(d)):
s = d[i][None][0]
s1 = d[i][None][5]
dd = {'hash': d[i]['hash'],
'embed_L2_norm': float(d[i][' embed_L2_norm']),
'mess_dropout': float(d[i][' mess_dropout']),
'node_dropout': float(d[i][' node_dropout']),
'lr': float(s[s.find(':') + 1:])}
# print(dd)
# model
if CONFIG['model'] == 'BGCN':
graph = [ub_graph, ui_graph, bi_graph]
info = BGCN_Info(64, dd['embed_L2_norm'], dd['mess_dropout'], dd['node_dropout'], 1)
model = BGCN(info, assist_data, graph, device, pretrain=None).to(device)
assert model.__class__.__name__ == CONFIG['model']
model.load_state_dict(torch.load(
os.path.join(DIR, dd['hash']+"_Recall@10.pth")))
# log
log.update_modelinfo(info, {'lr': dd['lr']}, metrics)
# temp
time_path = time.strftime('%m-%d-%H-%M-%S-', time.localtime(time.time()))
visual_path = os.path.join(
CONFIG['visual'], CONFIG['model'],
f"{CONFIG['dataset_name']}_{CONFIG['eval_task']}",
f"{time_path}-{CONFIG['note']}")
visual_name = f"lr{dd['lr']}_decay{dd['embed_L2_norm']}_\
medr{dd['mess_dropout']}_nodr{dd['node_dropout']}"
visual_name = str(dd)
visual = [visual_path, visual_name]
epoch = 1
test(model, epoch+1, test_loader, device, CONFIG, metrics, visual)
# log
log.update_log(metrics, model)
# # show(log.metrics_log)
log.close_log(TARGET)
log.close()
if __name__ == "__main__":
main()