-
Notifications
You must be signed in to change notification settings - Fork 3
/
choose_by_cate.py
72 lines (56 loc) · 1.7 KB
/
choose_by_cate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
filepath = './data/clothing/Men/'
vectorpath = '../../../../../mnt/dev_sdb1/cuizeyu/amazon_image_feature/clothing/Men/'
# filepath = './data/'
# # vectorpath = '../../../../../mnt/dev_sdb1/cuizeyu/amazon_image_feature/clothing/Men/'
with open(filepath + 'item_category.json', 'r') as f:
item_category = dict(json.load(f))
with open(filepath + 'valid_taobao.json', 'r') as f:
test_data = json.load(f)
with open(filepath + "cate_2match.json", 'r') as f:
cate_2match = json.load(f)
cate_2match = zip(*(zip(*cate_2match)))
cate_2match = set(cate_2match)
def choose(q, p, n):
def _cate_match(x, y):
c_x = item_category[x]
c_y = item_category[y]
if (c_x, c_y) in cate_2match or (c_y, c_x) in cate_2match:
return True
else:
return False
if _cate_match(q,p) and not _cate_match(q, n):
return 1.
elif _cate_match(q,p) and _cate_match(q, n):
return 0.5
elif not _cate_match(q,p) and not _cate_match(q, n):
print "WAAO~"
return 0.3
else:
return 0.
def error_rate(q, p, r):
def _cate_match(x, y):
c_x = item_category[x]
c_y = item_category[y]
if (c_x, c_y) in cate_2match or (c_y, c_x) in cate_2match:
return True
else:
return False
if _cate_match(q,p) == r:
return 1.
else:
return 0.
tt = 0.
ss = 0.
for q, p, n in test_data:
# if choose(q, p, n) == 0.5:
tt += choose(q, p, n)
ss += 1.
print "auc", tt/ss
for q, p, n in test_data:
# if choose(q, p, n) == 0.5:
tt += error_rate(q, p, True)
tt += error_rate(q, p, False)
ss += 2.
print "acc", tt/ss
print 'error_rate', 1-tt/ss