forked from anthonywchen/RARR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
question_generation.py
82 lines (69 loc) · 2.55 KB
/
question_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Utils for running question generation."""
import os
import time
from typing import List
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
def parse_api_response(api_response: str) -> List[str]:
"""Extract questions from the GPT-3 API response.
Our prompt returns questions as a string with the format of an ordered list.
This function parses this response in a list of questions.
Args:
api_response: Question generation response from GPT-3.
Returns:
questions: A list of questions.
"""
search_string = "I googled:"
questions = []
for question in api_response.split("\n"):
# Remove the search string from each question
if search_string not in question:
continue
question = question.split(search_string)[1].strip()
questions.append(question)
return questions
def run_rarr_question_generation(
claim: str,
model: str,
prompt: str,
temperature: float,
num_rounds: int,
context: str = None,
num_retries: int = 5,
) -> List[str]:
"""Generates questions that interrogate the information in a claim.
Given a piece of text (claim), we use GPT-3 to generate questions that question the
information in the claim. We run num_rounds of sampling to get a diverse set of questions.
Args:
claim: Text to generate questions off of.
model: Name of the OpenAI GPT-3 model to use.
prompt: The prompt template to query GPT-3 with.
temperature: Temperature to use for sampling questions. 0 represents greedy deconding.
num_rounds: Number of times to sample questions.
Returns:
questions: A list of questions.
"""
if context:
gpt3_input = prompt.format(context=context, claim=claim).strip()
else:
gpt3_input = prompt.format(claim=claim).strip()
questions = set()
for _ in range(num_rounds):
for _ in range(num_retries):
try:
response = openai.Completion.create(
model=model,
prompt=gpt3_input,
temperature=temperature,
max_tokens=256,
)
cur_round_questions = parse_api_response(
response.choices[0].text.strip()
)
questions.update(cur_round_questions)
break
except openai.error.OpenAIError as exception:
print(f"{exception}. Retrying...")
time.sleep(1)
questions = list(sorted(questions))
return questions