-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig_v_table.py
250 lines (225 loc) · 11.3 KB
/
config_v_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def get_V_table(model_path):
if 'Freeway' in model_path:
if 'noisynet' in model_path:
V_table = {0.005: (1.0452193021774292, 1.0613726377487183),
0.05: (1.0452193021774292, 1.0613726377487183),
0.1: (1.0452193021774292, 1.0613726377487183),
0.5: (1.0452193021774292, 1.0613726377487183),
0.75: (1.0452193021774292, 1.0613726377487183),
1.0: (1.0452193021774292, 1.0613726377487183)}
elif 'graddqn' in model_path:
V_table = {0.005: (0.6000280380249023, 0.6044848561286926),
0.05: (0.6000280380249023, 0.6044848561286926),
0.1: (0.6000280380249023, 0.6044848561286926),
0.5: (0.6000280380249023, 0.6044848561286926),
0.75: (0.6000280380249023, 0.6044848561286926),
1.0: (0.6000280380249023, 0.6044848561286926)}
else:
raise NotImplementedError(f'Freeway: model_path = {model_path} not implemented!')
elif 'Pong' in model_path:
if 'noisynet' in model_path:
V_table = {0.001: (-0.6386245489120483, 1.9888107776641846),
0.005: (-0.6386322379112244, 1.9888076782226562),
0.01: (-0.6386420130729675, 1.9888032674789429),
0.03: (-0.6638827323913574, 1.9887866973876953),
0.05: (-0.6638806462287903, 1.9887686967849731),
0.1: (-0.6637315154075623, 1.9886873960494995)}
elif 'graddqn' in model_path:
V_table = {0.001: (-0.5181906819343567, 1.9751276969909668),
0.005: (-0.5183050632476807, 1.9750744104385376),
0.01: (-0.5183289051055908, 1.9750686883926392),
0.03: (-0.5183154344558716, 1.9750584363937378),
0.05: (-0.5183108448982239, 1.9750334024429321),
0.1: (-0.5182011127471924, 1.9749038219451904)}
else:
raise NotImplementedError(f'Pong: model_path = {model_path} not implemented!')
elif 'Boxing' in model_path:
if 'nat' in model_path:
V_table = {0.001: (12.652398109436035, 30.271860122680664),
0.005: (12.652400970458984, 30.271867752075195),
0.01: (12.651599884033203, 30.27104949951172),
0.03: (12.651833534240723, 30.368236541748047),
0.05: (12.652006149291992, 30.368276596069336),
0.1: (12.652134895324707, 30.368314743041992)}
elif 'noisynet' in model_path:
V_table = {0.001: (7.715142726898193, 22.17021942138672),
0.005: (7.715142726898193, 22.17022132873535),
0.01: (7.56257963180542, 21.91171646118164),
0.03: (7.562463283538818, 21.911514282226562),
0.05: (7.562436103820801, 21.91152572631836),
0.1: (7.562304496765137, 21.911609649658203)}
elif 'adv' in model_path:
V_table = {0.001: (7.171239852905273, 23.252723693847656),
0.005: (6.63673734664917, 23.034860610961914),
0.01: (8.145294189453125, 22.700572967529297),
0.03: (8.565078735351562, 21.688188552856445),
0.05: (9.53758430480957, 20.1334228515625),
0.1: (3.132399559020996, 17.30183982849121)}
elif 'aug' in model_path:
V_table = {0.001: (-1.0133867263793945, 26.743619918823242),
0.005: (-2.126743793487549, 26.73563575744629),
0.01: (0.7752446532249451, 26.628562927246094),
0.03: (1.4728885889053345, 25.073265075683594),
0.05: (2.6650748252868652, 22.28669548034668),
0.1: (5.46824836730957, 17.07016944885254)}
elif 'cov' in model_path:
V_table = {0.001: (-1.2643269300460815, 25.42894744873047),
0.005: (-1.2028146982192993, 25.4093074798584),
0.01: (-1.0432186126708984, 25.311626434326172),
0.03: (1.247150182723999, 24.118837356567383),
0.05: (1.92692232131958, 22.684253692626953),
0.1: (2.8393449783325195, 17.277660369873047)}
elif 'pgd' in model_path:
V_table = {0.001: (0.3573600649833679, 26.907806396484375),
0.005: (0.4471278786659241, 26.96663475036621),
0.01: (0.503233790397644, 26.83462142944336),
0.03: (2.7629988193511963, 25.80197525024414),
0.05: (4.403608798980713, 23.638290405273438),
0.1: (4.832557201385498, 17.85064125061035)}
else:
raise NotImplementedError(f'Boxing: model_path = {model_path} not implemented!')
elif 'Breakout' in model_path:
if 'adv' in model_path:
V_table = {0.001: (3.8931267261505127, 7.003189563751221),
0.005: (0.4728318154811859, 6.747408390045166),
0.01: (4.160265922546387, 6.580664157867432),
0.03: (5.200388431549072, 5.783012866973877),
0.05: (9.369736671447754, 10.651130676269531),
0.1: (18.00510597229004, 20.888399124145508)}
elif 'aug' in model_path:
V_table = {0.001: (0.47669175267219543, 6.696699619293213),
0.005: (2.3722803592681885, 6.7005228996276855),
0.01: (2.5842974185943604, 6.246527671813965),
0.03: (6.578291416168213, 7.349504470825195),
0.05: (12.224047660827637, 12.728035926818848),
0.1: (12.373668670654297, 13.091817855834961)}
elif 'cov' in model_path:
V_table = {0.001: (0.020712073892354965, 3.6237778663635254),
0.005: (0.028907131403684616, 3.6185104846954346),
0.01: (0.809413731098175, 3.572075843811035),
0.03: (2.60768461227417, 3.5341570377349854),
0.05: (5.902832508087158, 7.44727087020874),
0.1: (7.929012298583984, 9.810087203979492)}
elif 'pgd' in model_path:
V_table = {0.001: (0.07901041209697723, 9.283035278320312),
0.005: (4.477279186248779, 8.087729454040527),
0.01: (4.583154201507568, 7.359407424926758),
0.03: (9.355469703674316, 10.174246788024902),
0.05: (11.46679973602295, 12.618890762329102),
0.1: (15.9259033203125, 16.664899826049805)}
else:
raise NotImplementedError(f'Breakout: model_path = {model_path} not implemented!')
elif 'highway-fast' in model_path:
if 'adv' in model_path:
V_table = {0.001: (-3.4863009452819824, 5.3150715827941895),
0.005: (-3.4234046936035156, 5.315459251403809),
0.01: (-3.203420400619507, 5.315192222595215),
0.03: (-1.8876430988311768, 5.169591903686523),
0.05: (1.780735731124878, 4.960638523101807),
0.1: (2.3625402450561523, 4.808628082275391),
0.5: (-0.7457254528999329, 4.161691665649414),
0.75: (-2.062584161758423, 4.6026458740234375),
1.0: (-1.5561974048614502, 5.107702255249023),
2.0: (4.423400402069092, 7.6058573722839355),
4.0: (-16.66010284423828, 11.822466850280762)}
elif 'aug' in model_path:
V_table = {0.001: (1.4310184717178345, 5.15714693069458),
0.005: (1.4330947399139404, 5.157279968261719),
0.01: (1.4301035404205322, 5.156999588012695),
0.03: (1.4469431638717651, 5.156375408172607),
0.05: (1.4998865127563477, 5.1572442054748535),
0.1: (1.664524793624878, 5.142214298248291),
0.5: (1.6625908613204956, 4.226973056793213),
0.75: (1.4756033420562744, 4.000681400299072),
1.0: (1.4228792190551758, 3.8414368629455566),
2.0: (1.3527854681015015, 3.7976841926574707),
4.0: (1.4060451984405518, 4.4868083000183105)}
elif 'cov' in model_path:
V_table = {0.001: (-1.2227998971939087, 5.041808128356934),
0.005: (-1.2165334224700928, 5.043108940124512),
0.01: (-1.1906334161758423, 5.043541431427002),
0.03: (-1.3780547380447388, 5.007018566131592),
0.05: (-0.3419512212276459, 4.930624485015869),
0.1: (1.651313304901123, 4.730605125427246),
0.5: (2.033402442932129, 3.5774776935577393),
0.75: (1.7956444025039673, 3.427471876144409),
1.0: (1.7279866933822632, 3.4084274768829346),
2.0: (1.8115736246109009, 3.8032267093658447),
4.0: (2.2257373332977295, 5.115963935852051)}
elif 'nat' in model_path:
V_table = {0.001: (0.8100096583366394, 5.687267780303955),
0.005: (0.8065616488456726, 5.6893792152404785),
0.01: (0.8075225949287415, 5.690341949462891),
0.03: (0.8846161365509033, 5.687906742095947),
0.05: (1.0077862739562988, 5.678463459014893),
0.1: (1.383544921875, 5.625143527984619),
0.5: (3.714292049407959, 6.164247035980225),
0.75: (5.006048202514648, 7.397580146789551),
1.0: (5.754027843475342, 8.927143096923828),
2.0: (10.911373138427734, 15.006096839904785),
4.0: (20.779056549072266, 27.66459846496582)}
elif 'pgd' in model_path:
V_table = {0.001: (-0.851825475692749, 5.20785665512085),
0.005: (-0.7977882623672485, 5.203657627105713),
0.01: (-0.6973153352737427, 5.170176982879639),
0.03: (-0.20668944716453552, 5.027972221374512),
0.05: (0.20738109946250916, 4.939896583557129),
0.1: (0.843664824962616, 4.747650623321533),
0.5: (0.6983765959739685, 3.9741010665893555),
0.75: (0.2753481864929199, 4.257822513580322),
1.0: (-0.16384287178516388, 4.687958717346191),
2.0: (-1.8197461366653442, 6.749190330505371),
4.0: (-5.130617141723633, 11.154892921447754)}
elif 'graddqn' in model_path:
V_table = {0.001: (-0.05645060911774635, 5.392153263092041),
0.005: (-0.04270013049244881, 5.39345121383667),
0.01: (-0.017048824578523636, 5.3908586502075195),
0.03: (0.12727545201778412, 5.315982341766357),
0.05: (0.8526111245155334, 5.1892571449279785),
0.1: (1.86343514919281, 4.712134838104248),
0.5: (0.1932944357395172, 3.5127270221710205),
0.75: (-1.093806505203247, 3.9216549396514893),
1.0: (-1.8417786359786987, 4.39701509475708),
2.0: (-5.960436820983887, 6.3666887283325195),
4.0: (-15.604986190795898, 10.33242416381836)}
elif 'noisynet' in model_path:
V_table = {0.001: (-0.8183189630508423, 5.455911159515381),
0.005: (-0.8158490657806396, 5.454329013824463),
0.01: (-0.7469663619995117, 5.444551467895508),
0.03: (0.48924943804740906, 5.324298858642578),
0.05: (0.8704175353050232, 5.201502323150635),
0.1: (0.9230291843414307, 4.761870384216309),
0.5: (2.863206148147583, 6.975791931152344),
0.75: (3.914341688156128, 9.476221084594727),
1.0: (4.8058929443359375, 11.978668212890625),
2.0: (7.97699499130249, 21.93963050842285),
4.0: (14.582740783691406, 41.81871795654297)}
elif 'radialrl' in model_path:
V_table = {0.001: (0.40167906880378723, 5.535007476806641),
0.005: (0.4204297661781311, 5.520941734313965),
0.01: (0.5291378498077393, 5.497609615325928),
0.03: (1.4488271474838257, 5.3698577880859375),
0.05: (1.2086281776428223, 5.230834007263184),
0.1: (1.633813500404358, 4.9519805908203125),
0.5: (2.039306402206421, 5.960858345031738),
0.75: (1.2049446105957031, 7.674386978149414),
1.0: (-0.2793647050857544, 9.459101676940918),
2.0: (-3.4751057624816895, 16.651653289794922),
4.0: (-9.37889575958252, 31.04629898071289)}
elif 'CARRL' in model_path:
V_table = {0.001: (-1.3827874660491943, 3.772693634033203),
0.005: (-1.3903032541275024, 3.7710087299346924),
0.01: (-1.3847306966781616, 3.7683959007263184),
0.03: (-0.9196563959121704, 3.7495474815368652),
0.05: (-1.212925672531128, 3.714653968811035),
0.1: (-0.5634633898735046, 3.751926898956299),
0.5: (2.554965019226074, 4.816079139709473),
0.75: (3.3965229988098145, 6.093468189239502),
1.0: (4.423400402069092, 7.6058573722839355),
2.0: (8.880249977111816, 13.824429512023926),
4.0: (19.479106903076172, 26.363550186157227)}
else:
raise NotImplementedError(f'highway-fast-v0: model_path = {model_path} not implemented!')
else:
raise NotImplementedError(f'unknown game: model_path = {model_path} not implemented!')
return V_table