SHAMI-MT-App / app.py
Omartificial-Intelligence-Space's picture
Update app.py
8d1d3f4 verified
import gradio as gr
import torch
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import spaces
# Load Hugging Face token from the environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.")
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
# Model configurations
MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA"
# Load models and tokenizers
print("Loading MSA to Syrian model...")
msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL)
msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device)
print("Loading Syrian to MSA model...")
syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL)
syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device)
print("Models loaded successfully!")
@spaces.GPU(duration=120)
def translate_msa_to_syrian(text):
"""Translate from Modern Standard Arabic to Syrian dialect"""
if not text.strip():
return ""
try:
input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device)
outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
except Exception as e:
return f"Translation error: {str(e)}"
@spaces.GPU(duration=120)
def translate_syrian_to_msa(text):
"""Translate from Syrian dialect to Modern Standard Arabic"""
if not text.strip():
return ""
try:
input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device)
outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
except Exception as e:
return f"Translation error: {str(e)}"
def bidirectional_translate(text, direction):
"""Handle bidirectional translation based on user selection"""
if direction == "MSA โ†’ Syrian":
return translate_msa_to_syrian(text)
elif direction == "Syrian โ†’ MSA":
return translate_syrian_to_msa(text)
else:
return "Please select a translation direction"
# Create Gradio interface
with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 2rem;">
<h1>๐ŸŒ SHAMI-MT: Bidirectional Arabic Translation</h1>
<p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p>
<p><strong>Built on AraT5v2-base-1024 architecture</strong></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""
<div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;">
<h3>๐Ÿ“š Model Information</h3>
<ul>
<li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li>
<li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li>
<li><strong>Languages:</strong> Arabic (MSA โ†” Syrian Dialect)</li>
<li><strong>Device:</strong> GPU/CPU Auto-detection</li>
</ul>
</div>
""")
with gr.Column(scale=2):
direction = gr.Dropdown(
choices=["MSA โ†’ Syrian", "Syrian โ†’ MSA"],
value="MSA โ†’ Syrian",
label="Translation Direction"
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter Arabic text here...",
lines=5
)
translate_btn = gr.Button("๐Ÿš€ Translate", variant="primary")
output_text = gr.Textbox(
label="Translation",
lines=5
)
# Connect the interface
translate_btn.click(
fn=bidirectional_translate,
inputs=[input_text, direction],
outputs=output_text
)
# Add example inputs
gr.Examples(
examples=[
["ุฃู†ุง ู„ุง ุฃุนุฑู ุฅุฐุง ูƒุงู† ุณูŠุชู…ูƒู† ู…ู† ุงู„ุญุถูˆุฑ ุงู„ูŠูˆู… ุฃู… ู„ุง.", "MSA โ†’ Syrian"],
["ูƒูŠู ุญุงู„ูƒุŸ", "MSA โ†’ Syrian"],
["ู…ุง ุจุนุฑู ุฅุฐุง ุฑุญ ูŠู‚ุฏุฑ ูŠุฌูŠ ุงู„ูŠูˆู… ูˆู„ุง ู„ุฃ.", "Syrian โ†’ MSA"],
["ุดู„ูˆู†ูƒุŸ", "Syrian โ†’ MSA"]
],
inputs=[input_text, direction],
outputs=output_text,
fn=bidirectional_translate
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)