import gradio as gr from huggingface_hub import snapshot_download from vllm import LLM, SamplingParams # ------------------------ # 1) Load the Model # ------------------------ # Download the model repository, specify revision if needed model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main", allow_patterns=[ "salamandrata_7b_inst_q4.gguf", "*tokenizer*", "tokenizer_config.json", "tokenizer.model", "config.json", ]) model_name = "salamandrata_7b_inst_q4.gguf" # Create an LLM instance from vLLM llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir) # We can define a single helper function to call the model: def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256): """ Sends the prompt to the LLM using vLLM's chat interface. """ messages = [{'role': 'user', 'content': prompt}] outputs = llm.chat( messages, sampling_params=SamplingParams( temperature=temperature, stop_token_ids=[5], # you can adjust the stop token ID if needed max_tokens=max_tokens ) ) # The model returns a list of "Generation" objects, each containing .outputs return outputs[0].outputs[0].text if outputs else "" # ------------------------ # 2) Task-specific functions # ------------------------ def general_translation(source_lang, target_lang, text): """ General translation prompt: Translate from source_lang into target_lang. """ prompt = ( f"Translate the following text from {source_lang} into {target_lang}.\n" f"{source_lang}: {text}\n" f"{target_lang}:" ) return call_model(prompt, temperature=0.1) def post_editing(source_lang, target_lang, source_text, machine_translation): """ Post-editing prompt: Ask the model to fix any mistakes in the machine translation or keep it unedited. """ prompt = ( f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n" f"Source: {source_text}\n" f"MT: {machine_translation}\n" f"Corrected:" ) return call_model(prompt, temperature=0.1) def document_level_translation(source_lang, target_lang, document_text): """ Document-level translation prompt: Translate a multi-paragraph document. """ prompt = ( f"Please translate this text from {source_lang} into {target_lang}.\n" f"{source_lang}: {document_text}\n" f"{target_lang}:" ) return call_model(prompt, temperature=0.1) def named_entity_recognition(tokenized_text): """ Named-entity recognition prompt: Label tokens as ORG, PER, LOC, MISC, or O. Expects the user to provide a list of tokens. """ # Convert the input string into a list of tokens, if the user typed them as space-separated words # or if the user provided them as a Python list string, we can try to parse that. # For simplicity, let's assume it's a space-separated string. tokens = tokenized_text.strip().split() prompt = ( "Analyse the following tokenized text and mark the tokens containing named entities.\n" "Use the following annotation guidelines with these tags for named entities:\n" "- ORG (Refers to named groups or organizations)\n" "- PER (Refers to individual people or named groups of people)\n" "- LOC (Refers to physical places or natural landmarks)\n" "- MISC (Refers to entities that don't fit into standard categories).\n" "Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n" "If a token is not a named entity, label it as O.\n" f"Input: {tokens}\n" "Marked:" ) return call_model(prompt, temperature=0.1) def grammar_checker(source_lang, sentence): """ Grammar checker prompt: Fix any mistakes in the given source_lang sentence or keep it unedited if correct. """ prompt = ( f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n" f"Sentence: {sentence}\n" f"Corrected:" ) return call_model(prompt, temperature=0.1) # ------------------------ # 3) Gradio UI # ------------------------ with gr.Blocks() as demo: gr.Markdown("## SalamandraTA-7B-Instruct Demo") gr.Markdown( "This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n" "1. General Translation\n" "2. Post-editing\n" "3. Document-level Translation\n" "4. Named-Entity Recognition (NER)\n" "5. Grammar Checking" ) with gr.Tab("1. General Translation"): gr.Markdown("### General Translation") src_lang_gt = gr.Textbox(label="Source Language", value="Spanish") tgt_lang_gt = gr.Textbox(label="Target Language", value="English") text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.") translate_button = gr.Button("Translate") output_gt = gr.Textbox(label="Translation Output", lines=4) translate_button.click(fn=general_translation, inputs=[src_lang_gt, tgt_lang_gt, text_gt], outputs=output_gt) with gr.Tab("2. Post-editing"): gr.Markdown("### Post-editing (Source → Target)") src_lang_pe = gr.Textbox(label="Source Language", value="Catalan") tgt_lang_pe = gr.Textbox(label="Target Language", value="English") source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.") mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.") post_edit_button = gr.Button("Post-edit") output_pe = gr.Textbox(label="Post-edited Text", lines=4) post_edit_button.click(fn=post_editing, inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe], outputs=output_pe) with gr.Tab("3. Document-level Translation"): gr.Markdown("### Document-level Translation") src_lang_doc = gr.Textbox(label="Source Language", value="English") tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian") doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)", lines=8, value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, " "has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico " "and Canada by February 2025.")) doc_button = gr.Button("Translate Document") doc_output = gr.Textbox(label="Document-level Translation Output", lines=8) doc_button.click(fn=document_level_translation, inputs=[src_lang_doc, tgt_lang_doc, doc_text], outputs=doc_output) with gr.Tab("4. Named-Entity Recognition"): gr.Markdown("### Named-Entity Recognition (NER)") text_ner = gr.Textbox( label="Tokenized Text (space-separated tokens)", lines=2, value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso." ) ner_button = gr.Button("Run NER") ner_output = gr.Textbox(label="NER Output", lines=6) ner_button.click(fn=named_entity_recognition, inputs=[text_ner], outputs=ner_output) with gr.Tab("5. Grammar Checker"): gr.Markdown("### Grammar Checker") src_lang_gc = gr.Textbox(label="Source Language", value="Catalan") text_gc = gr.Textbox(label="Sentence to Check", lines=2, value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.") gc_button = gr.Button("Check Grammar") gc_output = gr.Textbox(label="Corrected Sentence", lines=2) gc_button.click(fn=grammar_checker, inputs=[src_lang_gc, text_gc], outputs=gc_output) demo.launch()