Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,216 Bytes
87d8688 7a6a245 87d8688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
# ------------------------
# 1) Load the Model
# ------------------------
# Download the model repository, specify revision if needed
model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main", allow_patterns=[
"salamandrata_7b_inst_q4.gguf",
"*tokenizer*",
"tokenizer_config.json",
"tokenizer.model",
"config.json",
])
model_name = "salamandrata_7b_inst_q4.gguf"
# Create an LLM instance from vLLM
llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)
# We can define a single helper function to call the model:
def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
"""
Sends the prompt to the LLM using vLLM's chat interface.
"""
messages = [{'role': 'user', 'content': prompt}]
outputs = llm.chat(
messages,
sampling_params=SamplingParams(
temperature=temperature,
stop_token_ids=[5], # you can adjust the stop token ID if needed
max_tokens=max_tokens
)
)
# The model returns a list of "Generation" objects, each containing .outputs
return outputs[0].outputs[0].text if outputs else ""
# ------------------------
# 2) Task-specific functions
# ------------------------
def general_translation(source_lang, target_lang, text):
"""
General translation prompt:
Translate from source_lang into target_lang.
"""
prompt = (
f"Translate the following text from {source_lang} into {target_lang}.\n"
f"{source_lang}: {text}\n"
f"{target_lang}:"
)
return call_model(prompt, temperature=0.1)
def post_editing(source_lang, target_lang, source_text, machine_translation):
"""
Post-editing prompt:
Ask the model to fix any mistakes in the machine translation or keep it unedited.
"""
prompt = (
f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
f"Source: {source_text}\n"
f"MT: {machine_translation}\n"
f"Corrected:"
)
return call_model(prompt, temperature=0.1)
def document_level_translation(source_lang, target_lang, document_text):
"""
Document-level translation prompt:
Translate a multi-paragraph document.
"""
prompt = (
f"Please translate this text from {source_lang} into {target_lang}.\n"
f"{source_lang}: {document_text}\n"
f"{target_lang}:"
)
return call_model(prompt, temperature=0.1)
def named_entity_recognition(tokenized_text):
"""
Named-entity recognition prompt:
Label tokens as ORG, PER, LOC, MISC, or O.
Expects the user to provide a list of tokens.
"""
# Convert the input string into a list of tokens, if the user typed them as space-separated words
# or if the user provided them as a Python list string, we can try to parse that.
# For simplicity, let's assume it's a space-separated string.
tokens = tokenized_text.strip().split()
prompt = (
"Analyse the following tokenized text and mark the tokens containing named entities.\n"
"Use the following annotation guidelines with these tags for named entities:\n"
"- ORG (Refers to named groups or organizations)\n"
"- PER (Refers to individual people or named groups of people)\n"
"- LOC (Refers to physical places or natural landmarks)\n"
"- MISC (Refers to entities that don't fit into standard categories).\n"
"Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
"If a token is not a named entity, label it as O.\n"
f"Input: {tokens}\n"
"Marked:"
)
return call_model(prompt, temperature=0.1)
def grammar_checker(source_lang, sentence):
"""
Grammar checker prompt:
Fix any mistakes in the given source_lang sentence or keep it unedited if correct.
"""
prompt = (
f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
f"Sentence: {sentence}\n"
f"Corrected:"
)
return call_model(prompt, temperature=0.1)
# ------------------------
# 3) Gradio UI
# ------------------------
with gr.Blocks() as demo:
gr.Markdown("## SalamandraTA-7B-Instruct Demo")
gr.Markdown(
"This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
"1. General Translation\n"
"2. Post-editing\n"
"3. Document-level Translation\n"
"4. Named-Entity Recognition (NER)\n"
"5. Grammar Checking"
)
with gr.Tab("1. General Translation"):
gr.Markdown("### General Translation")
src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
translate_button = gr.Button("Translate")
output_gt = gr.Textbox(label="Translation Output", lines=4)
translate_button.click(fn=general_translation,
inputs=[src_lang_gt, tgt_lang_gt, text_gt],
outputs=output_gt)
with gr.Tab("2. Post-editing"):
gr.Markdown("### Post-editing (Source → Target)")
src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
post_edit_button = gr.Button("Post-edit")
output_pe = gr.Textbox(label="Post-edited Text", lines=4)
post_edit_button.click(fn=post_editing,
inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe],
outputs=output_pe)
with gr.Tab("3. Document-level Translation"):
gr.Markdown("### Document-level Translation")
src_lang_doc = gr.Textbox(label="Source Language", value="English")
tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)",
lines=8,
value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
"has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
"and Canada by February 2025."))
doc_button = gr.Button("Translate Document")
doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
doc_button.click(fn=document_level_translation,
inputs=[src_lang_doc, tgt_lang_doc, doc_text],
outputs=doc_output)
with gr.Tab("4. Named-Entity Recognition"):
gr.Markdown("### Named-Entity Recognition (NER)")
text_ner = gr.Textbox(
label="Tokenized Text (space-separated tokens)",
lines=2,
value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
)
ner_button = gr.Button("Run NER")
ner_output = gr.Textbox(label="NER Output", lines=6)
ner_button.click(fn=named_entity_recognition,
inputs=[text_ner],
outputs=ner_output)
with gr.Tab("5. Grammar Checker"):
gr.Markdown("### Grammar Checker")
src_lang_gc = gr.Textbox(label="Source Language", value="Catalan")
text_gc = gr.Textbox(label="Sentence to Check",
lines=2,
value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.")
gc_button = gr.Button("Check Grammar")
gc_output = gr.Textbox(label="Corrected Sentence", lines=2)
gc_button.click(fn=grammar_checker,
inputs=[src_lang_gc, text_gc],
outputs=gc_output)
demo.launch() |