Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,170 +1,78 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
import torch
|
|
|
|
|
5 |
|
6 |
-
# ------------------------
|
7 |
-
# 1) Load the Model
|
8 |
-
# ------------------------
|
9 |
model_id = "BSC-LT/salamandraTA-7b-instruct"
|
|
|
|
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
11 |
model = AutoModelForCausalLM.from_pretrained(
|
12 |
model_id,
|
13 |
device_map="auto",
|
14 |
-
torch_dtype=torch.bfloat16
|
15 |
)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
)
|
50 |
-
return call_model(prompt)
|
51 |
-
|
52 |
-
def post_editing(source_lang, target_lang, source_text, machine_translation):
|
53 |
-
prompt = (
|
54 |
-
f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
|
55 |
-
f"Source: {source_text}\n"
|
56 |
-
f"MT: {machine_translation}\n"
|
57 |
-
f"Corrected:"
|
58 |
-
)
|
59 |
-
return call_model(prompt, temperature=0.1)
|
60 |
-
|
61 |
-
def document_level_translation(source_lang, target_lang, document_text):
|
62 |
-
prompt = (
|
63 |
-
f"Please translate this text from {source_lang} into {target_lang}.\n"
|
64 |
-
f"{source_lang}: {document_text}\n"
|
65 |
-
f"{target_lang}:"
|
66 |
-
)
|
67 |
-
return call_model(prompt)
|
68 |
-
|
69 |
-
def named_entity_recognition(tokenized_text):
|
70 |
-
tokens = tokenized_text.strip().split()
|
71 |
-
prompt = (
|
72 |
-
"Analyse the following tokenized text and mark the tokens containing named entities.\n"
|
73 |
-
"Use the following annotation guidelines with these tags for named entities:\n"
|
74 |
-
"- ORG (Refers to named groups or organizations)\n"
|
75 |
-
"- PER (Refers to individual people or named groups of people)\n"
|
76 |
-
"- LOC (Refers to physical places or natural landmarks)\n"
|
77 |
-
"- MISC (Refers to entities that don't fit into standard categories).\n"
|
78 |
-
"Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
|
79 |
-
"If a token is not a named entity, label it as O.\n"
|
80 |
-
f"Input: {tokens}\n"
|
81 |
-
"Marked:"
|
82 |
-
)
|
83 |
-
return call_model(prompt)
|
84 |
-
|
85 |
-
def grammar_checker(source_lang, sentence):
|
86 |
-
prompt = (
|
87 |
-
f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
|
88 |
-
f"Sentence: {sentence}\n"
|
89 |
-
f"Corrected:"
|
90 |
-
)
|
91 |
-
return call_model(prompt)
|
92 |
-
|
93 |
-
# ------------------------
|
94 |
-
# 3) Gradio UI
|
95 |
-
# ------------------------
|
96 |
-
with gr.Blocks() as demo:
|
97 |
-
gr.Markdown("## SalamandraTA-7B-Instruct Demo")
|
98 |
-
gr.Markdown(
|
99 |
-
"This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
|
100 |
-
"1. General Translation\n"
|
101 |
-
"2. Post-editing\n"
|
102 |
-
"3. Document-level Translation\n"
|
103 |
-
"4. Named-Entity Recognition (NER)\n"
|
104 |
-
"5. Grammar Checking"
|
105 |
-
)
|
106 |
-
|
107 |
-
with gr.Tab("1. General Translation"):
|
108 |
-
gr.Markdown("### General Translation")
|
109 |
-
src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
|
110 |
-
tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
|
111 |
-
text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
|
112 |
-
translate_button = gr.Button("Translate")
|
113 |
-
output_gt = gr.Textbox(label="Translation Output", lines=4)
|
114 |
-
translate_button.click(fn=general_translation,
|
115 |
-
inputs=[src_lang_gt, tgt_lang_gt, text_gt],
|
116 |
-
outputs=output_gt)
|
117 |
-
|
118 |
-
with gr.Tab("2. Post-editing"):
|
119 |
-
gr.Markdown("### Post-editing (Source → Target)")
|
120 |
-
src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
|
121 |
-
tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
|
122 |
-
source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
|
123 |
-
mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
|
124 |
-
post_edit_button = gr.Button("Post-edit")
|
125 |
-
output_pe = gr.Textbox(label="Post-edited Text", lines=4)
|
126 |
-
post_edit_button.click(fn=post_editing,
|
127 |
-
inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe],
|
128 |
-
outputs=output_pe)
|
129 |
-
|
130 |
-
with gr.Tab("3. Document-level Translation"):
|
131 |
-
gr.Markdown("### Document-level Translation")
|
132 |
-
src_lang_doc = gr.Textbox(label="Source Language", value="English")
|
133 |
-
tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
|
134 |
-
doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)",
|
135 |
-
lines=8,
|
136 |
-
value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
|
137 |
-
"has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
|
138 |
-
"and Canada by February 2025."))
|
139 |
-
doc_button = gr.Button("Translate Document")
|
140 |
-
doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
|
141 |
-
doc_button.click(fn=document_level_translation,
|
142 |
-
inputs=[src_lang_doc, tgt_lang_doc, doc_text],
|
143 |
-
outputs=doc_output)
|
144 |
-
|
145 |
-
with gr.Tab("4. Named-Entity Recognition"):
|
146 |
-
gr.Markdown("### Named-Entity Recognition (NER)")
|
147 |
-
text_ner = gr.Textbox(
|
148 |
-
label="Tokenized Text (space-separated tokens)",
|
149 |
-
lines=2,
|
150 |
-
value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
|
151 |
)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
|
|
3 |
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from datetime import datetime
|
6 |
|
|
|
|
|
|
|
7 |
model_id = "BSC-LT/salamandraTA-7b-instruct"
|
8 |
+
|
9 |
+
# Load tokenizer and model
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
11 |
+
|
12 |
model = AutoModelForCausalLM.from_pretrained(
|
13 |
model_id,
|
14 |
device_map="auto",
|
15 |
+
torch_dtype=torch.bfloat16 # Usa bf16 como en el ejemplo original
|
16 |
)
|
17 |
|
18 |
+
languages = [ "Spanish", "Catalan", "English", "French", "German", "Italian", "Portuguese", "Euskera", "Galician",
|
19 |
+
"Bulgarian", "Czech", "Lithuanian", "Croatian", "Dutch", "Romanian", "Danish", "Greek", "Finnish",
|
20 |
+
"Hungarian", "Slovak", "Slovenian", "Estonian", "Polish", "Latvian", "Swedish", "Maltese",
|
21 |
+
"Irish", "Aranese", "Aragonese", "Asturian" ]
|
22 |
+
|
23 |
+
example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."]
|
24 |
+
|
25 |
+
@spaces.GPU(duration=120)
|
26 |
+
def translate(input_text, source, target):
|
27 |
+
sentences = [s for s in input_text.strip().split('\n') if s.strip()]
|
28 |
+
translated_sentences = []
|
29 |
+
|
30 |
+
for sentence in sentences:
|
31 |
+
prompt_text = f"Translate the following text from {source} into {target}.\n{source}: {sentence} \n{target}:"
|
32 |
+
messages = [{"role": "user", "content": prompt_text}]
|
33 |
+
date_string = datetime.today().strftime('%Y-%m-%d')
|
34 |
+
|
35 |
+
prompt = tokenizer.apply_chat_template(
|
36 |
+
messages,
|
37 |
+
tokenize=False,
|
38 |
+
add_generation_prompt=True,
|
39 |
+
date_string=date_string
|
40 |
+
)
|
41 |
+
|
42 |
+
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
|
43 |
+
input_length = inputs.input_ids.shape[1]
|
44 |
+
|
45 |
+
output = model.generate(
|
46 |
+
input_ids=inputs.input_ids,
|
47 |
+
max_new_tokens=400,
|
48 |
+
early_stopping=True,
|
49 |
+
num_beams=5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
+
|
52 |
+
decoded = tokenizer.decode(output[0, input_length:], skip_special_tokens=True).strip()
|
53 |
+
translated_sentences.append(decoded)
|
54 |
+
|
55 |
+
return '\n'.join(translated_sentences), ""
|
56 |
+
|
57 |
+
with gr.Blocks() as demo:
|
58 |
+
gr.HTML("""<html>
|
59 |
+
<head><style>h1 { text-align: center; }</style></head>
|
60 |
+
<body><h1>SalamandraTA 7B Translate</h1></body>
|
61 |
+
</html>""")
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column():
|
65 |
+
source_language_dropdown = gr.Dropdown(choices=languages, value="Catalan", label="Source Language")
|
66 |
+
input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text")
|
67 |
+
with gr.Column():
|
68 |
+
target_language_dropdown = gr.Dropdown(choices=languages, value="English", label="Target Language")
|
69 |
+
translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text")
|
70 |
+
|
71 |
+
info_label = gr.HTML("")
|
72 |
+
btn = gr.Button("Translate")
|
73 |
+
btn.click(translate, inputs=[input_textbox, source_language_dropdown, target_language_dropdown],
|
74 |
+
outputs=[translated_textbox, info_label])
|
75 |
+
gr.Examples(example_sentence, inputs=[input_textbox])
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
demo.launch()
|