Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from vllm import LLM, SamplingParams | |
# ------------------------ | |
# 1) Load the Model | |
# ------------------------ | |
# Download the model repository, specify revision if needed | |
model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main", allow_patterns=[ | |
"salamandrata_7b_inst_q4.gguf", | |
"*tokenizer*", | |
"tokenizer_config.json", | |
"tokenizer.model", | |
"config.json", | |
]) | |
model_name = "salamandrata_7b_inst_q4.gguf" | |
# Create an LLM instance from vLLM | |
llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir) | |
# We can define a single helper function to call the model: | |
def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256): | |
""" | |
Sends the prompt to the LLM using vLLM's chat interface. | |
""" | |
messages = [{'role': 'user', 'content': prompt}] | |
outputs = llm.chat( | |
messages, | |
sampling_params=SamplingParams( | |
temperature=temperature, | |
stop_token_ids=[5], # you can adjust the stop token ID if needed | |
max_tokens=max_tokens | |
) | |
) | |
# The model returns a list of "Generation" objects, each containing .outputs | |
return outputs[0].outputs[0].text if outputs else "" | |
# ------------------------ | |
# 2) Task-specific functions | |
# ------------------------ | |
def general_translation(source_lang, target_lang, text): | |
""" | |
General translation prompt: | |
Translate from source_lang into target_lang. | |
""" | |
prompt = ( | |
f"Translate the following text from {source_lang} into {target_lang}.\n" | |
f"{source_lang}: {text}\n" | |
f"{target_lang}:" | |
) | |
return call_model(prompt, temperature=0.1) | |
def post_editing(source_lang, target_lang, source_text, machine_translation): | |
""" | |
Post-editing prompt: | |
Ask the model to fix any mistakes in the machine translation or keep it unedited. | |
""" | |
prompt = ( | |
f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n" | |
f"Source: {source_text}\n" | |
f"MT: {machine_translation}\n" | |
f"Corrected:" | |
) | |
return call_model(prompt, temperature=0.1) | |
def document_level_translation(source_lang, target_lang, document_text): | |
""" | |
Document-level translation prompt: | |
Translate a multi-paragraph document. | |
""" | |
prompt = ( | |
f"Please translate this text from {source_lang} into {target_lang}.\n" | |
f"{source_lang}: {document_text}\n" | |
f"{target_lang}:" | |
) | |
return call_model(prompt, temperature=0.1) | |
def named_entity_recognition(tokenized_text): | |
""" | |
Named-entity recognition prompt: | |
Label tokens as ORG, PER, LOC, MISC, or O. | |
Expects the user to provide a list of tokens. | |
""" | |
# Convert the input string into a list of tokens, if the user typed them as space-separated words | |
# or if the user provided them as a Python list string, we can try to parse that. | |
# For simplicity, let's assume it's a space-separated string. | |
tokens = tokenized_text.strip().split() | |
prompt = ( | |
"Analyse the following tokenized text and mark the tokens containing named entities.\n" | |
"Use the following annotation guidelines with these tags for named entities:\n" | |
"- ORG (Refers to named groups or organizations)\n" | |
"- PER (Refers to individual people or named groups of people)\n" | |
"- LOC (Refers to physical places or natural landmarks)\n" | |
"- MISC (Refers to entities that don't fit into standard categories).\n" | |
"Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n" | |
"If a token is not a named entity, label it as O.\n" | |
f"Input: {tokens}\n" | |
"Marked:" | |
) | |
return call_model(prompt, temperature=0.1) | |
def grammar_checker(source_lang, sentence): | |
""" | |
Grammar checker prompt: | |
Fix any mistakes in the given source_lang sentence or keep it unedited if correct. | |
""" | |
prompt = ( | |
f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n" | |
f"Sentence: {sentence}\n" | |
f"Corrected:" | |
) | |
return call_model(prompt, temperature=0.1) | |
# ------------------------ | |
# 3) Gradio UI | |
# ------------------------ | |
with gr.Blocks() as demo: | |
gr.Markdown("## SalamandraTA-7B-Instruct Demo") | |
gr.Markdown( | |
"This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n" | |
"1. General Translation\n" | |
"2. Post-editing\n" | |
"3. Document-level Translation\n" | |
"4. Named-Entity Recognition (NER)\n" | |
"5. Grammar Checking" | |
) | |
with gr.Tab("1. General Translation"): | |
gr.Markdown("### General Translation") | |
src_lang_gt = gr.Textbox(label="Source Language", value="Spanish") | |
tgt_lang_gt = gr.Textbox(label="Target Language", value="English") | |
text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.") | |
translate_button = gr.Button("Translate") | |
output_gt = gr.Textbox(label="Translation Output", lines=4) | |
translate_button.click(fn=general_translation, | |
inputs=[src_lang_gt, tgt_lang_gt, text_gt], | |
outputs=output_gt) | |
with gr.Tab("2. Post-editing"): | |
gr.Markdown("### Post-editing (Source → Target)") | |
src_lang_pe = gr.Textbox(label="Source Language", value="Catalan") | |
tgt_lang_pe = gr.Textbox(label="Target Language", value="English") | |
source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.") | |
mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.") | |
post_edit_button = gr.Button("Post-edit") | |
output_pe = gr.Textbox(label="Post-edited Text", lines=4) | |
post_edit_button.click(fn=post_editing, | |
inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe], | |
outputs=output_pe) | |
with gr.Tab("3. Document-level Translation"): | |
gr.Markdown("### Document-level Translation") | |
src_lang_doc = gr.Textbox(label="Source Language", value="English") | |
tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian") | |
doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)", | |
lines=8, | |
value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, " | |
"has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico " | |
"and Canada by February 2025.")) | |
doc_button = gr.Button("Translate Document") | |
doc_output = gr.Textbox(label="Document-level Translation Output", lines=8) | |
doc_button.click(fn=document_level_translation, | |
inputs=[src_lang_doc, tgt_lang_doc, doc_text], | |
outputs=doc_output) | |
with gr.Tab("4. Named-Entity Recognition"): | |
gr.Markdown("### Named-Entity Recognition (NER)") | |
text_ner = gr.Textbox( | |
label="Tokenized Text (space-separated tokens)", | |
lines=2, | |
value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso." | |
) | |
ner_button = gr.Button("Run NER") | |
ner_output = gr.Textbox(label="NER Output", lines=6) | |
ner_button.click(fn=named_entity_recognition, | |
inputs=[text_ner], | |
outputs=ner_output) | |
with gr.Tab("5. Grammar Checker"): | |
gr.Markdown("### Grammar Checker") | |
src_lang_gc = gr.Textbox(label="Source Language", value="Catalan") | |
text_gc = gr.Textbox(label="Sentence to Check", | |
lines=2, | |
value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.") | |
gc_button = gr.Button("Check Grammar") | |
gc_output = gr.Textbox(label="Corrected Sentence", lines=2) | |
gc_button.click(fn=grammar_checker, | |
inputs=[src_lang_gc, text_gc], | |
outputs=gc_output) | |
demo.launch() |