AshwinSankar commited on
Commit
3a650f2
ยท
verified ยท
1 Parent(s): 3a45ced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +613 -276
app.py CHANGED
@@ -6,51 +6,102 @@ import gradio as gr
6
  from threading import Thread
7
  from collections.abc import Iterator
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
9
 
 
10
  MAX_MAX_NEW_TOKENS = 4096
11
  MAX_INPUT_TOKEN_LENGTH = 4096
12
  DEFAULT_MAX_NEW_TOKENS = 2048
13
- HF_TOKEN = os.environ["HF_TOKEN"]
14
-
15
- model_id = "ai4bharat/IndicTrans3-beta"
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN
18
- )
19
- tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicTrans3-beta")
20
-
21
-
22
- LANGUAGES = [
23
- "Hindi",
24
- "Bengali",
25
- "Telugu",
26
- "Marathi",
27
- "Tamil",
28
- "Urdu",
29
- "Gujarati",
30
- "Kannada",
31
- "Odia",
32
- "Malayalam",
33
- "Punjabi",
34
- "Assamese",
35
- "Maithili",
36
- "Santali",
37
- "Kashmiri",
38
- "Nepali",
39
- "Sindhi",
40
- "Konkani",
41
- "Dogri",
42
- "Manipuri",
43
- "Bodo",
44
  ]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def format_message_for_translation(message, target_lang):
48
  return f"Translate the following text to {target_lang}: {message}"
49
 
50
-
51
- def store_feedback(rating, feedback_text, chat_history, tgt_lang):
52
  try:
53
-
54
  if not rating:
55
  gr.Warning("Please select a rating before submitting feedback.", duration=5)
56
  return None
@@ -60,16 +111,11 @@ def store_feedback(rating, feedback_text, chat_history, tgt_lang):
60
  return None
61
 
62
  if not chat_history:
63
- gr.Warning(
64
- "Please provide the input text before submitting feedback.", duration=5
65
- )
66
  return None
67
 
68
  if len(chat_history[0]) < 2:
69
- gr.Warning(
70
- "Please translate the input text before submitting feedback.",
71
- duration=5,
72
- )
73
  return None
74
 
75
  conn = psycopg2.connect(
@@ -81,54 +127,42 @@ def store_feedback(rating, feedback_text, chat_history, tgt_lang):
81
  )
82
 
83
  cursor = conn.cursor()
84
-
85
  insert_query = """
86
  INSERT INTO feedback
87
- (tgt_lang, rating, feedback_txt, chat_history)
88
- VALUES (%s, %s, %s, %s)
89
  """
90
-
91
- cursor.execute(
92
- insert_query, (tgt_lang, int(rating), feedback_text, chat_history)
93
- )
94
-
95
  conn.commit()
96
-
97
  cursor.close()
98
  conn.close()
99
-
100
  gr.Info("Thank you for your feedback! ๐Ÿ™", duration=5)
101
 
