-
Notifications
You must be signed in to change notification settings - Fork 6
/
nn_eval.py
128 lines (107 loc) · 4.33 KB
/
nn_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""NN metric for graph generation. An example to compare graphs generated by different models."""
import dgl
import pickle as pk
import networkx as nx
from datasets import StructureDataset, networkx_graphs
from evaluation.gin_evaluator import *
import random
import torch
import numpy as np
seed = 42
N_GIN = 10
def nn_eval(dataset_n, gen_graph_paths, N_gin=1):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Create individual evaluators for each GNN-based metric
evaluators = []
for _ in range(N_gin):
gin = load_feature_extractor(device)
evaluators.append(MMDEvaluation(model=gin, kernel='rbf', sigma='range', multiplier='mean'))
evaluators.append(prdcEvaluation(model=gin, use_pr=True))
evaluators.append(prdcEvaluation(model=gin, use_pr=False))
root_path = 'data'
dataset = StructureDataset(root_path, dataset_n)
n_graphs = len(dataset)
# 50/50 split dataset
idx = list(range(0, n_graphs))
random.seed(seed)
random.shuffle(idx)
train50_idx = idx[:int(n_graphs * 0.5)]
test50_idx = idx[int(n_graphs * 0.5):]
train50_graphs = networkx_graphs(dataset[train50_idx])
train50_graphs = [dgl.from_networkx(g).to(device) for g in train50_graphs]
test50_graphs = networkx_graphs(dataset[test50_idx])
test50_graphs = [dgl.from_networkx(g).to(device) for g in test50_graphs]
# make NN-based evaluation
metrics = {
'mmd_rbf': [],
'f1_pr': [],
'f1_dc': []
}
for evaluator in evaluators:
res, time = evaluator.evaluate(generated_dataset=train50_graphs, reference_dataset=test50_graphs)
for key in list(res.keys()):
if key in metrics:
metrics[key].append(res[key])
# output GT dataset results
print('50/50 split')
print('MMD_RBF: mean {:.6f} std {:.6f}'.format(np.mean(metrics['mmd_rbf']), np.std(metrics['mmd_rbf'])))
print('F1_PR: mean {:.6f} std {:.6f}'.format(np.mean(metrics['f1_pr']), np.std(metrics['f1_pr'])))
print('F1_DC: mean {:.6f} std {:.6f}'.format(np.mean(metrics['f1_dc']), np.std(metrics['f1_dc'])))
print('-' * 80)
# construct reference graphs
n_train = int(n_graphs * 0.8)
test_dataset = dataset[n_train:]
ref_graphs = networkx_graphs(test_dataset)
# convert graphs to DGL from NetworkX
ref_graphs = [dgl.from_networkx(g).to(device) for g in ref_graphs]
# load generated graphs
for name, gen_graph_path in gen_graph_paths.items():
with open(gen_graph_path, 'rb') as f:
graph_adj = pk.load(f)
if isinstance(graph_adj[0], np.ndarray):
gen_graphs = [nx.from_numpy_matrix(adj_i) for adj_i in graph_adj]
else:
gen_graphs = graph_adj
# convert graphs to DGL from NetworkX
gen_graphs = [dgl.from_networkx(g).to(device) for g in gen_graphs]
# make NN-based evaluation
metrics = {
'mmd_rbf': [],
'f1_pr': [],
'f1_dc': []
}
for evaluator in evaluators:
res, time = evaluator.evaluate(generated_dataset=gen_graphs, reference_dataset=ref_graphs)
for key in list(res.keys()):
if key in metrics:
metrics[key].append(res[key])
# res[key] = res[key] + metrics[key]
# metrics.update(res)
# output results
print(name)
print('MMD_RBF: mean {:.6f} std {:.6f}'.format(np.mean(metrics['mmd_rbf']), np.std(metrics['mmd_rbf'])))
print('F1_PR: mean {:.6f} std {:.6f}'.format(np.mean(metrics['f1_pr']), np.std(metrics['f1_pr'])))
print('F1_DC: mean {:.6f} std {:.6f}'.format(np.mean(metrics['f1_dc']), np.std(metrics['f1_dc'])))
print('-'*80)
return
def eval_coms():
dataset_n = 'Community_small'
gen_graph_paths = dict()
print('dataset:', dataset_n)
# add your file paths
gen_graph_paths['ER'] = ''
gen_graph_paths['VGAE'] = ''
gen_graph_paths['GraphRNN'] = ''
gen_graph_paths['GRAN'] = ''
gen_graph_paths['EDPGNN'] = ''
gen_graph_paths['BIGG'] = ''
gen_graph_paths['GraphGDP'] = ''
nn_eval(dataset_n, gen_graph_paths, N_GIN)
return
if __name__ == '__main__':
# set down the random seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# com-small dataset
eval_coms()