import os
import torch
import spaces
import gradio as gr
from threading import Thread
from collections.abc import Iterator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
HF_TOKEN = os.environ['HF_TOKEN']
model_id = "ai4bharat/IndicTrans3-beta"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
LANGUAGES = {
"Hindi": "hin_Deva",
"Bengali": "ben_Beng",
"Telugu": "tel_Telu",
"Marathi": "mar_Deva",
"Tamil": "tam_Taml",
"Urdu": "urd_Arab",
"Gujarati": "guj_Gujr",
"Kannada": "kan_Knda",
"Odia": "ori_Orya",
"Malayalam": "mal_Mlym",
"Punjabi": "pan_Guru",
"Assamese": "asm_Beng",
"Maithili": "mai_Mith",
"Santali": "sat_Olck",
"Kashmiri": "kas_Arab",
"Nepali": "nep_Deva",
"Sindhi": "snd_Arab",
"Konkani": "kok_Deva",
"Dogri": "dgo_Deva",
"Manipuri": "mni_Beng",
"Bodo": "brx_Deva"
}
def format_message_for_translation(message, target_lang):
return f"Translate the following text to {target_lang}: {message}"
@spaces.GPU
def translate_message(
message: str,
chat_history: list[dict],
target_language: str = "Hindi",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
translation_request = format_message_for_translation(message, target_language)
print(f"Translation request: {translation_request}")
conversation.append({"role": "user", "content": translation_request})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def store_feedback(rating, feedback_text):
if not rating:
gr.Warning("Please select a rating before submitting feedback.", duration=5)
return None
if not feedback_text or feedback_text.strip() == "":
gr.Warning("Please provide some feedback before submitting.", duration=5)
return None
gr.Info("Feedback submitted successfully!")
return "Thank you for your feedback!"
css = """
# body {
# background-color: #f7f7f7;
# }
.feedback-section {
margin-top: 30px;
border-top: 1px solid #ddd;
padding-top: 20px;
}
.container {
max-width: 90%;
margin: 0 auto;
}
.language-selector {
margin-bottom: 20px;
padding: 10px;
background-color: #ffffff;
border-radius: 8px;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
.advanced-options {
margin-top: 20px;
}
"""
DESCRIPTION = """\
IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across 22 Indic languages with high accuracy. It supports document-level machine translation (MT) and is built to match the performance of other leading SOTA models.
📢 Training data will be released soon!