javi8979's picture
Update app.py
7a6a245 verified
raw
history blame
9.22 kB
import gradio as gr
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
# ------------------------
# 1) Load the Model
# ------------------------
# Download the model repository, specify revision if needed
model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main", allow_patterns=[
"salamandrata_7b_inst_q4.gguf",
"*tokenizer*",
"tokenizer_config.json",
"tokenizer.model",
"config.json",
])
model_name = "salamandrata_7b_inst_q4.gguf"
# Create an LLM instance from vLLM
llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)
# We can define a single helper function to call the model:
def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
"""
Sends the prompt to the LLM using vLLM's chat interface.
"""
messages = [{'role': 'user', 'content': prompt}]
outputs = llm.chat(
messages,
sampling_params=SamplingParams(
temperature=temperature,
stop_token_ids=[5], # you can adjust the stop token ID if needed
max_tokens=max_tokens
)
)
# The model returns a list of "Generation" objects, each containing .outputs
return outputs[0].outputs[0].text if outputs else ""
# ------------------------
# 2) Task-specific functions
# ------------------------
def general_translation(source_lang, target_lang, text):
"""
General translation prompt:
Translate from source_lang into target_lang.
"""
prompt = (
f"Translate the following text from {source_lang} into {target_lang}.\n"
f"{source_lang}: {text}\n"
f"{target_lang}:"
)
return call_model(prompt, temperature=0.1)
def post_editing(source_lang, target_lang, source_text, machine_translation):
"""
Post-editing prompt:
Ask the model to fix any mistakes in the machine translation or keep it unedited.
"""
prompt = (
f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
f"Source: {source_text}\n"
f"MT: {machine_translation}\n"
f"Corrected:"
)
return call_model(prompt, temperature=0.1)
def document_level_translation(source_lang, target_lang, document_text):
"""
Document-level translation prompt:
Translate a multi-paragraph document.
"""
prompt = (
f"Please translate this text from {source_lang} into {target_lang}.\n"
f"{source_lang}: {document_text}\n"
f"{target_lang}:"
)
return call_model(prompt, temperature=0.1)
def named_entity_recognition(tokenized_text):
"""
Named-entity recognition prompt:
Label tokens as ORG, PER, LOC, MISC, or O.
Expects the user to provide a list of tokens.
"""
# Convert the input string into a list of tokens, if the user typed them as space-separated words
# or if the user provided them as a Python list string, we can try to parse that.
# For simplicity, let's assume it's a space-separated string.
tokens = tokenized_text.strip().split()
prompt = (
"Analyse the following tokenized text and mark the tokens containing named entities.\n"
"Use the following annotation guidelines with these tags for named entities:\n"
"- ORG (Refers to named groups or organizations)\n"
"- PER (Refers to individual people or named groups of people)\n"
"- LOC (Refers to physical places or natural landmarks)\n"
"- MISC (Refers to entities that don't fit into standard categories).\n"
"Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
"If a token is not a named entity, label it as O.\n"
f"Input: {tokens}\n"
"Marked:"
)
return call_model(prompt, temperature=0.1)
def grammar_checker(source_lang, sentence):
"""
Grammar checker prompt:
Fix any mistakes in the given source_lang sentence or keep it unedited if correct.
"""
prompt = (
f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
f"Sentence: {sentence}\n"
f"Corrected:"
)
return call_model(prompt, temperature=0.1)
# ------------------------
# 3) Gradio UI
# ------------------------
with gr.Blocks() as demo:
gr.Markdown("## SalamandraTA-7B-Instruct Demo")
gr.Markdown(
"This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
"1. General Translation\n"
"2. Post-editing\n"
"3. Document-level Translation\n"
"4. Named-Entity Recognition (NER)\n"
"5. Grammar Checking"
)
with gr.Tab("1. General Translation"):
gr.Markdown("### General Translation")
src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
translate_button = gr.Button("Translate")
output_gt = gr.Textbox(label="Translation Output", lines=4)
translate_button.click(fn=general_translation,
inputs=[src_lang_gt, tgt_lang_gt, text_gt],
outputs=output_gt)
with gr.Tab("2. Post-editing"):
gr.Markdown("### Post-editing (Source → Target)")
src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
post_edit_button = gr.Button("Post-edit")
output_pe = gr.Textbox(label="Post-edited Text", lines=4)
post_edit_button.click(fn=post_editing,
inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe],
outputs=output_pe)
with gr.Tab("3. Document-level Translation"):
gr.Markdown("### Document-level Translation")
src_lang_doc = gr.Textbox(label="Source Language", value="English")
tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)",
lines=8,
value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
"has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
"and Canada by February 2025."))
doc_button = gr.Button("Translate Document")
doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
doc_button.click(fn=document_level_translation,
inputs=[src_lang_doc, tgt_lang_doc, doc_text],
outputs=doc_output)
with gr.Tab("4. Named-Entity Recognition"):
gr.Markdown("### Named-Entity Recognition (NER)")
text_ner = gr.Textbox(
label="Tokenized Text (space-separated tokens)",
lines=2,
value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
)
ner_button = gr.Button("Run NER")
ner_output = gr.Textbox(label="NER Output", lines=6)
ner_button.click(fn=named_entity_recognition,
inputs=[text_ner],
outputs=ner_output)
with gr.Tab("5. Grammar Checker"):
gr.Markdown("### Grammar Checker")
src_lang_gc = gr.Textbox(label="Source Language", value="Catalan")
text_gc = gr.Textbox(label="Sentence to Check",
lines=2,
value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.")
gc_button = gr.Button("Check Grammar")
gc_output = gr.Textbox(label="Corrected Sentence", lines=2)
gc_button.click(fn=grammar_checker,
inputs=[src_lang_gc, text_gc],
outputs=gc_output)
demo.launch()