AshwinSankar's picture
Update app.py
3a650f2 verified
raw
history blame
23.8 kB
import os
import torch
import spaces
import psycopg2
import gradio as gr
from threading import Thread
from collections.abc import Iterator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gc
# Constants
MAX_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Language lists
INDIC_LANGUAGES = [
"Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati",
"Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili",
"Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri",
"Manipuri", "Bodo", "English", "Sanskrit"
]
SARVAM_LANGUAGES = INDIC_LANGUAGES
# Model configurations with optimizations
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE_MAP = "auto" if torch.cuda.is_available() else None
class ModelManager:
def __init__(self):
self.indictrans_model = None
self.indictrans_tokenizer = None
self.sarvam_model = None
self.sarvam_tokenizer = None
self.current_model = None
def load_indictrans_model(self):
if self.indictrans_model is None:
try:
self.indictrans_model = AutoModelForCausalLM.from_pretrained(
"ai4bharat/IndicTrans3-beta",
torch_dtype=TORCH_DTYPE,
device_map=DEVICE_MAP,
token=HF_TOKEN,
use_cache=True, # Enable KV cache
low_cpu_mem_usage=True,
trust_remote_code=True
)
self.indictrans_tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/IndicTrans3-beta",
trust_remote_code=True
)
# Enable optimizations
if hasattr(self.indictrans_model, 'eval'):
self.indictrans_model.eval()
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Error loading IndicTrans model: {e}")
def load_sarvam_model(self):
if self.sarvam_model is None:
try:
self.sarvam_model = AutoModelForCausalLM.from_pretrained(
"sarvamai/sarvam-translate",
torch_dtype=TORCH_DTYPE,
device_map=DEVICE_MAP,
token=HF_TOKEN,
use_cache=True, # Enable KV cache
low_cpu_mem_usage=True,
trust_remote_code=True
)
self.sarvam_tokenizer = AutoTokenizer.from_pretrained(
"sarvamai/sarvam-translate",
trust_remote_code=True
)
# Enable optimizations
if hasattr(self.sarvam_model, 'eval'):
self.sarvam_model.eval()
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Error loading Sarvam model: {e}")
def get_model_and_tokenizer(self, model_type):
if model_type == "indictrans":
if self.indictrans_model is None:
self.load_indictrans_model()
return self.indictrans_model, self.indictrans_tokenizer
else: # sarvam
if self.sarvam_model is None:
self.load_sarvam_model()
return self.sarvam_model, self.sarvam_tokenizer
# Global model manager
model_manager = ModelManager()
def format_message_for_translation(message, target_lang):
return f"Translate the following text to {target_lang}: {message}"
def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type):
try:
if not rating:
gr.Warning("Please select a rating before submitting feedback.", duration=5)
return None
if not feedback_text or feedback_text.strip() == "":
gr.Warning("Please provide some feedback before submitting.", duration=5)
return None
if not chat_history:
gr.Warning("Please provide the input text before submitting feedback.", duration=5)
return None
if len(chat_history[0]) < 2:
gr.Warning("Please translate the input text before submitting feedback.", duration=5)
return None
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO feedback
(tgt_lang, rating, feedback_txt, chat_history, model_type)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type))
conn.commit()
cursor.close()
conn.close()
gr.Info("Thank you for your feedback! ๐Ÿ™", duration=5)
except Exception as e:
print(f"Database error: {e}")
gr.Error("An error occurred while storing feedback. Please try again later.", duration=5)
def store_output(tgt_lang, input_text, output_text, model_type):
try:
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO translation
(input_txt, output_txt, tgt_lang, model_type)
VALUES (%s, %s, %s, %s)
"""
cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type))
conn.commit()
cursor.close()
conn.close()
except Exception as e:
print(f"Database error: {e}")
@spaces.GPU
def translate_message(
message: str,
chat_history: list[dict],
target_language: str = "Hindi",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
model_type: str = "indictrans"
) -> Iterator[str]:
model, tokenizer = model_manager.get_model_and_tokenizer(model_type)
if model is None or tokenizer is None:
yield "Error: Model failed to load. Please try again."
return
conversation = []
translation_request = format_message_for_translation(message, target_language)
conversation.append({"role": "user", "content": translation_request})
try:
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"use_cache": True, # Enable KV cache
"pad_token_id": tokenizer.eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Clean up
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
store_output(target_language, message, "".join(outputs), model_type)
except Exception as e:
yield f"Translation error: {str(e)}"
# Enhanced CSS with beautiful styling
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
* {
font-family: 'Inter', sans-serif;
}
.gradio-container {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
min-height: 100vh;
}
.main-container {
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(10px);
border-radius: 20px;
padding: 2rem;
margin: 1rem;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
}
.title-container {
text-align: center;
margin-bottom: 2rem;
padding: 1rem;
background: linear-gradient(45deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.model-tab {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
border: none;
border-radius: 15px;
color: white;
font-weight: 600;
padding: 1rem 2rem;
transition: all 0.3s ease;
}
.model-tab:hover {
transform: translateY(-2px);
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.2);
}
.language-dropdown {
background: white;
border: 2px solid #e2e8f0;
border-radius: 12px;
padding: 0.75rem;
font-size: 16px;
transition: all 0.3s ease;
}
.language-dropdown:focus {
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
}
.chat-container {
background: white;
border-radius: 15px;
padding: 1rem;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
margin: 1rem 0;
}
.message-input {
border: 2px solid #e2e8f0;
border-radius: 12px;
padding: 1rem;
font-size: 16px;
transition: all 0.3s ease;
background: white;
}
.message-input:focus {
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
}
.translate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
border-radius: 12px;
color: white;
font-weight: 600;
padding: 1rem 2rem;
font-size: 16px;
cursor: pointer;
transition: all 0.3s ease;
}
.translate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 10px 25px rgba(102, 126, 234, 0.3);
}
.examples-container {
background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
border-radius: 15px;
padding: 1.5rem;
margin: 1rem 0;
}
.feedback-section {
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
border-radius: 15px;
padding: 1.5rem;
margin: 1rem 0;
border: none;
}
.advanced-options {
background: linear-gradient(135deg, #d299c2 0%, #fef9d7 100%);
border-radius: 15px;
padding: 1.5rem;
margin: 1rem 0;
}
.slider-container .gr-slider {
background: linear-gradient(90deg, #667eea, #764ba2);
}
.rating-container {
display: flex;
gap: 1rem;
justify-content: center;
margin: 1rem 0;
}
.feedback-btn {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
border: none;
border-radius: 12px;
color: white;
font-weight: 600;
padding: 0.75rem 1.5rem;
cursor: pointer;
transition: all 0.3s ease;
}
.feedback-btn:hover {
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(240, 147, 251, 0.3);
}
.stats-card {
background: rgba(255, 255, 255, 0.8);
border-radius: 12px;
padding: 1rem;
text-align: center;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
margin: 0.5rem;
}
.model-info {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 12px;
padding: 1rem;
margin: 1rem 0;
}
.animate-pulse {
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
@keyframes pulse {
0%, 100% {
opacity: 1;
}
50% {
opacity: .5;
}
}
.loading-spinner {
border: 4px solid #f3f3f3;
border-top: 4px solid #667eea;
border-radius: 50%;
width: 40px;
height: 40px;
animation: spin 2s linear infinite;
margin: 0 auto;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
"""
# Model descriptions
INDICTRANS_DESCRIPTION = """
<div class="model-info">
<h3>๐ŸŒŸ IndicTrans3-Beta</h3>
<p><strong>Latest SOTA translation model from AI4Bharat</strong></p>
<ul>
<li>โœ… Supports <strong>22 Indic languages</strong></li>
<li>โœ… Document-level machine translation</li>
<li>โœ… Optimized for real-world applications</li>
<li>โœ… Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""
SARVAM_DESCRIPTION = """
<div class="model-info">
<h3>๐Ÿš€ Sarvam Translate</h3>
<p><strong>Advanced multilingual translation model</strong></p>
<ul>
<li>โœ… Supports <strong>22 Indic languages</strong></li>
<li>โœ… High-quality translations</li>
<li>โœ… Document-level machine translation</li>
<li>โœ… Optimized for real-world applications</li>
<li>โœ… Optimized for production use</li>
<li>โœ… Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""
def create_chatbot_interface(model_type, languages, description):
with gr.Column(elem_classes="main-container"):
gr.Markdown(description)
target_language = gr.Dropdown(
languages,
value=languages[0],
label="๐ŸŒ Select Target Language",
elem_classes="language-dropdown",
)
chatbot = gr.Chatbot(
height=500,
elem_classes="chat-container",
show_copy_button=True,
avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"],
bubble_full_width=False,
show_label=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="โœ๏ธ Enter text to translate...",
show_label=False,
container=False,
scale=9,
elem_classes="message-input",
)
submit_btn = gr.Button(
"๐Ÿ”„ Translate",
scale=1,
elem_classes="translate-btn"
)
# Examples section
if model_type == "indictrans":
examples_data = [
"The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.",
"Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.",
"India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.",
"Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.",
"Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil."
]
else:
examples_data = [
"Hello, how are you today?",
"I love learning new languages and cultures.",
"Technology is transforming the way we communicate.",
"The weather is beautiful today.",
"Thank you for your help and support."
]
with gr.Accordion("๐Ÿ“š Example Texts", open=False, elem_classes="examples-container"):
gr.Examples(
examples=examples_data,
inputs=msg,
label="Click on any example to try:"
)
# Feedback section
with gr.Accordion("๐Ÿ’ญ Provide Feedback", open=False, elem_classes="feedback-section"):
gr.Markdown("### ๐Ÿ“ Rate Translation & Share Feedback")
gr.Markdown("Help us improve translation quality with your valuable feedback!")
with gr.Row():
rating = gr.Radio(
["1", "2", "3", "4", "5"],
label="๐Ÿ† Translation Quality Rating",
value=None
)
feedback_text = gr.Textbox(
placeholder="๐Ÿ’ฌ Share your thoughts about the translation quality, accuracy, or suggestions for improvement...",
label="๐Ÿ“ Your Feedback",
lines=3,
)
feedback_submit = gr.Button(
"๐Ÿ“ค Submit Feedback",
elem_classes="feedback-btn"
)
# Advanced options
with gr.Accordion("โš™๏ธ Advanced Settings", open=False, elem_classes="advanced-options"):
gr.Markdown("### ๐Ÿ”ง Fine-tune Translation Parameters")
with gr.Row():
max_new_tokens = gr.Slider(
label="๐Ÿ“ Max New Tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
elem_classes="slider-container"
)
temperature = gr.Slider(
label="๐ŸŒก๏ธ Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.1,
elem_classes="slider-container"
)
with gr.Row():
top_p = gr.Slider(
label="๐ŸŽฏ Top-p (Nucleus Sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
elem_classes="slider-container"
)
top_k = gr.Slider(
label="๐Ÿ” Top-k",
minimum=1,
maximum=100,
step=1,
value=50,
elem_classes="slider-container"
)
repetition_penalty = gr.Slider(
label="๐Ÿ”„ Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
elem_classes="slider-container"
)
return (chatbot, msg, submit_btn, target_language, rating, feedback_text,
feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
def user(user_message, history, target_lang):
return "", history + [[user_message, None]]
def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type):
user_message = history[-1][0]
history[-1][1] = ""
for chunk in translate_message(
user_message, history[:-1], target_lang, max_tokens,
temp, top_p_val, top_k_val, rep_penalty, model_type
):
history[-1][1] = chunk
yield history
# Main Gradio interface
with gr.Blocks(css=css, title="๐ŸŒ Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div class="title-container">
<h1>๐ŸŒ Advanced Multilingual Translation Hub</h1>
<p style="font-size: 18px; margin-top: 10px;">
Experience state-of-the-art translation with multiple AI models
</p>
</div>
""",
elem_classes="title-container"
)
# Statistics cards
with gr.Row():
gr.Markdown(
'<div class="stats-card"><h3>๐ŸŽฏ</h3><p><strong>22+</strong><br>Languages</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>๐Ÿš€</h3><p><strong>2</strong><br>AI Models</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>โšก</h3><p><strong>Optimized</strong><br>Performance</p></div>',
elem_classes="stats-card"
)
gr.Markdown(
'<div class="stats-card"><h3>๐Ÿ”’</h3><p><strong>Secure</strong><br>Processing</p></div>',
elem_classes="stats-card"
)
with gr.Tabs(elem_classes="model-tab") as tabs:
with gr.TabItem("๐Ÿ‡ฎ๐Ÿ‡ณ IndicTrans3-Beta", elem_id="indictrans-tab"):
indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION)
with gr.TabItem("๐ŸŒ Sarvam Translate", elem_id="sarvam-tab"):
sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION)
# Event handlers for IndicTrans
(indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang,
indictrans_rating, indictrans_feedback, indictrans_feedback_submit,
indictrans_max_tokens, indictrans_temp, indictrans_top_p,
indictrans_top_k, indictrans_rep_penalty) = indictrans_components
indictrans_msg.submit(
user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
[indictrans_msg, indictrans_chatbot], queue=False
).then(
lambda *args: bot(*args, "indictrans"),
[indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
indictrans_chatbot,
)
indictrans_submit.click(
user, [indictrans_msg, indictrans_chatbot, indictrans_lang],
[indictrans_msg, indictrans_chatbot], queue=False
).then(
lambda *args: bot(*args, "indictrans"),
[indictrans_chatbot, indictrans_lang, indictrans_max_tokens,
indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
indictrans_chatbot,
)
indictrans_feedback_submit.click(
lambda *args: store_feedback(*args, "indictrans"),
inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang],
)
# Event handlers for Sarvam
(sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang,
sarvam_rating, sarvam_feedback, sarvam_feedback_submit,
sarvam_max_tokens, sarvam_temp, sarvam_top_p,
sarvam_top_k, sarvam_rep_penalty) = sarvam_components
sarvam_msg.submit(
user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
[sarvam_msg, sarvam_chatbot], queue=False
).then(
lambda *args: bot(*args, "sarvam"),
[sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
sarvam_chatbot,
)
sarvam_submit.click(
user, [sarvam_msg, sarvam_chatbot, sarvam_lang],
[sarvam_msg, sarvam_chatbot], queue=False
).then(
lambda *args: bot(*args, "sarvam"),
[sarvam_chatbot, sarvam_lang, sarvam_max_tokens,
sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
sarvam_chatbot,
)
sarvam_feedback_submit.click(
lambda *args: store_feedback(*args, "sarvam"),
inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang],
)
# Footer
gr.Markdown(
"""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;">
<p>๐Ÿš€ <strong>Powered by AI4Bharat & Sarvam AI</strong> |
Built with โค๏ธ using Gradio |
๐Ÿ”ง <strong>Optimized with KV Caching & Advanced Memory Management</strong></p>
</div>
"""
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
max_threads=10
)