102
- except:
103
- gr.Error(
104
- "An error occurred while storing feedback. Please try again later.",
105
- duration=5,
106
- )
107
-
108
-
109
- def store_output(tgt_lang, input_text, output_text):
110
-
111
- conn = psycopg2.connect(
112
- host=os.getenv("DB_HOST"),
113
- database=os.getenv("DB_NAME"),
114
- user=os.getenv("DB_USER"),
115
- password=os.getenv("DB_PASSWORD"),
116
- port=os.getenv("DB_PORT"),
117
- )
118
-
119
- cursor = conn.cursor()
120
-
121
- insert_query = """
122
- INSERT INTO translation
123
- (input_txt, output_txt, tgt_lang)
124
- VALUES (%s, %s, %s)
125
- """
126
-
127
- cursor.execute(insert_query, (input_text, output_text, tgt_lang))
128
-
129
- conn.commit()
130
- cursor.close()
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  @spaces.GPU
134
  def translate_message(
@@ -140,254 +174,557 @@ def translate_message(
140
  top_p: float = 0.9,
141
  top_k: int = 50,
142
  repetition_penalty: float = 1.2,
 
143
  ) -> Iterator[str]:
 
 
 
 
 
 
 
144
  conversation = []
145
-
146
  translation_request = format_message_for_translation(message, target_language)
147
-
148
  conversation.append({"role": "user", "content": translation_request})
149
 
150
- input_ids = tokenizer.apply_chat_template(
151
- conversation, return_tensors="pt", add_generation_prompt=True
152
- )
153
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
154
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
155
- gr.Warning(
156
- f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
157
  )
158
- input_ids = input_ids.to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- streamer = TextIteratorStreamer(
161
- tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
162
- )
163
- generate_kwargs = dict(
164
- {"input_ids": input_ids},
165
- streamer=streamer,
166
- max_new_tokens=max_new_tokens,
167
- do_sample=True,
168
- top_p=top_p,
169
- top_k=top_k,
170
- temperature=temperature,
171
- num_beams=1,
172
- repetition_penalty=repetition_penalty,
173
- )
174
- t = Thread(target=model.generate, kwargs=generate_kwargs)
175
- t.start()
176
 
177
- outputs = []
178
- for text in streamer:
179
- outputs.append(text)
180
- yield "".join(outputs)
181
 
182
- store_output(target_language, message, "".join(outputs))
 
 
 
 
 
 
 
183
 
 
 
 
 
 
 
 
 
 
184
 
185
- css = """
186
- # body {
187
- # background-color: #f7f7f7;
188
- # }
189
- .feedback-section {
190
- margin-top: 30px;
191
- border-top: 1px solid #ddd;
192
- padding-top: 20px;
193
  }
194
- .container {
195
- max-width: 90%;
196
- margin: 0 auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  }
198
- .language-selector {
199
- margin-bottom: 20px;
200
- padding: 10px;
201
- background-color: #ffffff;
202
- border-radius: 8px;
203
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  }
 
205
  .advanced-options {
206
- margin-top: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  }
208
  """
209
 
210
- DESCRIPTION = """\
211
- IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across <b>22 Indic languages</b> with high accuracy. It supports <b>document-level machine translation (MT)</b> and is built to match the performance of other leading SOTA models. <br>
212
- ๐Ÿ“ข <b>Training data will be released soon!</b>
213
- <h3>๐Ÿ”น Features</h3>
214
- โœ… Supports <b>22 Indic languages</b><br>
215
- โœ… Enables <b>document-level translation</b><br>
216
- โœ… Achieves <b>SOTA performance</b> in Indic MT<br>
217
- โœ… Optimized for <b>real-world applications</b><br>
218
- <h3>๐Ÿš€ Try It Out!</h3>
219
- 1๏ธโƒฃ Enter text in any supported language<br>
220
- 2๏ธโƒฃ Select the target language<br>
221
- 3๏ธโƒฃ Click <b>Translate</b> and get high-quality results!<br>
222
- Built for <b>linguistic diversity and accessibility</b>, IndicTrans3 is a major step forward in <b>Indic language AI</b>.
223
- ๐Ÿ’ก <b>Source:</b> AI4Bharat | Powered by Hugging Face
224
  """
225
 
226
- with gr.Blocks(css=css) as demo:
227
- with gr.Column(elem_classes="container"):
228
- gr.Markdown(
229
- "# ๐ŸŒ IndicTrans3-beta ๐Ÿš€: Multilingual Translation for 22 Indic Languages </center>"
230
- )
231
- gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
232
 
 
 
 
 
233
  target_language = gr.Dropdown(
234
- LANGUAGES,
235
- value="Hindi",
236
- label="Which language would you like to translate to?",
237
- elem_id="language-dropdown",
238
  )
239
 
240
  chatbot = gr.Chatbot(
241
- height=400,
242
- elem_id="chatbot",
243
  show_copy_button=True,
244
- avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"]
245
- )
 
 
246
 
247
  with gr.Row():
248
  msg = gr.Textbox(
249
- placeholder="Enter a long paragraph to translate...",
250
  show_label=False,
251
  container=False,
252
  scale=9,
 
 
 
 
 
 
253
  )
254
- submit_btn = gr.Button("Translate", scale=1)
255
-
256
- gr.Examples(
257
- examples=[
258
- "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India. Built by Mughal Emperor Shah Jahan in memory of his beloved wife, Mumtaz Mahal, it symbolizes eternal love and devotion. The monument, a UNESCO World Heritage site, attracts millions of visitors each year, who admire its intricate carvings, calligraphy, and symmetrical gardens. At sunrise and sunset, the marble dome glows in hues of pink and golden, creating a breathtaking spectacle. The Taj Mahal is not only a masterpiece of Mughal architecture but also a timeless representation of romance and artistry.",
259
- "Kumbh Mela, the worldโ€™s largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanksโ€”Prayagraj, Haridwar, Nashik, and Ujjain๏ฟฝ๏ฟฝat intervals of 12 years. Millions of devotees, including sadhus, ascetics, and pilgrims, gather to take a holy dip in the river, believing it washes away sins and grants salvation. The festival is marked by grand processions, religious discourses, and vibrant cultural events. With its rich traditions, ancient rituals, and immense scale, Kumbh Mela is not just a religious event but also a profound representation of Indiaโ€™s spiritual and cultural heritage, fostering faith and unity among millions worldwide.",
260
- "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, Kuchipudi, and Kathakali, are deeply rooted in tradition and storytelling. These dance styles blend intricate footwork, graceful hand gestures, and expressive facial expressions to narrate mythological tales and historical legends. Bharatanatyam, originating from Tamil Nadu, is known for its rhythmic precision, while Kathak, from North India, features rapid spins and foot-tapping movements. Odissi, from Odisha, showcases fluid postures inspired by temple sculptures. Each form carries a distinct cultural essence, preserving centuries-old traditions while continuing to evolve in contemporary performances, keeping Indiaโ€™s rich artistic heritage alive and thriving.",
261
- "Ayurveda, Indiaโ€™s ancient medical system, emphasizes a holistic approach to health by balancing the mind, body, and spirit. Rooted in nature, it promotes well-being through herbal medicines, dietary guidelines, yoga, and meditation. Ayurveda classifies individuals based on three doshasโ€”Vata, Pitta, and Kaphaโ€”determining their physical and mental constitution. Remedies include plant-based treatments, detox therapies, and rejuvenation practices to prevent and heal ailments. Unlike modern medicine, Ayurveda focuses on personalized healing and long-term wellness. With growing global interest in alternative medicine, Ayurveda continues to gain recognition for its effectiveness in promoting natural healing and overall health optimization.",
262
- "Diwali, the festival of lights, is one of Indiaโ€™s most celebrated festivals, symbolizing the victory of light over darkness and good over evil. Families clean and decorate their homes with colorful rangoli, oil lamps, and twinkling fairy lights. The festival marks the return of Lord Rama to Ayodhya after defeating Ravana, and it also honors Goddess Lakshmi, the deity of wealth and prosperity. Fireworks illuminate the night sky, while families exchange sweets and gifts, spreading joy and togetherness. Beyond its religious significance, Diwali fosters unity, strengthens relationships, and brings communities together in a spirit of happiness and renewal.",
263
- ],
264
- example_labels=[
265
- "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra...",
266
- "Kumbh Mela, the worldโ€™s largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks...",
267
- "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, Kuchipudi, and Kathakali, are deeply rooted in tradition...",
268
- "Ayurveda, Indiaโ€™s ancient medical system, emphasizes a holistic approach to health by balancing the mind, body, and spirit...",
269
- "Diwali, the festival of lights, is one of Indiaโ€™s most celebrated festivals, symbolizing the victory of light over darkness...",
270
- ],
271
- inputs=msg,
272
- )
273
 
274
- with gr.Accordion("Provide Feedback", open=True):
275
- gr.Markdown("## Rate Translation & Provide Feedback ๐Ÿ“")
276
- gr.Markdown(
277
- "Help us improve the translation quality by providing your feedback."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  )
 
 
 
 
 
 
279
  with gr.Row():
280
  rating = gr.Radio(
281
- ["1", "2", "3", "4", "5"], label="Translation Rating (1-5)"
 
 
282
  )
283
 
284
  feedback_text = gr.Textbox(
285
- placeholder="Share your feedback about the translation...",
286
- label="Feedback",
287
  lines=3,
288
  )
289
 
290
- feedback_submit = gr.Button("Submit Feedback")
291
- feedback_result = gr.Textbox(label="", visible=False)
292
-
293
- with gr.Accordion(
294
- "Advanced Options", open=False, elem_classes="advanced-options"
295
- ):
296
- max_new_tokens = gr.Slider(
297
- label="Max new tokens",
298
- minimum=1,
299
- maximum=MAX_MAX_NEW_TOKENS,
300
- step=1,
301
- value=DEFAULT_MAX_NEW_TOKENS,
302
- )
303
- temperature = gr.Slider(
304
- label="Temperature",
305
- minimum=0.1,
306
- maximum=1.0,
307
- step=0.1,
308
- value=0.1,
309
- )
310
- top_p = gr.Slider(
311
- label="Top-p (nucleus sampling)",
312
- minimum=0.05,
313
- maximum=1.0,
314
- step=0.05,
315
- value=0.9,
316
- )
317
- top_k = gr.Slider(
318
- label="Top-k",
319
- minimum=1,
320
- maximum=100,
321
- step=1,
322
- value=50,
323
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  repetition_penalty = gr.Slider(
325
- label="Repetition penalty",
326
  minimum=1.0,
327
  maximum=2.0,
328
  step=0.05,
329
  value=1.0,
 
330
  )
331
 
332
- chat_state = gr.State([])
333
-
334
- def user(user_message, history, target_lang):
335
- return "", history + [[user_message, None]]
336
-
337
- def bot(
338
- history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty
339
- ):
340
- user_message = history[-1][0]
341
- history[-1][1] = ""
342
-
343
- for chunk in translate_message(
344
- user_message,
345
- history[:-1],
346
- target_lang,
347
- max_tokens,
348
- temp,
349
- top_p_val,
350
- top_k_val,
351
- rep_penalty,
352
- ):
353
- history[-1][1] = chunk
354
- yield history
355
-
356
- msg.submit(
357
- user, [msg, chatbot, target_language], [msg, chatbot], queue=False
358
- ).then(
359
- bot,
360
- [
361
- chatbot,
362
- target_language,
363
- max_new_tokens,
364
- temperature,
365
- top_p,
366
- top_k,
367
- repetition_penalty,
368
- ],
369
- chatbot,
370
- )
371
 
372
- submit_btn.click(
373
- user, [msg, chatbot, target_language], [msg, chatbot], queue=False
374
- ).then(
375
- bot,
376
- [
377
- chatbot,
378
- target_language,
379
- max_new_tokens,
380
- temperature,
381
- top_p,
382
- top_k,
383
- repetition_penalty,
384
- ],
385
- chatbot,
386
- )
387
 
388
- feedback_submit.click(
389
- fn=store_feedback,
390
- inputs=[rating, feedback_text, chatbot, target_language],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  if __name__ == "__main__":
393
- demo.launch()
 
 
 
 
 
 
 
6
  from threading import Thread
7
  from collections.abc import Iterator
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ import gc
10
 
11
+ # Constants
12
  MAX_MAX_NEW_TOKENS = 4096
13
  MAX_INPUT_TOKEN_LENGTH = 4096
14
  DEFAULT_MAX_NEW_TOKENS = 2048
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
16
+
17
+ # Language lists
18
+ INDIC_LANGUAGES = [
19
+ "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati",
20
+ "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili",
21
+ "Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri",
22
+ "Manipuri", "Bodo", "English", "Sanskrit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ]
24
 
25
+ SARVAM_LANGUAGES = INDIC_LANGUAGES
26
+
27
+ # Model configurations with optimizations
28
+ TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
29
+ DEVICE_MAP = "auto" if torch.cuda.is_available() else None
30
+
31
+ class ModelManager:
32
+ def __init__(self):
33
+ self.indictrans_model = None
34
+ self.indictrans_tokenizer = None
35
+ self.sarvam_model = None
36
+ self.sarvam_tokenizer = None
37
+ self.current_model = None
38
+
39
+ def load_indictrans_model(self):
40
+ if self.indictrans_model is None:
41
+ try:
42
+ self.indictrans_model = AutoModelForCausalLM.from_pretrained(
43
+ "ai4bharat/IndicTrans3-beta",
44
+ torch_dtype=TORCH_DTYPE,
45
+ device_map=DEVICE_MAP,
46
+ token=HF_TOKEN,
47
+ use_cache=True, # Enable KV cache
48
+ low_cpu_mem_usage=True,
49
+ trust_remote_code=True
50
+ )
51
+ self.indictrans_tokenizer = AutoTokenizer.from_pretrained(
52
+ "ai4bharat/IndicTrans3-beta",
53
+ trust_remote_code=True
54
+ )
55
+ # Enable optimizations
56
+ if hasattr(self.indictrans_model, 'eval'):
57
+ self.indictrans_model.eval()
58
+ if torch.cuda.is_available():
59
+ torch.cuda.empty_cache()
60
+ except Exception as e:
61
+ print(f"Error loading IndicTrans model: {e}")
62
+
63
+ def load_sarvam_model(self):
64
+ if self.sarvam_model is None:
65
+ try:
66
+ self.sarvam_model = AutoModelForCausalLM.from_pretrained(
67
+ "sarvamai/sarvam-translate",
68
+ torch_dtype=TORCH_DTYPE,
69
+ device_map=DEVICE_MAP,
70
+ token=HF_TOKEN,
71
+ use_cache=True, # Enable KV cache
72
+ low_cpu_mem_usage=True,
73
+ trust_remote_code=True
74
+ )
75
+ self.sarvam_tokenizer = AutoTokenizer.from_pretrained(
76
+ "sarvamai/sarvam-translate",
77
+ trust_remote_code=True
78
+ )
79
+ # Enable optimizations
80
+ if hasattr(self.sarvam_model, 'eval'):
81
+ self.sarvam_model.eval()
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+ except Exception as e:
85
+ print(f"Error loading Sarvam model: {e}")
86
+
87
+ def get_model_and_tokenizer(self, model_type):
88
+ if model_type == "indictrans":
89
+ if self.indictrans_model is None:
90
+ self.load_indictrans_model()
91
+ return self.indictrans_model, self.indictrans_tokenizer
92
+ else: # sarvam
93
+ if self.sarvam_model is None:
94
+ self.load_sarvam_model()
95
+ return self.sarvam_model, self.sarvam_tokenizer
96
+
97
+ # Global model manager
98
+ model_manager = ModelManager()
99
 
100
  def format_message_for_translation(message, target_lang):
101
  return f"Translate the following text to {target_lang}: {message}"
102
 
103
+ def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type):
 
104
  try:
 
105
  if not rating:
106
  gr.Warning("Please select a rating before submitting feedback.", duration=5)
107
  return None
 
111
  return None
112
 
113
  if not chat_history:
114
+ gr.Warning("Please provide the input text before submitting feedback.", duration=5)
 
 
115
  return None
116
 
117
  if len(chat_history[0]) < 2:
118
+ gr.Warning("Please translate the input text before submitting feedback.", duration=5)
 
 
 
119
  return None
120
 
121
  conn = psycopg2.connect(
 
127
  )
128
 
129
  cursor = conn.cursor()
 
130
  insert_query = """
131
  INSERT INTO feedback
132
+ (tgt_lang, rating, feedback_txt, chat_history, model_type)
133
+ VALUES (%s, %s, %s, %s, %s)
134
  """
135
+ cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type))
 
 
 
 
136
  conn.commit()
 
137
  cursor.close()
138
  conn.close()
 
139
  gr.Info("Thank you for your feedback! ๐Ÿ™", duration=5)
140
 
141
+ except Exception as e:
142
+ print(f"Database error: {e}")
143
+ gr.Error("An error occurred while storing feedback. Please try again later.", duration=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ def store_output(tgt_lang, input_text, output_text, model_type):
146
+ try:
147
+ conn = psycopg2.connect(
148
+ host=os.getenv("DB_HOST"),
149
+ database=os.getenv("DB_NAME"),
150
+ user=os.getenv("DB_USER"),
151
+ password=os.getenv("DB_PASSWORD"),
152
+ port=os.getenv("DB_PORT"),
153
+ )
154
+ cursor = conn.cursor()
155
+ insert_query = """
156
+ INSERT INTO translation
157
+ (input_txt, output_txt, tgt_lang, model_type)
158
+ VALUES (%s, %s, %s, %s)
159
+ """
160
+ cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type))
161
+ conn.commit()
162
+ cursor.close()
163
+ conn.close()
164
+ except Exception as e:
165
+ print(f"Database error: {e}")
166
 
