forked from Yushi-Hu/IC-DST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_zeroshot_codex_experiment.py
executable file
·237 lines (197 loc) · 8.02 KB
/
run_zeroshot_codex_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import json
import argparse
import copy
from collections import defaultdict
from tqdm import tqdm
from utils.helper import SpeedLimitTimer, PreviousStateRecorder
from utils.typo_fix import typo_fix
from config import CONFIG
from codex_completion import codex_completion
from utils.sql import sql_pred_parse, sv_dict_to_string
from prompting import get_prompt, conversion, table_prompt
from retriever.code.embed_based_retriever import EmbeddingRetriever
from evaluate_metrics import evaluate
# input arguments
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, default="./expts/zero-shot",
help="directory to save running log and configs")
parser.add_argument('--mwz_ver', type=str, default="2.1",
choices=['2.1', '2.4'], help="version of MultiWOZ")
parser.add_argument('--test_fn', type=str, default='',
help="file to evaluate on, empty means use the test set")
args = parser.parse_args()
# create the output folder
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "exp_config.json"), 'w') as f:
json.dump(vars(args), f, indent=4)
NUM_EXAMPLE = 10
# read the ontology and the test set
if args.mwz_ver == '2.1':
ontology_path = CONFIG["ontology_21"]
if args.test_fn == "":
test_set_path = "./data/mw21_100p_test.json"
else:
ontology_path = CONFIG["ontology_24"]
if args.test_fn == "":
test_set_path = "./data/mw24_100p_test.json"
# evaluate on some other file
if args.test_fn:
test_set_path = args.test_fn
with open(ontology_path) as f:
ontology = json.load(f)
with open(test_set_path) as f:
test_set = json.load(f)
demonstration_examples = [
{
"dialog": {
"sys": [""],
"usr": [
"i am looking for a guest house to stay in the west. i do not need internet .",
]
},
"turn_slot_values": {
"hotel-type": "guest house",
"hotel-area": "west",
"hotel-internet": "no"
},
"last_slot_values": {}
}
]
def run(test_set, turn=-1, use_gold=False):
# turn and use_gold are for analysis purpose
# turn = -1 means evalute all dialogues
# turn = 0 means evaluate single-turn dialogues
# turn = 1 means evalute two-turn dialogues... etc.
# when use_gold = True, the context are gold context (for analysis purpose)
# openai limitation 20 queries/min
timer = SpeedLimitTimer(second_per_step=3.1)
result_dict = defaultdict(list) # use to record the accuracy
selected_set = test_set
# if needed, only evaluate on particular turns (analysis purpose)
if turn >= 0:
if not use_gold:
raise ValueError(
"can only evaluate particular turn when using gold context")
selected_set = [d for d in test_set if len(
d['dialog']['usr']) == turn + 1]
prediction_recorder = PreviousStateRecorder() # state recorder
# start experiment
all_result = []
n_total = 0
n_correct = 0
total_acc = 0
total_f1 = 0
for data_item in tqdm(selected_set):
n_total += 1
completion = ""
if use_gold:
prompt_text = get_prompt(
data_item, examples=demonstration_examples)
else:
predicted_context = prediction_recorder.state_retrieval(data_item)
modified_item = copy.deepcopy(data_item)
modified_item['last_slot_values'] = predicted_context
examples = demonstration_examples
prompt_text = get_prompt(
data_item, examples=examples, given_context=predicted_context)
# print the retrieved examples (without the sql table)
print(prompt_text.replace(conversion(table_prompt), ""))
# record the prompt
data_item['prompt'] = prompt_text
# codex completion
complete_flag = False
parse_error_count = 0
while not complete_flag:
try:
completion = codex_completion(prompt_text)
# convert back the sql completion result
completion = conversion(completion, reverse=True)
except Exception as e:
if e.user_message.startswith("This model's maximum context length"):
print("prompt overlength")
examples = examples[1:]
prompt_text = get_prompt(
data_item, examples=examples, given_context=predicted_context)
else:
# throughput too high
timer.sleep(10)
else:
try:
# check if CODEX is crazy
temp_parse = sql_pred_parse(completion)
except:
parse_error_count += 1
if parse_error_count >= 5:
complete_flag = True
else:
complete_flag = True
# limit query speed
timer.step()
# aggregate the prediction and the history states
predicted_slot_values = {}
try:
predicted_slot_values = sql_pred_parse(completion) # a dictionary
except:
print("the output is not a valid SQL query")
data_item['not_valid'] = 1
predicted_slot_values = typo_fix(
predicted_slot_values, ontology=ontology, version=args.mwz_ver)
predicted_slot_values = {k:v for k,v in predicted_slot_values.items() if k in ontology}
context_slot_values = data_item['last_slot_values'] # a dictionary
# merge context and prediction
if use_gold:
all_slot_values = context_slot_values.copy()
else:
all_slot_values = prediction_recorder.state_retrieval(
data_item).copy()
for s, v in predicted_slot_values.items():
if s in all_slot_values and v == "[DELETE]":
del all_slot_values[s]
elif v != "[DELETE]":
all_slot_values[s] = v
# some slots may contain multiple values
all_slot_values = {k: v.split('|')[0]
for k, v in all_slot_values.items()}
# record current turn prediction
prediction_recorder.add_state(data_item, all_slot_values)
# record the predictions
data_item['pred'] = all_slot_values
data_item['ontology_path'] = ontology_path
data_item['completion'] = completion
all_result.append(data_item)
# print the result
print(completion)
print(
f"this is the {n_total - 1}th example. {data_item['ID']}_turn_{data_item['turn_id']}")
print(
f"pred turn change: {sv_dict_to_string(predicted_slot_values, sep='-')}")
print(
f"gold turn change: {sv_dict_to_string(data_item['turn_slot_values'], sep='-')}")
print(f"pred states: {sv_dict_to_string(all_slot_values, sep='-')}")
print(
f"gold states: {sv_dict_to_string(data_item['slot_values'], sep='-')}")
this_jga, this_acc, this_f1 = evaluate(
all_slot_values, data_item['slot_values'])
total_acc += this_acc
total_f1 += this_f1
if this_jga:
n_correct += 1
result_dict[data_item['turn_id']].append(1)
print("\n=====================correct!=======================")
else:
result_dict[data_item['turn_id']].append(0)
print("\n=====================wrong!=======================")
print("\n")
print(f"correct {n_correct}/{n_total} = {n_correct / n_total}")
print(f"Slot Acc {total_acc/n_total}")
print(f"Joint F1 {total_f1/n_total}")
print()
# calculate the accuracy of each turn
for k, v in result_dict.items():
print(f"accuracy of turn {k} is {sum(v)}/{len(v)} = {sum(v) / len(v)}")
return all_result
if __name__ == "__main__":
all_results = run(test_set)
with open(os.path.join(args.output_dir, "running_log.json"), 'w') as f:
json.dump(all_results, f, indent=4)