wirag / app.py
nurasaki's picture
Added vdb-v3-wikksplitter metadata
2ad5136
import os
import gradio as gr
from gradio.components import Textbox, Button, Slider, Checkbox
from AinaTheme import theme
from urllib.error import HTTPError
from rag import RAG
from utils import setup
MAX_NEW_TOKENS = 700
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="False") == "True"
import logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s')
setup()
print("Loading RAG model...")
print("Show model parameters in UI: ", SHOW_MODEL_PARAMETERS_IN_UI)
# Load the RAG model
rag = RAG(
vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
vectorstore_path=os.getenv("VECTORSTORE_PATH"),
hf_token=os.getenv("HF_TOKEN"),
embeddings_model=os.getenv("EMBEDDINGS"),
model_name=os.getenv("MODEL"),
rerank_model=os.getenv("RERANK_MODEL"),
rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))
)
def generate(prompt, model_parameters):
try:
output, context, source = rag.get_response(prompt, model_parameters)
return output, context, source
except HTTPError as err:
if err.code == 400:
gr.Warning(
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
)
except:
gr.Warning(
"Inference endpoint is not available right now. Please try again later."
)
return None, None, None
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
"""
Function to handle the input and call the RAG model for inference.
"""
if input_.strip() == "":
gr.Warning("Not possible to inference an empty input")
return None
model_parameters = {
"NUM_CHUNKS": num_chunks,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"top_p": top_p,
"do_sample": do_sample,
"temperature": temperature
}
print("Model parameters: ", model_parameters)
output, context, source = generate(input_, model_parameters)
sources_markup = ""
for url in source:
sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'
return output, sources_markup, context
# return output.strip(), sources_markup, context
def change_interactive(text):
if len(text) == 0:
return gr.update(interactive=True), gr.update(interactive=False)
return gr.update(interactive=True), gr.update(interactive=True)
def clear():
return (
None,
None,
None,
None,
gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
)
def gradio_app():
with gr.Blocks(theme=theme) as demo:
# App Description
# =====================================================================================================================================
with gr.Row():
with gr.Column():
gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""")
with gr.Row(equal_height=False):
# User Input
# =====================================================================================================================================
with gr.Column(scale=2, variant="panel"):
input_ = Textbox(
lines=5,
label="Input",
placeholder="Qui va crear la guerra de les Galaxies ?",
)
with gr.Row(variant="default"):
clear_btn = Button("Clear",)
submit_btn = Button("Submit", variant="primary", interactive=False)
with gr.Row(variant="default"):
num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
# Add Examples manually
gr.Examples( examples=[
["Qui va crear la guerra de les Galaxies?"],
["Quin era el nom real de Voltaire?"],
["Què fan al BSC?"],
# No existèix aquesta entrada a la VDB
# https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic
# ["Què és un Imperi Galàctic?"],
# ["Què és l'Imperi Galàctic d'Isaac Asimov?"],
# ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"]
],
inputs=[input_], # only inputs
)
# Output
# =====================================================================================================================================
with gr.Column(scale=10, variant="panel"):
output = Textbox(
lines=10,
max_lines=25,
label="Output",
interactive=False,
show_copy_button=True
)
with gr.Accordion("Sources and context:", open=False, visible=False):
source_context = gr.Markdown(
label="Sources",
show_label=False,
)
with gr.Accordion("See full context evaluation:", open=False):
context_evaluation = gr.Markdown(
label="Full context",
show_label=False,
# interactive=False,
# autoscroll=False,
# show_copy_button=True
)
# Event Handlers
# =====================================================================================================================================
input_.change(
fn=change_interactive,
inputs=[input_],
outputs=[clear_btn, submit_btn],
api_name=False,
)
input_.change(
fn=None,
inputs=[input_],
api_name=False,
js="""(i, m) => {
document.getElementById('inputlenght').textContent = i.length + ' '
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
}""",
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[input_, output, source_context, context_evaluation, num_chunks],
# outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
queue=False,
api_name=False
)
submit_btn.click(
fn=submit_input,
# inputs=[input_] + parameters_compontents,
inputs=[input_] + [num_chunks],
outputs=[output, source_context, context_evaluation],
api_name="get-results"
)
# =====================================================================================================================================
# # Output
# with gr.Row():
# with gr.Column(scale=0.5):
# gr.Examples(
# examples=[["""Qui va crear la guerra de les Galaxies ?"""],],
# inputs=input_,
# outputs=[output, source_context, context_evaluation],
# fn=submit_input,
# )
# input_, output, source_context, context_evaluation, num_chunks = clear()
demo.launch(show_api=True)
if __name__ == "__main__":
gradio_app()