167
  @spaces.GPU
168
  def translate_message(
 
174
  top_p: float = 0.9,
175
  top_k: int = 50,
176
  repetition_penalty: float = 1.2,
177
+ model_type: str = "indictrans"
178
  ) -> Iterator[str]:
179
+
180
+ model, tokenizer = model_manager.get_model_and_tokenizer(model_type)
181
+
182
+ if model is None or tokenizer is None:
183
+ yield "Error: Model failed to load. Please try again."
184
+ return
185
+
186
  conversation = []
 
187
  translation_request = format_message_for_translation(message, target_language)
 
188
  conversation.append({"role": "user", "content": translation_request})
189
 
190
+ try:
191
+ input_ids = tokenizer.apply_chat_template(
192
+ conversation, return_tensors="pt", add_generation_prompt=True
 
 
 
 
193
  )
194
+
195
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
196
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
197
+ gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
198
+
199
+ input_ids = input_ids.to(model.device)
200
+
201
+ streamer = TextIteratorStreamer(
202
+ tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
203
+ )
204
+
205
+ generate_kwargs = {
206
+ "input_ids": input_ids,
207
+ "streamer": streamer,
208
+ "max_new_tokens": max_new_tokens,
209
+ "do_sample": True,
210
+ "top_p": top_p,
211
+ "top_k": top_k,
212
+ "temperature": temperature,
213
+ "num_beams": 1,
214
+ "repetition_penalty": repetition_penalty,
215
+ "use_cache": True, # Enable KV cache
216
+ "pad_token_id": tokenizer.eos_token_id,
217
+ }
218
+
219
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
220
+ t.start()
221
+
222
+ outputs = []
223
+ for text in streamer:
224
+ outputs.append(text)
225
+ yield "".join(outputs)
226
+
227
+ # Clean up
228
+ if torch.cuda.is_available():
229
+ torch.cuda.empty_cache()
230
+ gc.collect()
231
+
232
+ store_output(target_language, message, "".join(outputs), model_type)
233
+
234
+ except Exception as e:
235
+ yield f"Translation error: {str(e)}"
236
+
237
+ # Enhanced CSS with beautiful styling
238
+ css = """
239
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
240
 
241
+ * {
242
+ font-family: 'Inter', sans-serif;
243
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ .gradio-container {
246
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
247
+ min-height: 100vh;
248
+ }
249
 
250
+ .main-container {
251
+ background: rgba(255, 255, 255, 0.95);
252
+ backdrop-filter: blur(10px);
253
+ border-radius: 20px;
254
+ padding: 2rem;
255
+ margin: 1rem;
256
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
257
+ }
258
 
259
+ .title-container {
260
+ text-align: center;
261
+ margin-bottom: 2rem;
262
+ padding: 1rem;
263
+ background: linear-gradient(45deg, #667eea, #764ba2);
264
+ -webkit-background-clip: text;
265
+ -webkit-text-fill-color: transparent;
266
+ background-clip: text;
267
+ }
268
 
269
+ .model-tab {
270
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
271
+ border: none;
272
+ border-radius: 15px;
273
+ color: white;
274
+ font-weight: 600;
275
+ padding: 1rem 2rem;
276
+ transition: all 0.3s ease;
277
  }
278
+
279
+ .model-tab:hover {
280
+ transform: translateY(-2px);
281
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.2);
282
+ }
283
+
284
+ .language-dropdown {
285
+ background: white;
286
+ border: 2px solid #e2e8f0;
287
+ border-radius: 12px;
288
+ padding: 0.75rem;
289
+ font-size: 16px;
290
+ transition: all 0.3s ease;
291
+ }
292
+
293
+ .language-dropdown:focus {
294
+ border-color: #667eea;
295
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
296
+ }
297
+
298
+ .chat-container {
299
+ background: white;
300
+ border-radius: 15px;
301
+ padding: 1rem;
302
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
303
+ margin: 1rem 0;
304
+ }
305
+
306
+ .message-input {
307
+ border: 2px solid #e2e8f0;
308
+ border-radius: 12px;
309
+ padding: 1rem;
310
+ font-size: 16px;
311
+ transition: all 0.3s ease;
312
+ background: white;
313
  }
314
+
315
+ .message-input:focus {
316
+ border-color: #667eea;
317
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
318
+ }
319
+
320
+ .translate-btn {
321
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
322
+ border: none;
323
+ border-radius: 12px;
324
+ color: white;
325
+ font-weight: 600;
326
+ padding: 1rem 2rem;
327
+ font-size: 16px;
328
+ cursor: pointer;
329
+ transition: all 0.3s ease;
330
+ }
331
+
332
+ .translate-btn:hover {
333
+ transform: translateY(-2px);
334
+ box-shadow: 0 10px 25px rgba(102, 126, 234, 0.3);
335
+ }
336
+
337
+ .examples-container {
338
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
339
+ border-radius: 15px;
340
+ padding: 1.5rem;
341
+ margin: 1rem 0;
342
+ }
343
+
344
+ .feedback-section {
345
+ background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
346
+ border-radius: 15px;
347
+ padding: 1.5rem;
348
+ margin: 1rem 0;
349
+ border: none;
350
  }
351
+
352
  .advanced-options {
353
+ background: linear-gradient(135deg, #d299c2 0%, #fef9d7 100%);
354
+ border-radius: 15px;
355
+ padding: 1.5rem;
356
+ margin: 1rem 0;
357
+ }
358
+
359
+ .slider-container .gr-slider {
360
+ background: linear-gradient(90deg, #667eea, #764ba2);
361
+ }
362
+
363
+ .rating-container {
364
+ display: flex;
365
+ gap: 1rem;
366
+ justify-content: center;
367
+ margin: 1rem 0;
368
+ }
369
+
370
+ .feedback-btn {
371
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
372
+ border: none;
373
+ border-radius: 12px;
374
+ color: white;
375
+ font-weight: 600;
376
+ padding: 0.75rem 1.5rem;
377
+ cursor: pointer;
378
+ transition: all 0.3s ease;
379
+ }
380
+
381
+ .feedback-btn:hover {
382
+ transform: translateY(-2px);
383
+ box-shadow: 0 8px 20px rgba(240, 147, 251, 0.3);
384
+ }
385
+
386
+ .stats-card {
387
+ background: rgba(255, 255, 255, 0.8);
388
+ border-radius: 12px;
389
+ padding: 1rem;
390
+ text-align: center;
391
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
392
+ margin: 0.5rem;
393
+ }
394
+
395
+ .model-info {
396
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
397
+ color: white;
398
+ border-radius: 12px;
399
+ padding: 1rem;
400
+ margin: 1rem 0;
401
+ }
402
+
403
+ .animate-pulse {
404
+ animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
405
+ }
406
+
407
+ @keyframes pulse {
408
+ 0%, 100% {
409
+ opacity: 1;
410
+ }
411
+ 50% {
412
+ opacity: .5;
413
+ }
414
+ }
415
+
416
+ .loading-spinner {
417
+ border: 4px solid #f3f3f3;
418
+ border-top: 4px solid #667eea;
419
+ border-radius: 50%;
420
+ width: 40px;
421
+ height: 40px;
422
+ animation: spin 2s linear infinite;
423
+ margin: 0 auto;
424
+ }
425
+
426
+ @keyframes spin {
427
+ 0% { transform: rotate(0deg); }
428
+ 100% { transform: rotate(360deg); }
429
  }
430
  """
