Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import os | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import spaces | |
# Load Hugging Face token from the environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN is None: | |
raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.") | |
# Check for GPU support and configure appropriately | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
zero = torch.Tensor([0]).to(device) | |
print(f"Device being used: {zero.device}") | |
# Model configurations | |
MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT" | |
SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA" | |
# Load models and tokenizers | |
print("Loading MSA to Syrian model...") | |
msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL) | |
msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device) | |
print("Loading Syrian to MSA model...") | |
syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL) | |
syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device) | |
print("Models loaded successfully!") | |
def translate_msa_to_syrian(text): | |
"""Translate from Modern Standard Arabic to Syrian dialect""" | |
if not text.strip(): | |
return "" | |
try: | |
input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device) | |
outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) | |
translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
return f"Translation error: {str(e)}" | |
def translate_syrian_to_msa(text): | |
"""Translate from Syrian dialect to Modern Standard Arabic""" | |
if not text.strip(): | |
return "" | |
try: | |
input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device) | |
outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) | |
translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
return f"Translation error: {str(e)}" | |
def bidirectional_translate(text, direction): | |
"""Handle bidirectional translation based on user selection""" | |
if direction == "MSA โ Syrian": | |
return translate_msa_to_syrian(text) | |
elif direction == "Syrian โ MSA": | |
return translate_syrian_to_msa(text) | |
else: | |
return "Please select a translation direction" | |
# Create Gradio interface | |
with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 2rem;"> | |
<h1>๐ SHAMI-MT: Bidirectional Arabic Translation</h1> | |
<p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p> | |
<p><strong>Built on AraT5v2-base-1024 architecture</strong></p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;"> | |
<h3>๐ Model Information</h3> | |
<ul> | |
<li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li> | |
<li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li> | |
<li><strong>Languages:</strong> Arabic (MSA โ Syrian Dialect)</li> | |
<li><strong>Device:</strong> GPU/CPU Auto-detection</li> | |
</ul> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
direction = gr.Dropdown( | |
choices=["MSA โ Syrian", "Syrian โ MSA"], | |
value="MSA โ Syrian", | |
label="Translation Direction" | |
) | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter Arabic text here...", | |
lines=5 | |
) | |
translate_btn = gr.Button("๐ Translate", variant="primary") | |
output_text = gr.Textbox( | |
label="Translation", | |
lines=5 | |
) | |
# Connect the interface | |
translate_btn.click( | |
fn=bidirectional_translate, | |
inputs=[input_text, direction], | |
outputs=output_text | |
) | |
# Add example inputs | |
gr.Examples( | |
examples=[ | |
["ุฃูุง ูุง ุฃุนุฑู ุฅุฐุง ูุงู ุณูุชู ูู ู ู ุงูุญุถูุฑ ุงูููู ุฃู ูุง.", "MSA โ Syrian"], | |
["ููู ุญุงููุ", "MSA โ Syrian"], | |
["ู ุง ุจุนุฑู ุฅุฐุง ุฑุญ ููุฏุฑ ูุฌู ุงูููู ููุง ูุฃ.", "Syrian โ MSA"], | |
["ุดููููุ", "Syrian โ MSA"] | |
], | |
inputs=[input_text, direction], | |
outputs=output_text, | |
fn=bidirectional_translate | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |