OpenFactCheck-Prerelease
/
src
/openfactcheck
/solvers
/webservice
/rarr_utils
/question_generation.py
| """Utils for running question generation.""" | |
| import os | |
| import time | |
| from typing import List | |
| import openai | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| def parse_api_response(api_response: str) -> List[str]: | |
| """Extract questions from the GPT-3 API response. | |
| Our prompt returns questions as a string with the format of an ordered list. | |
| This function parses this response in a list of questions. | |
| Args: | |
| api_response: Question generation response from GPT-3. | |
| Returns: | |
| questions: A list of questions. | |
| """ | |
| search_string = "I googled:" | |
| questions = [] | |
| for question in api_response.split("\n"): | |
| # Remove the search string from each question | |
| if search_string not in question: | |
| continue | |
| question = question.split(search_string)[1].strip() | |
| questions.append(question) | |
| return questions | |
| def run_rarr_question_generation( | |
| claim: str, | |
| model: str, | |
| prompt: str, | |
| temperature: float, | |
| num_rounds: int, | |
| context: str = None, | |
| num_retries: int = 5, | |
| ) -> List[str]: | |
| """Generates questions that interrogate the information in a claim. | |
| Given a piece of text (claim), we use GPT-3 to generate questions that question the | |
| information in the claim. We run num_rounds of sampling to get a diverse set of questions. | |
| Args: | |
| claim: Text to generate questions off of. | |
| model: Name of the OpenAI GPT-3 model to use. | |
| prompt: The prompt template to query GPT-3 with. | |
| temperature: Temperature to use for sampling questions. 0 represents greedy deconding. | |
| num_rounds: Number of times to sample questions. | |
| Returns: | |
| questions: A list of questions. | |
| """ | |
| if context: | |
| gpt3_input = prompt.format(context=context, claim=claim).strip() | |
| else: | |
| gpt3_input = prompt.format(claim=claim).strip() | |
| questions = set() | |
| for _ in range(num_rounds): | |
| for _ in range(num_retries): | |
| try: | |
| response = openai.completions.create( | |
| model=model, | |
| prompt=gpt3_input, | |
| temperature=temperature, | |
| max_tokens=256, | |
| ) | |
| cur_round_questions = parse_api_response( | |
| response.choices[0].text.strip() | |
| ) | |
| questions.update(cur_round_questions) | |
| break | |
| except openai.OpenAIError as exception: | |
| print(f"{exception}. Retrying...") | |
| time.sleep(1) | |
| questions = list(sorted(questions)) | |
| return questions | |