File size: 9,216 Bytes
87d8688
 
 
 
 
 
 
 
7a6a245
 
 
 
 
 
 
87d8688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams

# ------------------------
# 1) Load the Model
# ------------------------
# Download the model repository, specify revision if needed
model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main",     allow_patterns=[
                                                                                                                    "salamandrata_7b_inst_q4.gguf", 
                                                                                                                    "*tokenizer*", 
                                                                                                                    "tokenizer_config.json", 
                                                                                                                    "tokenizer.model", 
                                                                                                                    "config.json",
                                                                                                                ])
model_name = "salamandrata_7b_inst_q4.gguf"

# Create an LLM instance from vLLM
llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)

# We can define a single helper function to call the model:
def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
    """
    Sends the prompt to the LLM using vLLM's chat interface.
    """
    messages = [{'role': 'user', 'content': prompt}]
    outputs = llm.chat(
        messages,
        sampling_params=SamplingParams(
            temperature=temperature,
            stop_token_ids=[5],  # you can adjust the stop token ID if needed
            max_tokens=max_tokens
        )
    )
    # The model returns a list of "Generation" objects, each containing .outputs
    return outputs[0].outputs[0].text if outputs else ""

# ------------------------
# 2) Task-specific functions
# ------------------------

def general_translation(source_lang, target_lang, text):
    """
    General translation prompt:
    Translate from source_lang into target_lang.
    """
    prompt = (
        f"Translate the following text from {source_lang} into {target_lang}.\n"
        f"{source_lang}: {text}\n"
        f"{target_lang}:"
    )
    return call_model(prompt, temperature=0.1)

def post_editing(source_lang, target_lang, source_text, machine_translation):
    """
    Post-editing prompt:
    Ask the model to fix any mistakes in the machine translation or keep it unedited.
    """
    prompt = (
        f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
        f"Source: {source_text}\n"
        f"MT: {machine_translation}\n"
        f"Corrected:"
    )
    return call_model(prompt, temperature=0.1)

def document_level_translation(source_lang, target_lang, document_text):
    """
    Document-level translation prompt:
    Translate a multi-paragraph document.
    """
    prompt = (
        f"Please translate this text from {source_lang} into {target_lang}.\n"
        f"{source_lang}: {document_text}\n"
        f"{target_lang}:"
    )
    return call_model(prompt, temperature=0.1)

def named_entity_recognition(tokenized_text):
    """
    Named-entity recognition prompt:
    Label tokens as ORG, PER, LOC, MISC, or O.
    Expects the user to provide a list of tokens.
    """
    # Convert the input string into a list of tokens, if the user typed them as space-separated words
    # or if the user provided them as a Python list string, we can try to parse that.
    # For simplicity, let's assume it's a space-separated string. 
    tokens = tokenized_text.strip().split()
    
    prompt = (
        "Analyse the following tokenized text and mark the tokens containing named entities.\n"
        "Use the following annotation guidelines with these tags for named entities:\n"
        "- ORG (Refers to named groups or organizations)\n"
        "- PER (Refers to individual people or named groups of people)\n"
        "- LOC (Refers to physical places or natural landmarks)\n"
        "- MISC (Refers to entities that don't fit into standard categories).\n"
        "Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
        "If a token is not a named entity, label it as O.\n"
        f"Input: {tokens}\n"
        "Marked:"
    )
    return call_model(prompt, temperature=0.1)

def grammar_checker(source_lang, sentence):
    """
    Grammar checker prompt:
    Fix any mistakes in the given source_lang sentence or keep it unedited if correct.
    """
    prompt = (
        f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
        f"Sentence: {sentence}\n"
        f"Corrected:"
    )
    return call_model(prompt, temperature=0.1)

# ------------------------
# 3) Gradio UI
# ------------------------
with gr.Blocks() as demo:
    gr.Markdown("## SalamandraTA-7B-Instruct Demo")
    gr.Markdown(
        "This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
        "1. General Translation\n"
        "2. Post-editing\n"
        "3. Document-level Translation\n"
        "4. Named-Entity Recognition (NER)\n"
        "5. Grammar Checking"
    )

    with gr.Tab("1. General Translation"):
        gr.Markdown("### General Translation")
        src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
        tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
        text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
        translate_button = gr.Button("Translate")
        output_gt = gr.Textbox(label="Translation Output", lines=4)
        translate_button.click(fn=general_translation, 
                               inputs=[src_lang_gt, tgt_lang_gt, text_gt], 
                               outputs=output_gt)

    with gr.Tab("2. Post-editing"):
        gr.Markdown("### Post-editing (Source → Target)")
        src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
        tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
        source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
        mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
        post_edit_button = gr.Button("Post-edit")
        output_pe = gr.Textbox(label="Post-edited Text", lines=4)
        post_edit_button.click(fn=post_editing, 
                               inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe], 
                               outputs=output_pe)

    with gr.Tab("3. Document-level Translation"):
        gr.Markdown("### Document-level Translation")
        src_lang_doc = gr.Textbox(label="Source Language", value="English")
        tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
        doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)", 
                              lines=8,
                              value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
                                     "has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
                                     "and Canada by February 2025."))
        doc_button = gr.Button("Translate Document")
        doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
        doc_button.click(fn=document_level_translation, 
                         inputs=[src_lang_doc, tgt_lang_doc, doc_text], 
                         outputs=doc_output)

    with gr.Tab("4. Named-Entity Recognition"):
        gr.Markdown("### Named-Entity Recognition (NER)")
        text_ner = gr.Textbox(
            label="Tokenized Text (space-separated tokens)",
            lines=2,
            value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
        )
        ner_button = gr.Button("Run NER")
        ner_output = gr.Textbox(label="NER Output", lines=6)
        ner_button.click(fn=named_entity_recognition, 
                         inputs=[text_ner], 
                         outputs=ner_output)

    with gr.Tab("5. Grammar Checker"):
        gr.Markdown("### Grammar Checker")
        src_lang_gc = gr.Textbox(label="Source Language", value="Catalan")
        text_gc = gr.Textbox(label="Sentence to Check", 
                             lines=2, 
                             value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.")
        gc_button = gr.Button("Check Grammar")
        gc_output = gr.Textbox(label="Corrected Sentence", lines=2)
        gc_button.click(fn=grammar_checker, 
                        inputs=[src_lang_gc, text_gc], 
                        outputs=gc_output)

demo.launch()