Spaces:
Sleeping
Sleeping
import os, json | |
import gradio as gr | |
import pandas as pd | |
QUESTIONS = [ | |
"What is the DOI of this study?", | |
"What is the Citation ID of this study?", | |
"What is the First author of this study?", | |
"What is the year of this study?", | |
"What is the animal type of this study?", | |
"What is the exposure age of this study?", | |
"Is there any behavior test done in this study?", | |
"What's the Intervention 1's name of this study?(anesthetics only)", | |
"What's the Intervention 2's name of this study?(anesthetics only)", | |
"What's the genetic chain of this study?", | |
] | |
template = '''We now have a following <document> in the medical field: | |
""" | |
{} | |
""" | |
We have some introduction here: | |
1. DOI: The DOI link for the article, usually can be found in the first line of the .txt file for the article. E.g., “DOI: 10.3892/mmr.2019.10397”. | |
2. Citation ID: The number in the file name. E.g., “1134”. | |
3. First author: The last name in the file name. E.g., “Guan”. | |
4. Year: The year in the file name. E.g., “2019”. | |
5. Animal type: The rodent type used in the article, should be one of the choices: mice, rats. E.g., “rats”. | |
6. Exposure age: The age when the animals were exposed to anesthetics, should be mentioned as "PND1", "PND7","postnatal day 7", "Gestational day 21", etc, which should be extract as: 'PND XX' , 'Gestational day xx'. E.g., “PND7”. | |
7. Behavior test: Whether there is any behavior test in the article, should be one of the choices: "Y", "N". "Y" is chosen if there are any of the behavior tests described and done in the article, which mentioned as: "Open field test", "Morris water task", "fear conditioning test", "Dark/light avoidance"; "passive/active avoidance test"; "elevated maze", "Forced swim test", "Object recognition test", "Social interaction/preference“. E.g., “N”. | |
8. Intervention 1 & Intervention 2: Intervention 1 and Intervention 2 are both anesthetic drugs, which listed as: "isoflurane", "sevoflurane", "desflurane", "ketamine", "propofol", "Midazolam", "Nitrous oxide“. If none, put “NA”. E.g., “propofol”. | |
9. Genetic chain: Genetic chain is the genetic type of the animals being used in the article, here is the examples: | |
"C57BL/6", "C57BL/6J" should be extracted as "C57BL/6"; "Sprague Dawley", "Sprague-Dawley", "SD" should be extracted as "Sprague Dawley"; "CD-1" should be extracted as "CD-1"; "Wistar/ST" should be extracted as "Wistar/ST"; "Wistar" should be extracted as "Wistar"; "FMR-1 KO" should be extracted as "FMR-1 KO“. E.g., “Sprague Dawley”. | |
We have some <question>s begin with "Question" here: | |
""" | |
{} | |
""" | |
Please finish the following task: | |
1. Please select the <original sentences> related the each <question> from the <document>. | |
2. Please use the <original sentences> to answer the <question>. | |
3. Please provide <original sentences> coming from the <document>. | |
4. Output the <answer> in the following json format: | |
{{ | |
"Question 1": {{ | |
"question": {{}}, | |
"answer": {{}}, | |
"original sentences": [] | |
}}, | |
"Question 2": {{ | |
"question": {{}}, | |
"answer": {{}}, | |
"original sentences": [] | |
}}, | |
... | |
}} | |
''' | |
import requests | |
class OpenAI: | |
def __init__(self, init_prompt = None): | |
self.history = [] | |
if init_prompt is not None: | |
self.history.append({'role': 'system', 'content': init_prompt}) | |
def clear_history(self): | |
self.history = [] | |
def show_history(self): | |
for message in self.history: | |
print(f"{message['role']}: {message['content']}") | |
def get_raw_history(self): | |
return self.history | |
def __call__(self, prompt, with_history = False, model = 'gpt-3.5-turbo', temperature = 0, api_key = None): | |
URL = 'https://api.openai.com/v1/chat/completions' | |
new_message = {'role': 'user', 'content': prompt} | |
if with_history: | |
self.history.append(new_message) | |
messages = self.history | |
else: | |
messages = [new_message] | |
resp = requests.post(URL, json={ | |
'model': model, | |
'messages': messages, | |
'temperature': temperature, | |
}, headers={ | |
'Authorization': f"Bearer {api_key}" | |
}) | |
# print(resp.json()) | |
self.history.append(resp.json()['choices'][0]['message']) | |
return resp.json()['choices'][0]['message']['content'] | |
class Backend: | |
def __init__(self): | |
self.agent = OpenAI() | |
def read_file(self, file): | |
# read the file | |
with open(file.name, 'r') as f: | |
text = f.read() | |
return text | |
def highlight_text(self, text, highlight_list): | |
# hightlight the reference | |
for hl in highlight_list: | |
text = text.replace(hl, f'<mark style="background: #5FACF0">{hl}</mark>') | |
# add line break | |
text = text.replace('\n', f" <br /> ") | |
# add scroll bar | |
text = f'<div style="height: 500px; overflow: auto;">{text}</div>' | |
return text | |
def process_file(self, file, question, openai_key): | |
# get the question | |
question = [ f'Question {id_ +1 }: {q}' for id_, q in enumerate(question) if 'Input question' not in q] | |
question = '\n'.join(question) | |
# get the text | |
self.text = self.read_file(file) | |
# make the prompt | |
prompt = template.format(self.text, question) | |
# interact with openai | |
res = self.agent(prompt, with_history = False, temperature = 0.1, model = 'gpt-3.5-turbo-16k', api_key = openai_key) | |
res = json.loads(res) | |
# for multiple questions | |
self.gpt_result = res | |
self.curret_question = 0 | |
self.totel_question = len(res.keys()) | |
# make a dataframe to record everything | |
self.ori_answer_df = pd.DataFrame(res).T | |
self.answer_df = pd.DataFrame(res).T | |
# default fist question | |
res = res['Question 1'] | |
question = res['question'] | |
self.answer = res['answer'] | |
self.highlighted_out = res['original sentences'] | |
highlighted_out_html = self.highlight_text(self.text, self.highlighted_out) | |
self.highlighted_out = '\n'.join(self.highlighted_out) | |
return question, self.answer, highlighted_out_html, self.answer, self.highlighted_out | |
def process_results(self, answer_correct, correct_answer, reference_correct, correct_reference): | |
if not hasattr(self, 'clicked_correct_answer'): | |
raise gr.Error("You need to judge whether the generated answer is correct first") | |
if not hasattr(self, 'clicked_correct_reference'): | |
raise gr.Error("You need to judge whether the highlighted reference is correct first") | |
if not hasattr(self, 'answer_df'): | |
raise gr.Error("You need to submit the document first") | |
if self.curret_question >= self.totel_question or self.curret_question < 0: | |
raise gr.Error("No more questions, please return back") | |
# record the answer | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'answer_correct'] = answer_correct | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'reference_correct'] = reference_correct | |
if self.clicked_correct_answer == True: | |
if hasattr(self, 'answer'): | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'correct_answer'] = self.answer | |
else: | |
raise gr.Error("You need to submit the document first") | |
else: | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'correct_answer'] = correct_answer | |
if self.clicked_correct_reference == True: | |
if hasattr(self, 'highlighted_out'): | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'correct_reference'] = self.highlighted_out | |
else: | |
raise gr.Error("You need to submit the document first") | |
else: | |
self.answer_df.loc[f'Question {self.curret_question + 1}', 'correct_reference'] = correct_reference | |
gr.Info('Results saved!') | |
return "Results saved!" | |
def process_next(self): | |
self.curret_question += 1 | |
if hasattr(self, 'clicked_correct_answer'): | |
del self.clicked_correct_answer | |
if hasattr(self, 'clicked_correct_reference'): | |
del self.clicked_correct_reference | |
if self.curret_question >= self.totel_question: | |
# self.curret_question -= 1 | |
return "No more questions!", "No more questions!", "No more questions!", 'No more questions!', 'No more questions!', 'Still need to click the button above to save the results', None, None | |
else: | |
res = self.gpt_result[f'Question {self.curret_question + 1}'] | |
question = res['question'] | |
self.answer = res['answer'] | |
self.highlighted_out = res['original sentences'] | |
highlighted_out_html = self.highlight_text(self.text, self.highlighted_out) | |
self.highlighted_out = '\n'.join(self.highlighted_out) | |
return question, self.answer, highlighted_out_html, 'Please judge on the generated answer', 'Please judge on the generated answer', 'Still need to click the button above to save the results', None, None | |
def process_last(self): | |
self.curret_question -= 1 | |
if hasattr(self, 'clicked_correct_answer'): | |
del self.clicked_correct_answer | |
if hasattr(self, 'clicked_correct_reference'): | |
del self.clicked_correct_reference | |
if self.curret_question < 0: | |
# self.curret_question += 1 | |
return "No more questions!", "No more questions!", "No more questions!", 'No more questions!', 'No more questions!', 'Still need to click the button above to save the results', None, None | |
else: | |
res = self.gpt_result[f'Question {self.curret_question + 1}'] | |
question = res['question'] | |
self.answer = res['answer'] | |
self.highlighted_out = res['original sentences'] | |
highlighted_out_html = self.highlight_text(self.text, self.highlighted_out) | |
self.highlighted_out = '\n'.join(self.highlighted_out) | |
return question, self.answer, highlighted_out_html, 'Please judge on the generated answer', 'Please judge on the generated answer', 'Still need to click the button above to save the results', None, None | |
def download_answer(self, path = './tmp', name = 'answer.xlsx'): | |
os.makedirs(path, exist_ok = True) | |
path = os.path.join(path, name) | |
self.ori_answer_df.to_excel(path, index = False) | |
return path | |
def download_corrected(self, path = './tmp', name = 'corrected_answer.xlsx'): | |
os.makedirs(path, exist_ok = True) | |
path = os.path.join(path, name) | |
self.answer_df.to_excel(path, index = False) | |
return path | |
def change_correct_answer(self, correctness): | |
if correctness == "Correct": | |
self.clicked_correct_answer = True | |
return "No need to change" | |
else: | |
if hasattr(self, 'answer'): | |
self.clicked_correct_answer = False | |
return self.answer | |
else: | |
return "No answer yet, you need to submit the document first" | |
def change_correct_reference(self, correctness): | |
if correctness == "Correct": | |
self.clicked_correct_reference = True | |
return "No need to change" | |
else: | |
if hasattr(self, 'highlighted_out'): | |
self.clicked_correct_reference = False | |
return self.highlighted_out | |
else: | |
return "No answer yet, you need to submit the document first" | |
with gr.Blocks(theme="dark") as demo: | |
backend = Backend() | |
with gr.Row(): | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown(f'<center><h1>Input</h1></center>') | |
gr.Markdown(f'<center><p>Please First Upload the File</p></center>') | |
openai_key = gr.Textbox( | |
label='Enter your OpenAI API key here', | |
type='password') | |
file = gr.File(label='Upload your .txt file here', file_types=['.txt']) | |
questions = gr.CheckboxGroup(choices = QUESTIONS, value = QUESTIONS, label="Questions", info="Please select the question you want to ask") | |
btn_submit_txt = gr.Button(value='Submit txt') | |
btn_submit_txt.style(full_width=True) | |
with gr.Group(): | |
gr.Markdown(f'<center><h1>Output</h1></center>') | |
gr.Markdown(f'<center><p>The answer to your question is :</p></center>') | |
question_box = gr.Textbox(label='Question') | |
answer_box = gr.Textbox(label='Answer') | |
highlighted_text = gr.outputs.HTML(label="Highlighted Text") | |
with gr.Row(): | |
btn_last_question = gr.Button(value='Last Question') | |
btn_next_question = gr.Button(value='Next Question') | |
with gr.Group(): | |
gr.Markdown(f'<center><h1>Correct the Result</h1></center>') | |
gr.Markdown(f'<center><p>Please Correct the Results</p></center>') | |
with gr.Row(): | |
save_results = gr.Textbox(placeholder = "Still need to click the button above to save the results", label = 'Save Results') | |
with gr.Group(): | |
gr.Markdown(f'<center><p>Please Choose: </p></center>') | |
answer_correct = gr.Radio(choices = ["Correct", "Incorrect"], label='Is the Generated Answer Correct?', info="Pease select whether the generated text is correct") | |
correct_answer = gr.Textbox(placeholder = "Please judge on the generated answer", label = 'Correct Answer', interactive = True) | |
reference_correct = gr.Radio(choices = ["Correct", "Incorrect"], label="Is the Reference Correct?", info="Pease select whether the reference is correct") | |
correct_reference = gr.Textbox(placeholder = "Please judge on the generated answer", label = 'Correct Reference', interactive = True) | |
btn_submit_correctness = gr.Button(value='Submit Correctness') | |
btn_submit_correctness.style(full_width=True) | |
with gr.Group(): | |
gr.Markdown(f'<center><h1>Download</h1></center>') | |
gr.Markdown(f'<center><p>Download the processed data and corrected data</p></center>') | |
answer_file = gr.File(label='Download processed data', file_types=['.xlsx']) | |
btn_download_answer = gr.Button(value='Download processed data') | |
btn_download_answer.style(full_width=True) | |
corrected_file = gr.File(label='Download corrected data', file_types=['.xlsx']) | |
btn_download_corrected = gr.Button(value='Download corrected data') | |
btn_download_corrected.style(full_width=True) | |
with gr.Row(): | |
reset = gr.Button(value='Reset') | |
reset.style(full_width=True) | |
# Answer change | |
answer_correct.input( | |
backend.change_correct_answer, | |
inputs = [answer_correct], | |
outputs = [correct_answer], | |
) | |
reference_correct.input( | |
backend.change_correct_reference, | |
inputs = [reference_correct], | |
outputs = [correct_reference], | |
) | |
# Submit button | |
btn_submit_txt.click( | |
backend.process_file, | |
inputs=[file, questions, openai_key], | |
outputs=[question_box, answer_box, highlighted_text, correct_answer, correct_reference], | |
) | |
btn_submit_correctness.click( # TODO | |
backend.process_results, | |
inputs=[answer_correct, correct_answer, reference_correct, correct_reference], | |
outputs=[save_results], | |
) | |
# Switch question button | |
btn_last_question.click( | |
backend.process_last, | |
outputs=[question_box, answer_box, highlighted_text, correct_answer, correct_reference, save_results, answer_correct, reference_correct], | |
) | |
btn_next_question.click( | |
backend.process_next, | |
outputs=[question_box, answer_box, highlighted_text, correct_answer, correct_reference, save_results, answer_correct, reference_correct], | |
) | |
# Download button | |
btn_download_answer.click( | |
backend.download_answer, | |
outputs=[answer_file], | |
) | |
btn_download_corrected.click( | |
backend.download_corrected, | |
outputs=[corrected_file], | |
) | |
demo.queue() | |
demo.launch() |