431
 
432
+ # Model descriptions
433
+ INDICTRANS_DESCRIPTION = """
434
+ <div class="model-info">
435
+ <h3>๐ŸŒŸ IndicTrans3-Beta</h3>
436
+ <p><strong>Latest SOTA translation model from AI4Bharat</strong></p>
437
+ <ul>
438
+ <li>โœ… Supports <strong>22 Indic languages</strong></li>
439
+ <li>โœ… Document-level machine translation</li>
440
+ <li>โœ… Optimized for real-world applications</li>
441
+ <li>โœ… Enhanced with KV caching for faster inference</li>
442
+ </ul>
443
+ </div>
 
 
444
  """
445
 
446
+ SARVAM_DESCRIPTION = """
447
+ <div class="model-info">
448
+ <h3>๐Ÿš€ Sarvam Translate</h3>
449
+ <p><strong>Advanced multilingual translation model</strong></p>
450
+ <ul>
451
+ <li>โœ… Supports <strong>22 Indic languages</strong></li>
452
+ <li>โœ… High-quality translations</li>
453
+ <li>โœ… Document-level machine translation</li>
454
+ <li>โœ… Optimized for real-world applications</li>
455
+ <li>โœ… Optimized for production use</li>
456
+ <li>โœ… Enhanced with KV caching for faster inference</li>
457
+ </ul>
458
+ </div>
459
+ """
460
 
