Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from gradio.components import Textbox, Button, Slider, Checkbox | |
from AinaTheme import theme | |
from urllib.error import HTTPError | |
from rag import RAG | |
from utils import setup | |
MAX_NEW_TOKENS = 700 | |
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="False") == "True" | |
import logging | |
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s') | |
setup() | |
print("Loading RAG model...") | |
print("Show model parameters in UI: ", SHOW_MODEL_PARAMETERS_IN_UI) | |
# Load the RAG model | |
rag = RAG( | |
vs_hf_repo_path=os.getenv("VS_REPO_NAME"), | |
vectorstore_path=os.getenv("VECTORSTORE_PATH"), | |
hf_token=os.getenv("HF_TOKEN"), | |
embeddings_model=os.getenv("EMBEDDINGS"), | |
model_name=os.getenv("MODEL"), | |
rerank_model=os.getenv("RERANK_MODEL"), | |
rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS")) | |
) | |
def generate(prompt, model_parameters): | |
try: | |
output, context, source = rag.get_response(prompt, model_parameters) | |
return output, context, source | |
except HTTPError as err: | |
if err.code == 400: | |
gr.Warning( | |
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET." | |
) | |
except: | |
gr.Warning( | |
"Inference endpoint is not available right now. Please try again later." | |
) | |
return None, None, None | |
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): | |
""" | |
Function to handle the input and call the RAG model for inference. | |
""" | |
if input_.strip() == "": | |
gr.Warning("Not possible to inference an empty input") | |
return None | |
model_parameters = { | |
"NUM_CHUNKS": num_chunks, | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"top_k": top_k, | |
"top_p": top_p, | |
"do_sample": do_sample, | |
"temperature": temperature | |
} | |
print("Model parameters: ", model_parameters) | |
output, context, source = generate(input_, model_parameters) | |
sources_markup = "" | |
for url in source: | |
sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>' | |
return output, sources_markup, context | |
# return output.strip(), sources_markup, context | |
def change_interactive(text): | |
if len(text) == 0: | |
return gr.update(interactive=True), gr.update(interactive=False) | |
return gr.update(interactive=True), gr.update(interactive=True) | |
def clear(): | |
return ( | |
None, | |
None, | |
None, | |
None, | |
gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True) | |
) | |
def gradio_app(): | |
with gr.Blocks(theme=theme) as demo: | |
# App Description | |
# ===================================================================================================================================== | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""") | |
with gr.Row(equal_height=False): | |
# User Input | |
# ===================================================================================================================================== | |
with gr.Column(scale=2, variant="panel"): | |
input_ = Textbox( | |
lines=5, | |
label="Input", | |
placeholder="Qui va crear la guerra de les Galaxies ?", | |
) | |
with gr.Row(variant="default"): | |
clear_btn = Button("Clear",) | |
submit_btn = Button("Submit", variant="primary", interactive=False) | |
with gr.Row(variant="default"): | |
num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True) | |
# Add Examples manually | |
gr.Examples( examples=[ | |
["Qui va crear la guerra de les Galaxies?"], | |
["Quin era el nom real de Voltaire?"], | |
["Què fan al BSC?"], | |
# No existèix aquesta entrada a la VDB | |
# https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic | |
# ["Què és un Imperi Galàctic?"], | |
# ["Què és l'Imperi Galàctic d'Isaac Asimov?"], | |
# ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"] | |
], | |
inputs=[input_], # only inputs | |
) | |
# Output | |
# ===================================================================================================================================== | |
with gr.Column(scale=10, variant="panel"): | |
output = Textbox( | |
lines=10, | |
max_lines=25, | |
label="Output", | |
interactive=False, | |
show_copy_button=True | |
) | |
with gr.Accordion("Sources and context:", open=False, visible=False): | |
source_context = gr.Markdown( | |
label="Sources", | |
show_label=False, | |
) | |
with gr.Accordion("See full context evaluation:", open=False): | |
context_evaluation = gr.Markdown( | |
label="Full context", | |
show_label=False, | |
# interactive=False, | |
# autoscroll=False, | |
# show_copy_button=True | |
) | |
# Event Handlers | |
# ===================================================================================================================================== | |
input_.change( | |
fn=change_interactive, | |
inputs=[input_], | |
outputs=[clear_btn, submit_btn], | |
api_name=False, | |
) | |
input_.change( | |
fn=None, | |
inputs=[input_], | |
api_name=False, | |
js="""(i, m) => { | |
document.getElementById('inputlenght').textContent = i.length + ' ' | |
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; | |
}""", | |
) | |
clear_btn.click( | |
fn=clear, | |
inputs=[], | |
outputs=[input_, output, source_context, context_evaluation, num_chunks], | |
# outputs=[input_, output, source_context, context_evaluation] + parameters_compontents, | |
queue=False, | |
api_name=False | |
) | |
submit_btn.click( | |
fn=submit_input, | |
# inputs=[input_] + parameters_compontents, | |
inputs=[input_] + [num_chunks], | |
outputs=[output, source_context, context_evaluation], | |
api_name="get-results" | |
) | |
# ===================================================================================================================================== | |
# # Output | |
# with gr.Row(): | |
# with gr.Column(scale=0.5): | |
# gr.Examples( | |
# examples=[["""Qui va crear la guerra de les Galaxies ?"""],], | |
# inputs=input_, | |
# outputs=[output, source_context, context_evaluation], | |
# fn=submit_input, | |
# ) | |
# input_, output, source_context, context_evaluation, num_chunks = clear() | |
demo.launch(show_api=True) | |
if __name__ == "__main__": | |
gradio_app() |