Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import spaces | |
| import psycopg2 | |
| import gradio as gr | |
| from threading import Thread | |
| from collections.abc import Iterator | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import gc | |
| # Constants | |
| MAX_MAX_NEW_TOKENS = 4096 | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # Language lists | |
| INDIC_LANGUAGES = [ | |
| "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati", | |
| "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", | |
| "Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri", | |
| "Manipuri", "Bodo", "English", "Sanskrit" | |
| ] | |
| SARVAM_LANGUAGES = INDIC_LANGUAGES | |
| # Model configurations with optimizations | |
| TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| DEVICE_MAP = "0" if torch.cuda.is_available() else "cpu" | |
| indictrans_model = AutoModelForCausalLM.from_pretrained( | |
| "ai4bharat/IndicTrans3-beta", | |
| torch_dtype=TORCH_DTYPE, | |
| device_map=DEVICE_MAP, | |
| token=HF_TOKEN, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| sarvam_model = AutoModelForCausalLM.from_pretrained( | |
| "sarvamai/sarvam-translate", | |
| torch_dtype=TORCH_DTYPE, | |
| device_map=DEVICE_MAP, | |
| token=HF_TOKEN, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ai4bharat/IndicTrans3-beta", | |
| trust_remote_code=True | |
| ) | |
| def format_message_for_translation(message, target_lang): | |
| return f"Translate the following text to {target_lang}: {message}" | |
| def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type): | |
| try: | |
| if not rating: | |
| gr.Warning("Please select a rating before submitting feedback.", duration=5) | |
| return None | |
| if not feedback_text or feedback_text.strip() == "": | |
| gr.Warning("Please provide some feedback before submitting.", duration=5) | |
| return None | |
| if not chat_history: | |
| gr.Warning("Please provide the input text before submitting feedback.", duration=5) | |
| return None | |
| if len(chat_history[0]) < 2: | |
| gr.Warning("Please translate the input text before submitting feedback.", duration=5) | |
| return None | |
| conn = psycopg2.connect( | |
| host=os.getenv("DB_HOST"), | |
| database=os.getenv("DB_NAME"), | |
| user=os.getenv("DB_USER"), | |
| password=os.getenv("DB_PASSWORD"), | |
| port=os.getenv("DB_PORT"), | |
| ) | |
| cursor = conn.cursor() | |
| insert_query = """ | |
| INSERT INTO feedback | |
| (tgt_lang, rating, feedback_txt, chat_history, model_type) | |
| VALUES (%s, %s, %s, %s, %s) | |
| """ | |
| cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type)) | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| gr.Info("Thank you for your feedback! ๐", duration=5) | |
| except Exception as e: | |
| print(f"Database error: {e}") | |
| gr.Error("An error occurred while storing feedback. Please try again later.", duration=5) | |
| def store_output(tgt_lang, input_text, output_text, model_type): | |
| try: | |
| conn = psycopg2.connect( | |
| host=os.getenv("DB_HOST"), | |
| database=os.getenv("DB_NAME"), | |
| user=os.getenv("DB_USER"), | |
| password=os.getenv("DB_PASSWORD"), | |
| port=os.getenv("DB_PORT"), | |
| ) | |
| cursor = conn.cursor() | |
| insert_query = """ | |
| INSERT INTO translation | |
| (input_txt, output_txt, tgt_lang, model_type) | |
| VALUES (%s, %s, %s, %s) | |
| """ | |
| cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type)) | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| except Exception as e: | |
| print(f"Database error: {e}") | |
| def translate_message( | |
| message: str, | |
| chat_history: list[dict], | |
| target_language: str = "Hindi", | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| model_type: str = "indictrans" | |
| ) -> Iterator[str]: | |
| if model_type == "indictrans": | |
| model = indictrans_model | |
| elif model_type == "sarvam": | |
| model = sarvam_model | |
| if model is None or tokenizer is None: | |
| yield "Error: Model failed to load. Please try again." | |
| return | |
| conversation = [] | |
| translation_request = format_message_for_translation(message, target_language) | |
| conversation.append({"role": "user", "content": translation_request}) | |
| try: | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, return_tensors="pt", add_generation_prompt=True | |
| ) | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "temperature": temperature, | |
| "num_beams": 1, | |
| "repetition_penalty": repetition_penalty, | |
| "use_cache": True, # Enable KV cache | |
| } | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| # Clean up | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| store_output(target_language, message, "".join(outputs), model_type) | |
| except Exception as e: | |
| yield f"Translation error: {str(e)}" | |
| # Enhanced CSS with beautiful styling | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| * { | |
| font-family: 'Inter', sans-serif; | |
| box-sizing: border-box; | |
| } | |
| .gradio-container { | |
| background: #1a1a1a !important; | |
| color: #e0e0e0; | |
| min-height: 100vh; | |
| } | |
| .main-container { | |
| background: #2a2a2a; | |
| border-radius: 12px; | |
| padding: 1.5rem; | |
| margin: 1rem; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
| } | |
| .title-container { | |
| text-align: center; | |
| margin-bottom: 1.5rem; | |
| padding: 1rem; | |
| color: #a0a0ff; | |
| } | |
| .model-tab { | |
| background: #3333a0; | |
| border: none; | |
| border-radius: 8px; | |
| color: #ffffff; | |
| font-weight: 500; | |
| padding: 0.75rem 1.5rem; | |
| transition: all 0.2s ease; | |
| } | |
| .model-tab:hover { | |
| background: #4444b0; | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); | |
| } | |
| .language-dropdown { | |
| background: #333333; | |
| border: 1px solid #444444; | |
| border-radius: 8px; | |
| padding: 0.5rem; | |
| font-size: 14px; | |
| color: #e0e0e0; | |
| transition: all 0.2s ease; | |
| } | |
| .language-dropdown:focus { | |
| border-color: #6666ff; | |
| box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); | |
| } | |
| .chat-container { | |
| background: #222222; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
| margin: 1rem 0; | |
| } | |
| .message-input { | |
| background: #333333; | |
| border: 1px solid #444444; | |
| border-radius: 8px; | |
| padding: 0.75rem; | |
| font-size: 14px; | |
| color: #e0e0e0; | |
| transition: all 0.2s ease; | |
| } | |
| .message-input:focus { | |
| border-color: #6666ff; | |
| box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); | |
| } | |
| .translate-btn { | |
| background: #3333a0; | |
| border: none; | |
| border-radius: 8px; | |
| color: #ffffff; | |
| font-weight: 500; | |
| padding: 0.75rem 1.5rem; | |
| font-size: 14px; | |
| cursor: pointer; | |
| transition: all 0.2s ease; | |
| } | |
| .translate-btn:hover { | |
| background: #4444b0; | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); | |
| } | |
| .examples-container { | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| } | |
| .feedback-section { | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border: none; | |
| } | |
| .advanced-options { | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| } | |
| .slider-container .gr-slider { | |
| background: #444444; | |
| color: #e0e0e0; | |
| } | |
| .rating-container { | |
| display: flex; | |
| gap: 0.5rem; | |
| justify-content: center; | |
| margin: 0.5rem 0; | |
| } | |
| .feedback-btn { | |
| background: #3333a0; | |
| border: none; | |
| border-radius: 8px; | |
| color: #ffffff; | |
| font-weight: 500; | |
| padding: 0.5rem 1rem; | |
| cursor: pointer; | |
| transition: all 0.2s ease; | |
| } | |
| .feedback-btn:hover { | |
| background: #4444b0; | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); | |
| } | |
| .stats-card { | |
| background: #333333; | |
| border-radius: 8px; | |
| padding: 0.75rem; | |
| text-align: center; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
| margin: 0.5rem; | |
| color: #e0e0e0; | |
| } | |
| .model-info { | |
| background: #3333a0; | |
| color: #ffffff; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| } | |
| .animate-pulse { | |
| animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { | |
| opacity: 1; | |
| } | |
| 50% { | |
| opacity: 0.5; | |
| } | |
| } | |
| .loading-spinner { | |
| border: 3px solid #444444; | |
| border-top: 3px solid #6666ff; | |
| border-radius: 50%; | |
| width: 30px; | |
| height: 30px; | |
| animation: spin 1.5s linear infinite; | |
| margin: 0 auto; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| """ | |
| # Model descriptions | |
| INDICTRANS_DESCRIPTION = """ | |
| <div class="model-info"> | |
| <h3>๐ IndicTrans3-Beta</h3> | |
| <p><strong>Latest SOTA translation model from AI4Bharat</strong></p> | |
| <ul> | |
| <li>โ Supports <strong>22 Indic languages</strong></li> | |
| <li>โ Document-level machine translation</li> | |
| <li>โ Optimized for real-world applications</li> | |
| <li>โ Enhanced with KV caching for faster inference</li> | |
| </ul> | |
| </div> | |
| """ | |
| SARVAM_DESCRIPTION = """ | |
| <div class="model-info"> | |
| <h3>๐ Sarvam Translate</h3> | |
| <p><strong>Advanced multilingual translation model</strong></p> | |
| <ul> | |
| <li>โ Supports <strong>22 Indic languages</strong></li> | |
| <li>โ High-quality translations</li> | |
| <li>โ Document-level machine translation</li> | |
| <li>โ Optimized for real-world applications</li> | |
| <li>โ Optimized for production use</li> | |
| <li>โ Enhanced with KV caching for faster inference</li> | |
| </ul> | |
| </div> | |
| """ | |
| def create_chatbot_interface(model_type, languages, description): | |
| with gr.Column(elem_classes="main-container"): | |
| gr.Markdown(description) | |
| target_language = gr.Dropdown( | |
| languages, | |
| value=languages[0], | |
| label="๐ Select Target Language", | |
| elem_classes="language-dropdown", | |
| ) | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| elem_classes="chat-container", | |
| show_copy_button=True, | |
| avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"], | |
| bubble_full_width=False, | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="โ๏ธ Enter text to translate...", | |
| show_label=False, | |
| container=False, | |
| scale=9, | |
| elem_classes="message-input", | |
| ) | |
| submit_btn = gr.Button( | |
| "๐ Translate", | |
| scale=1, | |
| elem_classes="translate-btn" | |
| ) | |
| # Examples section | |
| if model_type == "indictrans": | |
| examples_data = [ | |
| "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.", | |
| "Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.", | |
| "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.", | |
| "Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.", | |
| "Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil." | |
| ] | |
| else: | |
| examples_data = [ | |
| "Hello, how are you today?", | |
| "I love learning new languages and cultures.", | |
| "Technology is transforming the way we communicate.", | |
| "The weather is beautiful today.", | |
| "Thank you for your help and support." | |
| ] | |
| with gr.Accordion("๐ Example Texts", open=False, elem_classes="examples-container"): | |
| gr.Examples( | |
| examples=examples_data, | |
| inputs=msg, | |
| label="Click on any example to try:" | |
| ) | |
| # Feedback section | |
| with gr.Accordion("๐ญ Provide Feedback", open=False, elem_classes="feedback-section"): | |
| gr.Markdown("### ๐ Rate Translation & Share Feedback") | |
| gr.Markdown("Help us improve translation quality with your valuable feedback!") | |
| with gr.Row(): | |
| rating = gr.Radio( | |
| ["1", "2", "3", "4", "5"], | |
| label="๐ Translation Quality Rating", | |
| value=None | |
| ) | |
| feedback_text = gr.Textbox( | |
| placeholder="๐ฌ Share your thoughts about the translation quality, accuracy, or suggestions for improvement...", | |
| label="๐ Your Feedback", | |
| lines=3, | |
| ) | |
| feedback_submit = gr.Button( | |
| "๐ค Submit Feedback", | |
| elem_classes="feedback-btn" | |
| ) | |
| # Advanced options | |
| with gr.Accordion("โ๏ธ Advanced Settings", open=False, elem_classes="advanced-options"): | |
| gr.Markdown("### ๐ง Fine-tune Translation Parameters") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider( | |
| label="๐ Max New Tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| elem_classes="slider-container" | |
| ) | |
| temperature = gr.Slider( | |
| label="๐ก๏ธ Temperature", | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.1, | |
| elem_classes="slider-container" | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| label="๐ฏ Top-p (Nucleus Sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| elem_classes="slider-container" | |
| ) | |
| top_k = gr.Slider( | |
| label="๐ Top-k", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| elem_classes="slider-container" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="๐ Repetition Penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0, | |
| elem_classes="slider-container" | |
| ) | |
| return (chatbot, msg, submit_btn, target_language, rating, feedback_text, | |
| feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty) | |
| def user(user_message, history, target_lang): | |
| return "", history + [[user_message, None]] | |
| def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type): | |
| user_message = history[-1][0] | |
| history[-1][1] = "" | |
| for chunk in translate_message( | |
| user_message, history[:-1], target_lang, max_tokens, | |
| temp, top_p_val, top_k_val, rep_penalty, model_type | |
| ): | |
| history[-1][1] = chunk | |
| yield history | |
| # Main Gradio interface | |
| with gr.Blocks(css=css, title="๐ Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| <div class="title-container"> | |
| <h1>๐ Advanced Multilingual Translation Hub</h1> | |
| <p style="font-size: 18px; margin-top: 10px;"> | |
| Experience state-of-the-art translation with multiple AI models | |
| </p> | |
| </div> | |
| """, | |
| elem_classes="title-container" | |
| ) | |
| # Statistics cards | |
| with gr.Row(): | |
| gr.Markdown( | |
| '<div class="stats-card"><h3>๐ฏ</h3><p><strong>22+</strong><br>Languages</p></div>', | |
| elem_classes="stats-card" | |
| ) | |
| gr.Markdown( | |
| '<div class="stats-card"><h3>๐</h3><p><strong>2</strong><br>AI Models</p></div>', | |
| elem_classes="stats-card" | |
| ) | |
| gr.Markdown( | |
| '<div class="stats-card"><h3>โก</h3><p><strong>Optimized</strong><br>Performance</p></div>', | |
| elem_classes="stats-card" | |
| ) | |
| gr.Markdown( | |
| '<div class="stats-card"><h3>๐</h3><p><strong>Secure</strong><br>Processing</p></div>', | |
| elem_classes="stats-card" | |
| ) | |
| with gr.Tabs(elem_classes="model-tab") as tabs: | |
| with gr.TabItem("๐ฎ๐ณ IndicTrans3-Beta", elem_id="indictrans-tab"): | |
| indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION) | |
| with gr.TabItem("๐ Sarvam Translate", elem_id="sarvam-tab"): | |
| sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION) | |
| # Event handlers for IndicTrans | |
| (indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang, | |
| indictrans_rating, indictrans_feedback, indictrans_feedback_submit, | |
| indictrans_max_tokens, indictrans_temp, indictrans_top_p, | |
| indictrans_top_k, indictrans_rep_penalty) = indictrans_components | |
| indictrans_msg.submit( | |
| user, [indictrans_msg, indictrans_chatbot, indictrans_lang], | |
| [indictrans_msg, indictrans_chatbot], queue=False | |
| ).then( | |
| lambda *args: bot(*args, "indictrans"), | |
| [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, | |
| indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty], | |
| indictrans_chatbot, | |
| ) | |
| indictrans_submit.click( | |
| user, [indictrans_msg, indictrans_chatbot, indictrans_lang], | |
| [indictrans_msg, indictrans_chatbot], queue=False | |
| ).then( | |
| lambda *args: bot(*args, "indictrans"), | |
| [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, | |
| indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty], | |
| indictrans_chatbot, | |
| ) | |
| indictrans_feedback_submit.click( | |
| lambda *args: store_feedback(*args, "indictrans"), | |
| inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang], | |
| ) | |
| # Event handlers for Sarvam | |
| (sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang, | |
| sarvam_rating, sarvam_feedback, sarvam_feedback_submit, | |
| sarvam_max_tokens, sarvam_temp, sarvam_top_p, | |
| sarvam_top_k, sarvam_rep_penalty) = sarvam_components | |
| sarvam_msg.submit( | |
| user, [sarvam_msg, sarvam_chatbot, sarvam_lang], | |
| [sarvam_msg, sarvam_chatbot], queue=False | |
| ).then( | |
| lambda *args: bot(*args, "sarvam"), | |
| [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, | |
| sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty], | |
| sarvam_chatbot, | |
| ) | |
| sarvam_submit.click( | |
| user, [sarvam_msg, sarvam_chatbot, sarvam_lang], | |
| [sarvam_msg, sarvam_chatbot], queue=False | |
| ).then( | |
| lambda *args: bot(*args, "sarvam"), | |
| [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, | |
| sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty], | |
| sarvam_chatbot, | |
| ) | |
| sarvam_feedback_submit.click( | |
| lambda *args: store_feedback(*args, "sarvam"), | |
| inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang], | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;"> | |
| <p>๐ <strong>Powered by AI4Bharat & Sarvam AI</strong> | | |
| Built with โค๏ธ using Gradio | | |
| ๐ง <strong>Optimized with KV Caching & Advanced Memory Management</strong></p> | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) |