-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_perturbed.py
62 lines (50 loc) · 2.36 KB
/
run_perturbed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from main import calibrate, evaluate_ood, evaluate_corrupted, missing_outputs, missing_corrupted
datasets = [
'cifar10',
'cifar100',
# 'svhn'
'imagenet',
]
methods = [
'TemperatureScaling-P',
'TemperatureScalingMSE-P',
# 'VectorScaling',
# 'MatrixScaling',
# 'MatrixScalingODIR',
# 'DirichletL2',
# 'DirichletODIR',
'EnsembleTemperatureScaling-P',
'EnsembleTemperatureScalingCE-P',
'IRM-P',
'IRMTS-P',
'IROvA-P',
'IROvATS-P',
]
architectures = {}
architectures['cifar10'] = ['densenet40_k12', 'resnet20', 'resnet56', 'resnet110', 'wrn16_10', 'wrn28_10', 'wrn40_8']
architectures['cifar100'] = ['densenet40_k12', 'resnet20', 'resnet56', 'resnet110', 'wrn16_10', 'wrn28_10', 'wrn40_8']
# architectures['cifar10'].sort(reverse=True)
# architectures['cifar100'].sort(reverse=True)
architectures['svhn'] = ['densenet40_k12', 'resnet20', 'resnet56', 'resnet110', 'wrn16_10', 'wrn28_10', 'wrn40_8']
architectures['imagenet'] = ['resnet50', 'vgg19', 'resnext101_32x8d', 'densenet161', 'wide_resnet101_2']
splitIDs = [0,1,2,3,4]
calibrate(datasets, architectures, methods, splitIDs)
ood_datasets = ['stl10', 'cifar10.1-v4', 'cifar10.1-v6']
evaluate_ood('cifar10', ood_datasets, architectures['cifar10'], methods, splitIDs)
ood_datasets = ['imagenet-v2-mf', 'imagenet-v2-thr', 'imagenet-v2-ti', 'imagenet-sketch', 'imagenet-a', 'imagenet-r']
evaluate_ood('imagenet', ood_datasets, architectures['imagenet'], methods, splitIDs)
corruptions = [
'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur',
'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur',
'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur'
]
corruptions.sort(reverse=True)
intensities = [1,2,3,4,5]
for dataset in ['cifar10', 'cifar100', 'imagenet']:
evaluate_corrupted(dataset, corruptions, intensities, architectures[dataset], methods, splitIDs)
# ood_datasets = ['imagenet-v2-mf', 'imagenet-v2-thr', 'imagenet-v2-ti', 'imagenet-sketch', 'imagenet-a', 'imagenet-r']
# evaluate_ood('imagenet', ood_datasets, architectures['imagenet'], methods, splitIDs)
datasets = ['imagenet']
calibrate(datasets, architectures, methods, splitIDs)
for dataset in ['imagenet']:
evaluate_corrupted(dataset, corruptions, intensities, architectures[dataset], methods, splitIDs)