-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathget_final_performance_numbers.py
55 lines (45 loc) · 1.85 KB
/
get_final_performance_numbers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
This code is used to find the best validation epoch and to calculate the performance of the model.
How to run:
$ python get_final_performance_numbers.py results/interaction_prediction_reddit.txt
Paper: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks. S. Kumar, X. Zhang, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2019.
'''
import sys
import numpy as np
fname = sys.argv[1]
validation_performances = []
test_performances = []
val = []
test = []
f = open(fname, "r")
idx = -1
for l in f:
if "Validation performance of epoch" in l:
if val != []:
validation_performances.append(val)
test_performances.append(test)
idx = int(l.strip().split("epoch ")[1].split()[0])
val = [idx]
test = [idx]
if "Validation:" in l:
val.append(float(l.strip().split(": ")[-1]))
if "Test:" in l:
test.append(float(l.strip().split(": ")[-1]))
if val != []:
validation_performances.append(val)
test_performances.append(test)
validation_performances = np.array(validation_performances)
test_performances = np.array(test_performances)
if "interaction" in fname:
metrics = ['Mean Reciprocal Rank', 'Recall@10']
else:
metrics = ['AUC']
print '\n\n*** For file: %s ***' % fname
best_val_idx = np.argmax(validation_performances[:,1])
print "Best validation epoch: %d" % best_val_idx
print '\n\n*** Best validation performance (epoch %d) ***' % best_val_idx
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(validation_performances[best_val_idx][i+1]))
print '\n\n*** Final model performance on the test set, i.e., in epoch %d ***' % best_val_idx
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(test_performances[best_val_idx][i+1]))