Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import yaml | |
from langchain.prompts.chat import ChatPromptTemplate | |
from huggingface_hub import hf_hub_download | |
from spinoza_project.source.backend.llm_utils import ( | |
get_llm, | |
get_llm_api, | |
get_vectorstore, | |
get_vectorstore_api, | |
) | |
from spinoza_project.source.backend.document_store import pickle_to_document_store | |
from spinoza_project.source.backend.get_prompts import get_qa_prompts | |
from spinoza_project.source.frontend.utils import ( | |
make_html_source, | |
make_html_presse_source, | |
make_html_afp_source, | |
make_html_politique_source, | |
parse_output_llm_with_sources, | |
init_env, | |
) | |
from spinoza_project.source.backend.prompt_utils import ( | |
to_chat_instruction, | |
SpecialTokens, | |
) | |
from assets.utils_javascript import ( | |
accordion_trigger, | |
accordion_trigger_end, | |
accordion_trigger_spinoza, | |
accordion_trigger_spinoza_end, | |
update_footer, | |
) | |
init_env() | |
with open("./spinoza_project/config.yaml") as f: | |
config = yaml.full_load(f) | |
prompts = {} | |
for source in config["prompt_naming"]: | |
with open(f"./spinoza_project/prompt_{source}.yaml") as f: | |
prompts[source] = yaml.full_load(f) | |
## Building LLM | |
print("Building LLM") | |
model = "gpt35turbo" | |
llm = get_llm_api() | |
## Loading_tools | |
print("Loading Databases") | |
bdd_presse = get_vectorstore_api("presse") | |
bdd_afp = get_vectorstore_api("afp") | |
qdrants = { | |
tab: pickle_to_document_store( | |
hf_hub_download( | |
repo_id="TestSpinoza/spinoza-database", | |
filename=f"database_{tab}.pickle", | |
repo_type="dataset", | |
force_download=True, | |
) | |
) | |
for tab in config["prompt_naming"] | |
if tab != "Presse" and tab != "AFP" | |
} | |
## Load Prompts | |
print("Loading Prompts") | |
chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {} | |
for source, prompt in prompts.items(): | |
chat_qa_prompt, chat_reformulation_prompt = get_qa_prompts(config, prompt) | |
chat_qa_prompts[source] = chat_qa_prompt | |
chat_reformulation_prompts[source] = chat_reformulation_prompt | |
with open("./assets/style.css", "r") as f: | |
css = f.read() | |
special_tokens = SpecialTokens(config) | |
synthesis_template = """You are a factual journalist that summarize the secialized awnsers from thechnical sources. | |
Based on the folowing question: | |
{question} | |
And the following expert answer: | |
{answers} | |
- When using legal answers, keep tracking of the name of the articles. | |
- When using ADEME answers, name the sources that are mainly used. | |
- List the different elements mentionned, and highlight the agreement points between the sources, as well as the contradictions or differences. | |
- Contradictions don't lie in whether or not a subject is dealt with, but more in the opinion given or the way the subject is dealt with. | |
- Generate the answer as markdown, with an aerated layout, and headlines in bold | |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.", | |
- Do not use the sentence 'Doc i says ...' to say where information came from.", | |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]", | |
- Start by highlighting contradictions, then do a general summary and finally get into the details that might be interesting for article writing. Where relevant, quote them. | |
- Awnser in French / Répond en Français | |
""" | |
synthesis_prompt = to_chat_instruction(synthesis_template, special_tokens) | |
synthesis_prompt_template = ChatPromptTemplate.from_messages([synthesis_prompt]) | |
def zip_longest_fill(*args, fillvalue=None): | |
# zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D- | |
iterators = [iter(it) for it in args] | |
num_active = len(iterators) | |
if not num_active: | |
return | |
cond = True | |
fillvalues = [None] * len(iterators) | |
while cond: | |
values = [] | |
for i, it in enumerate(iterators): | |
try: | |
value = next(it) | |
except StopIteration: | |
value = fillvalues[i] | |
values.append(value) | |
new_cond = False | |
for i, elt in enumerate(values): | |
if elt != fillvalues[i]: | |
new_cond = True | |
cond = new_cond | |
fillvalues = values.copy() | |
yield tuple(values) | |
def format_question(question): | |
return f"{question}" # ### | |
def parse_question(question): | |
x = question.replace("<p>", "").replace("</p>\n", "") | |
if "### " in x: | |
return x.split("### ")[1] | |
return x | |
def reformulate(question, tab, config=config): | |
if tab in list(config["tabs"].keys()): | |
return llm.stream( | |
chat_reformulation_prompts[config["source_mapping"][tab]], | |
{"question": parse_question(question)}, | |
) | |
else: | |
return iter([None] * 5) | |
def reformulate_single_question(question, tab, config=config): | |
for elt in reformulate(question, tab, config=config): | |
time.sleep(0.02) | |
yield elt | |
def reformulate_questions(question, config=config): | |
for elt in zip_longest_fill( | |
*[reformulate(question, tab, config=config) for tab in config["tabs"]] | |
): | |
time.sleep(0.02) | |
yield elt | |
def add_question(question): | |
return question | |
def answer(question, source, tab, config=config): | |
if tab in list(config["tabs"].keys()): | |
if len(source) < 10: | |
return iter(["Aucune source trouvée, veuillez reformuler votre question"]) | |
else: | |
return llm.stream( | |
chat_qa_prompts[config["source_mapping"][tab]], | |
{ | |
"question": parse_question(question), | |
"sources": source.replace("<p>", "").replace("</p>\n", ""), | |
}, | |
) | |
else: | |
return iter([None] * 5) | |
def answer_single_question(source, question, tab, config=config): | |
for elt in answer(question, source, tab, config=config): | |
time.sleep(0.02) | |
yield elt | |
def answer_questions(*questions_sources, config=config): | |
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] | |
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] | |
for elt in zip_longest_fill( | |
*[ | |
answer(question, source, tab, config=config) | |
for question, source, tab in zip(questions, sources, config["tabs"]) | |
] | |
): | |
time.sleep(0.02) | |
yield [ | |
[(question, parse_output_llm_with_sources(ans))] | |
for question, ans in zip(questions, elt) | |
] | |
def get_sources( | |
questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config | |
): | |
k = config["num_document_retrieved"] | |
min_similarity = config["min_similarity"] | |
text, formated = [], [] | |
for i, (question, tab) in enumerate(zip(questions, list(config["tabs"].keys()))): | |
if tab == "Presse": | |
sources = bdd_presse.similarity_search_with_relevance_scores( | |
question.replace("<p>", "").replace("</p>\n", ""), k=k | |
) | |
sources = [ | |
(doc, score) for doc, score in sources if score >= min_similarity | |
] | |
formated.extend( | |
[ | |
make_html_presse_source(source[0], j, source[1]) | |
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) | |
] | |
) | |
elif tab == "AFP": | |
sources = bdd_afp.similarity_search_with_relevance_scores( | |
question.replace("<p>", "").replace("</p>\n", ""), k=k | |
) | |
sources = [ | |
(doc, score) for doc, score in sources if score >= min_similarity | |
] | |
formated.extend( | |
[ | |
make_html_afp_source(source[0], j, source[1]) | |
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) | |
] | |
) | |
elif tab == "Documents Stratégiques": | |
sources = qdrants[ | |
config["source_mapping"][tab] | |
].similarity_search_with_relevance_scores( | |
config["query_preprompt"] | |
+ question.replace("<p>", "").replace("</p>\n", ""), | |
k=k, | |
) | |
sources = [ | |
(doc, score) for doc, score in sources if score >= min_similarity | |
] | |
formated.extend( | |
[ | |
make_html_politique_source(source[0], j, source[1], config) | |
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) | |
] | |
) | |
else: | |
sources = qdrants[ | |
config["source_mapping"][tab] | |
].similarity_search_with_relevance_scores( | |
config["query_preprompt"] | |
+ question.replace("<p>", "").replace("</p>\n", ""), | |
k=k, | |
) | |
sources = [ | |
(doc, score) for doc, score in sources if score >= min_similarity | |
] | |
formated.extend( | |
[ | |
make_html_source(source[0], j, source[1], config) | |
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) | |
] | |
) | |
text.extend( | |
[ | |
"\n\n".join( | |
[ | |
f"Doc {str(j)} with source type {source[0].metadata.get('file_source_type')}:\n" | |
+ source[0].page_content | |
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) | |
] | |
) | |
] | |
) | |
formated = "".join(formated) | |
return formated, text | |
def retrieve_sources( | |
*questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config | |
): | |
formated_sources, text_sources = get_sources( | |
questions, qdrants, bdd_presse, bdd_afp, config | |
) | |
return (formated_sources, *text_sources) | |
def get_synthesis(question, *answers, config=config): | |
answer = [] | |
for i, tab in enumerate(config["tabs"]): | |
if len(str(answers[i])) >= 100: | |
answer.append( | |
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") | |
) | |
if len(answer) == 0: | |
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" | |
else: | |
for elt in llm.stream( | |
synthesis_prompt_template, | |
{ | |
"question": question.replace("<p>", "").replace("</p>\n", ""), | |
"answers": "\n\n".join(answer), | |
}, | |
): | |
time.sleep(0.01) | |
yield [(question, parse_output_llm_with_sources(elt))] | |
theme = gr.themes.Base( | |
primary_hue="blue", | |
secondary_hue="red", | |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
with open("./assets/style.css", "r") as f: | |
css = f.read() | |
with open("./assets/source_information.md", "r") as f: | |
source_information = f.read() | |
def start_agents(): | |
gr.Info(message="The agents and Spinoza are loading...", duration=3) | |
return [ | |
(None, "I am waiting until all the agents are done to generate an answer...") | |
] | |
def end_agents(): | |
gr.Info( | |
message="The agents and Spinoza have finished answering your question", | |
duration=3, | |
) | |
def next_call(): | |
return | |
init_prompt = """ | |
Hello, I am Spinoza, a conversational assistant designed to help you in your journalistic journey. I will answer your questions based **on the provided sources**. | |
⚠️ Limitations | |
*Please note that this chatbot is in an early stage, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* | |
What do you want to learn ? | |
""" | |
with gr.Blocks( | |
title=f"🔍 Spinoza", | |
css=css, | |
js=update_footer(), | |
theme=theme, | |
) as demo: | |
chatbots = {} | |
question = gr.State("") | |
docs_textbox = gr.State([""]) | |
agent_questions = {elt: gr.State("") for elt in config["tabs"]} | |
component_sources = {elt: gr.State("") for elt in config["tabs"]} | |
text_sources = {elt: gr.State("") for elt in config["tabs"]} | |
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} | |
with gr.Tab("Q&A", elem_id="main-component"): | |
with gr.Row(elem_id="chatbot-row"): | |
with gr.Column(scale=2, elem_id="center-panel"): | |
with gr.Group(elem_id="chatbot-group"): | |
with gr.Accordion( | |
"Science agent", | |
open=False, | |
elem_id="accordion-science", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[0]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-science", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"Law agent", | |
open=False, | |
elem_id="accordion-legal", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[1]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-legal", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"Politics agent", | |
open=False, | |
elem_id="accordion-politique", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[2]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-politique", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"ADEME agent", | |
open=False, | |
elem_id="accordion-ademe", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[3]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-ademe", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"Press agent", | |
open=False, | |
elem_id="accordion-presse", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[4]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-presse", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"AFP agent", | |
open=False, | |
elem_id="accordion-afp", | |
elem_classes="accordion", | |
): | |
chatbots[list(config["tabs"].keys())[5]] = gr.Chatbot( | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-afp", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
None, | |
), | |
) | |
with gr.Accordion( | |
"Spinoza", | |
open=True, | |
elem_id="accordion-spinoza", | |
elem_classes="accordion", | |
): | |
chatbots["spinoza"] = gr.Chatbot( | |
value=[(None, init_prompt)], | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id="chatbot-spinoza", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
"./assets/logos/spinoza.png", | |
), | |
) | |
with gr.Row(elem_id="input-message"): | |
ask = gr.Textbox( | |
placeholder="Ask me anything here!", | |
show_label=False, | |
scale=7, | |
lines=1, | |
interactive=True, | |
elem_id="input-textbox", | |
) | |
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
with gr.TabItem("Sources", elem_id="tab-sources", id=0): | |
sources_textbox = gr.HTML( | |
show_label=False, elem_id="sources-textbox" | |
) | |
with gr.Tab("Source information", elem_id="source-component"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(source_information) | |
with gr.Tab("Contact", elem_id="contact-component"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("For any issue contact **[email protected]**.") | |
ask.submit( | |
start_agents, inputs=[], outputs=[chatbots["spinoza"]], js=accordion_trigger() | |
).then( | |
fn=reformulate_questions, | |
inputs=[ask], | |
outputs=[agent_questions[tab] for tab in config["tabs"]], | |
).then( | |
fn=retrieve_sources, | |
inputs=[agent_questions[tab] for tab in config["tabs"]], | |
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], | |
).then( | |
fn=answer_questions, | |
inputs=[agent_questions[tab] for tab in config["tabs"]] | |
+ [text_sources[tab] for tab in config["tabs"]], | |
outputs=[chatbots[tab] for tab in config["tabs"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() | |
).then( | |
fn=get_synthesis, | |
inputs=[agent_questions[list(config["tabs"].keys())[1]]] | |
+ [chatbots[tab] for tab in config["tabs"]], | |
outputs=[chatbots["spinoza"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() | |
).then( | |
fn=end_agents, inputs=[], outputs=[] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) | |