Commit
b9f3278
ยท
verified ยท
1 Parent(s): 370b6fe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ # Load Hugging Face token from the environment variable
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+ if HF_TOKEN is None:
9
+ raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.")
10
+
11
+ # Check for GPU support and configure appropriately
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Device being used: {device}")
14
+
15
+ # Model configurations
16
+ MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
17
+ SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA"
18
+
19
+ # Load models and tokenizers
20
+ print("Loading MSA to Syrian model...")
21
+ msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL)
22
+ msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device)
23
+
24
+ print("Loading Syrian to MSA model...")
25
+ syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL)
26
+ syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device)
27
+
28
+ print("Models loaded successfully!")
29
+
30
+ def translate_msa_to_syrian(text):
31
+ """Translate from Modern Standard Arabic to Syrian dialect"""
32
+ if not text.strip():
33
+ return ""
34
+
35
+ try:
36
+ input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device)
37
+ outputs = msa_to_syrian_model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
38
+ translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ return translated_text
40
+ except Exception as e:
41
+ return f"Translation error: {str(e)}"
42
+
43
+ def translate_syrian_to_msa(text):
44
+ """Translate from Syrian dialect to Modern Standard Arabic"""
45
+ if not text.strip():
46
+ return ""
47
+
48
+ try:
49
+ input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device)
50
+ outputs = syrian_to_msa_model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
51
+ translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ return translated_text
53
+ except Exception as e:
54
+ return f"Translation error: {str(e)}"
55
+
56
+ def bidirectional_translate(text, direction):
57
+ """Handle bidirectional translation based on user selection"""
58
+ if direction == "MSA โ†’ Syrian":
59
+ return translate_msa_to_syrian(text)
60
+ elif direction == "Syrian โ†’ MSA":
61
+ return translate_syrian_to_msa(text)
62
+ else:
63
+ return "Please select a translation direction"
64
+
65
+ # Create Gradio interface
66
+ with gr.Blocks(title="SHAMI-MT: Bidirectional Arabic Translation") as demo:
67
+
68
+ gr.HTML("""
69
+ <div style="text-align: center; margin-bottom: 2rem;">
70
+ <h1>๐ŸŒ SHAMI-MT: Bidirectional Arabic Translation</h1>
71
+ <p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p>
72
+ <p><strong>Built on AraT5v2-base-1024 architecture</strong></p>
73
+ </div>
74
+ """)
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ gr.HTML("""
79
+ <div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;">
80
+ <h3>๐Ÿ“š Model Information</h3>
81
+ <ul>
82
+ <li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li>
83
+ <li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li>
84
+ <li><strong>Languages:</strong> Arabic (MSA โ†” Syrian Dialect)</li>
85
+ <li><strong>Device:</strong> GPU/CPU Auto-detection</li>
86
+ </ul>
87
+ </div>
88
+ """)
89
+
90
+ with gr.Column(scale=2):
91
+ direction = gr.Dropdown(
92
+ choices=["MSA โ†’ Syrian", "Syrian โ†’ MSA"],
93
+ value="MSA โ†’ Syrian",
94
+ label="Translation Direction"
95
+ )
96
+
97
+ input_text = gr.Textbox(
98
+ label="Input Text",
99
+ placeholder="Enter Arabic text here...",
100
+ lines=5
101
+ )
102
+
103
+ translate_btn = gr.Button("๐Ÿš€ Translate", variant="primary")
104
+
105
+ output_text = gr.Textbox(
106
+ label="Translation",
107
+ lines=5
108
+ )
109
+
110
+ # Connect the interface
111
+ translate_btn.click(
112
+ fn=bidirectional_translate,
113
+ inputs=[input_text, direction],
114
+ outputs=output_text
115
+ )
116
+
117
+ # Add example inputs
118
+ gr.Examples(
119
+ examples=[
120
+ ["ุฃู†ุง ู„ุง ุฃุนุฑู ุฅุฐุง ูƒุงู† ุณูŠุชู…ูƒู† ู…ู† ุงู„ุญุถูˆุฑ ุงู„ูŠูˆู… ุฃู… ู„ุง.", "MSA โ†’ Syrian"],
121
+ ["ูƒูŠู ุญุงู„ูƒุŸ", "MSA โ†’ Syrian"],
122
+ ["ู…ุง ุจุนุฑู ุฅุฐุง ุฑุญ ูŠู‚ุฏุฑ ูŠุฌูŠ ุงู„ูŠูˆู… ูˆู„ุง ู„ุฃ.", "Syrian โ†’ MSA"],
123
+ ["ุดู„ูˆู†ูƒุŸ", "Syrian โ†’ MSA"]
124
+ ],
125
+ inputs=[input_text, direction],
126
+ outputs=output_text,
127
+ fn=bidirectional_translate
128
+ )
129
+
130
+ # Launch the app
131
+ if __name__ == "__main__":
132
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)