461
+ def create_chatbot_interface(model_type, languages, description):
462
+ with gr.Column(elem_classes="main-container"):
463
+ gr.Markdown(description)
464
+
465
  target_language = gr.Dropdown(
466
+ languages,
467
+ value=languages[0],
468
+ label="๐ŸŒ Select Target Language",
469
+ elem_classes="language-dropdown",
470
  )
471
 
472
  chatbot = gr.Chatbot(
473
+ height=500,
474
+ elem_classes="chat-container",
475
  show_copy_button=True,
476
+ avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"],
477
+ bubble_full_width=False,
478
+ show_label=False
479
+ )
480
 
481
  with gr.Row():
482
  msg = gr.Textbox(
483
+ placeholder="โœ๏ธ Enter text to translate...",
484
  show_label=False,
485
  container=False,
486
  scale=9,
487
+ elem_classes="message-input",
488
+ )
489
+ submit_btn = gr.Button(
490
+ "๐Ÿ”„ Translate",
491
+ scale=1,
492
+ elem_classes="translate-btn"
493
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
+ # Examples section
496
+ if model_type == "indictrans":
497
+ examples_data = [
498
+ "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.",
499
+ "Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.",
500
+ "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.",
501
+ "Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.",
502
+ "Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil."
503
+ ]
504
+ else:
505
+ examples_data = [
506
+ "Hello, how are you today?",
507
+ "I love learning new languages and cultures.",
508
+ "Technology is transforming the way we communicate.",
509
+ "The weather is beautiful today.",
510
+ "Thank you for your help and support."
511
+ ]
512
+
513
+ with gr.Accordion("๐Ÿ“š Example Texts", open=False, elem_classes="examples-container"):
514
+ gr.Examples(
515
+ examples=examples_data,
516
+ inputs=msg,
517
+ label="Click on any example to try:"
518
  )
519
+
520
+ # Feedback section
521
+ with gr.Accordion("๐Ÿ’ญ Provide Feedback", open=False, elem_classes="feedback-section"):
522
+ gr.Markdown("### ๐Ÿ“ Rate Translation & Share Feedback")
523
+ gr.Markdown("Help us improve translation quality with your valuable feedback!")
524
+
525
  with gr.Row():
526
  rating = gr.Radio(
527
+ ["1", "2", "3", "4", "5"],
528
+ label="๐Ÿ† Translation Quality Rating",
529
+ value=None
530
  )
531
 
532
  feedback_text = gr.Textbox(
533
+ placeholder="๐Ÿ’ฌ Share your thoughts about the translation quality, accuracy, or suggestions for improvement...",
534
+ label="๐Ÿ“ Your Feedback",
535
  lines=3,
536
  )
537
 
538
+ feedback_submit = gr.Button(
539
+ "๐Ÿ“ค Submit Feedback",
540
+ elem_classes="feedback-btn"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  )
542
+
543
+ # Advanced options
544
+ with gr.Accordion("โš™๏ธ Advanced Settings", open=False, elem_classes="advanced-options"):
545
+ gr.Markdown("### ๐Ÿ”ง Fine-tune Translation Parameters")
546
+
547
+ with gr.Row():
548
+ max_new_tokens = gr.Slider(
549
+ label="๐Ÿ“ Max New Tokens",
550
+ minimum=1,
551
+ maximum=MAX_MAX_NEW_TOKENS,
552
+ step=1,
553
+ value=DEFAULT_MAX_NEW_TOKENS,
554
+ elem_classes="slider-container"
555
+ )
556
+ temperature = gr.Slider(
557
+ label="๐ŸŒก๏ธ Temperature",
558
+ minimum=0.1,
559
+ maximum=1.0,
560
+ step=0.1,
561
+ value=0.1,
562
+ elem_classes="slider-container"
563
+ )
564
+
565
+ with gr.Row():
566
+ top_p = gr.Slider(
567
+ label="๐ŸŽฏ Top-p (Nucleus Sampling)",
568
+ minimum=0.05,
569
+ maximum=1.0,
570
+ step=0.05,
571
+ value=0.9,
572
+ elem_classes="slider-container"
573
+ )
574
+ top_k = gr.Slider(
575
+ label="๐Ÿ” Top-k",
576
+ minimum=1,
577
+ maximum=100,
578
+ step=1,
579
+ value=50,
580
+ elem_classes="slider-container"
581
+ )
582
+
583
  repetition_penalty = gr.Slider(
584
+ label="๐Ÿ”„ Repetition Penalty",
585
  minimum=1.0,
586
  maximum=2.0,
587
  step=0.05,
588
  value=1.0,
589
+ elem_classes="slider-container"
590
  )
591
 
592
+ return (chatbot, msg, submit_btn, target_language, rating, feedback_text,
593
+ feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
+ def user(user_message, history, target_lang):
596
+ return "", history + [[user_message, None]]
597
+
598
+ def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type):
599
+ user_message = history[-1][0]
600
+ history[-1][1] = ""
 
 
 
 
 
 
 
 
 
601
 
602
+ for chunk in translate_message(
603
+ user_message, history[:-1], target_lang, max_tokens,
604
+ temp, top_p_val, top_k_val, rep_penalty, model_type
605
+ ):
606
+ history[-1][1] = chunk
607
+ yield history
608
+
609
+ # Main Gradio interface
610
+ with gr.Blocks(css=css, title="๐ŸŒ Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo:
611
+
612
+ gr.Markdown(
613
+ """
614
+ <div class="title-container">
615
+ <h1>๐ŸŒ Advanced Multilingual Translation Hub</h1>
616
+ <p style="font-size: 18px; margin-top: 10px;">
617
+ Experience state-of-the-art translation with multiple AI models
618
+ </p>
619
+ </div>
620
+ """,
621
+ elem_classes="title-container"
622
+ )
623
+
624
+ # Statistics cards
625
+ with gr.Row():
626
+ gr.Markdown(
627
+ '<div class="stats-card"><h3>๐ŸŽฏ</h3><p><strong>22+</strong><br>Languages</p></div>',
628
+ elem_classes="stats-card"
629
+ )
630
+ gr.Markdown(
631
+ '<div class="stats-card"><h3>๐Ÿš€</h3><p><strong>2</strong><br>AI Models</p></div>',
632
+ elem_classes="stats-card"
633
+ )
634
+ gr.Markdown(
635
+ '<div class="stats-card"><h3>โšก</h3><p><strong>Optimized</strong><br>Performance</p></div>',
636
+ elem_classes="stats-card"
637
+ )
638
+ gr.Markdown(
639
+ '<div class="stats-card"><h3>๐Ÿ”’</h3><p><strong>Secure</strong><br>Processing</p></div>',
640
+ elem_classes="stats-card"
641
  )
642
+
643
+ with gr.Tabs(elem_classes="model-tab") as tabs:
644
+ with gr.TabItem("๐Ÿ‡ฎ๐Ÿ‡ณ IndicTrans3-Beta", elem_id="indictrans-tab"):
645
+ indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION)
646
+
647
+ with gr.TabItem("๐ŸŒ Sarvam Translate", elem_id="sarvam-tab"):
648
+ sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION)
649
+
650
+ # Event handlers for IndicTrans
651
+ (indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang,
652
+ indictrans_rating, indictrans_feedback, indictrans_feedback_submit,
653
+ indictrans_max_tokens, indictrans_temp, indictrans_top_p,
654
+ indictrans_top_k, indictrans_rep_penalty) = indictrans_components
655
+
656
+ indictrans_msg.submit(
657
+ user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
658
+ [indictrans_msg, indictrans_chatbot], queue=False
659
+ ).then(
660
+ lambda *args: bot(*args, "indictrans"),
661
+ [indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
662
+ indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
663
+ indictrans_chatbot,
664
+ )
665
+
666
+ indictrans_submit.click(
667
+ user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
668
+ [indictrans_msg, indictrans_chatbot], queue=False
669
+ ).then(
670
+ lambda *args: bot(*args, "indictrans"),
671
+ [indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
672
+ indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
673
+ indictrans_chatbot,
674
+ )
675
+
676
+ indictrans_feedback_submit.click(
677
+ lambda *args: store_feedback(*args, "indictrans"),
678
+ inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang],
679
+ )
680
+
681
+ # Event handlers for Sarvam
682
+ (sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang,
683
+ sarvam_rating, sarvam_feedback, sarvam_feedback_submit,
684
+ sarvam_max_tokens, sarvam_temp, sarvam_top_p,
685
+ sarvam_top_k, sarvam_rep_penalty) = sarvam_components
686
+
687
+ sarvam_msg.submit(
688
+ user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
689
+ [sarvam_msg, sarvam_chatbot], queue=False
690
+ ).then(
691
+ lambda *args: bot(*args, "sarvam"),
692
+ [sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
693
+ sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
694
+ sarvam_chatbot,
695
+ )
696
+
697
+ sarvam_submit.click(
698
+ user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
699
+ [sarvam_msg, sarvam_chatbot], queue=False
700
+ ).then(
701
+ lambda *args: bot(*args, "sarvam"),
702
+ [sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
703
+ sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
704
+ sarvam_chatbot,
705
+ )
706
+
707
+ sarvam_feedback_submit.click(
708
+ lambda *args: store_feedback(*args, "sarvam"),
709
+ inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang],
710
+ )
711
+
712
+ # Footer
713
+ gr.Markdown(
714
+ """
715
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;">
716
+ <p>๐Ÿš€ <strong>Powered by AI4Bharat & Sarvam AI</strong> |
717
+ Built with โค๏ธ using Gradio |
718
+ ๐Ÿ”ง <strong>Optimized with KV Caching & Advanced Memory Management</strong></p>
719
+ </div>
720
+ """
721
+ )
722
+
723
  if __name__ == "__main__":
724
+ demo.launch(
725
+ server_name="0.0.0.0",
726
+ server_port=7860,
727
+ share=True,
728
+ show_error=True,
729
+ max_threads=10
730
+ )