sudoping01's picture
Update app.py
1b7efab verified
raw
history blame
5.49 kB
import gradio as gr
from transformers import pipeline
import torch
import logging
import spaces
from typing import Union, List
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
device = "cuda"
logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using MPS for inference.")
else:
device = "cpu"
logger.info("Using CPU for inference.")
class BambaraTranslator:
def __init__(self, model_name: str = "sudoping01/nllb-bambara-v2"):
self.translator = pipeline(
"translation",
model=model_name,
device=device,
max_length=512,
truncation=True
)
self.flores_codes = {
"French": "fra_Latn",
"English": "eng_Latn",
"Bambara": "bam_Latn"
}
logger.info("Translation pipeline initialized successfully.")
def translate(self, text: Union[str, List[str]], src_lang: str, tgt_lang: str) -> Union[str, List[str]]:
source_lang = self.flores_codes[src_lang]
target_lang = self.flores_codes[tgt_lang]
logger.info(f"Translating text from {source_lang} to {target_lang}.")
try:
if isinstance(text, str):
translation = self.translator(text, src_lang=source_lang, tgt_lang=target_lang, num_beams=2)
return str(translation[0]['translation_text'])
else:
translations = self.translator(text, src_lang=source_lang, tgt_lang=target_lang, num_beams=2)
return [str(t['translation_text']) for t in translations]
except Exception as e:
logger.error(f"Translation failed: {e}")
return "An error occurred during translation."
translator = BambaraTranslator()
examples = [
["Aw ni ce. Ne tɔgɔ ye Adama. Nka I bɛ se ka nwele Jarakɛ, Filakɛ, Marakakɛ walima Tarawelekɛ. Awɔ ne ye Malien de ye. Aw Sanbɛ, Sanbɛ. San min tɛ ɲinan ye, an bɛɛ ka jɛ ka o seli ɲɔgɔn fɛ, hɛɛrɛ ni lafiya la. Ala ka Mali suma. Ala ka Mali yiriwa. Ala ka Mali taa ɲɛ. Ala ka an ka seliw caya. Ala ka yafa an bɛɛ ma.", "Bambara", "French"],
["Le Mali est un pays riche en culture mais confronté à de nombreux défis.", "French", "Bambara"],
["The sun rises every morning to bring light to the world.", "English", "Bambara"],
["Good morning", "English", "Bambara"],
]
@spaces.GPU()
def translate_text(text: str, src_lang: str, tgt_lang: str) -> str:
"""
Translate the input text from the source language to the target language.
"""
if not text.strip():
return "Please enter text to translate."
if src_lang == tgt_lang:
return "Source and target languages must be different."
try:
result = translator.translate(text, src_lang, tgt_lang)
logger.info("Translation successful.")
return result
except Exception as e:
logger.error(f"Translation failed: {e}")
return f"Error: {str(e)}"
def build_interface():
"""
Builds the Gradio interface for translating text between supported languages.
"""
with gr.Blocks(title="Bambara Translator") as demo:
gr.Markdown(
"""
# 🇲🇱 Bambara Translator
Translate between Bambara, French, and English instantly using NLLB model.
## How to Use
1. Select source and target languages from the dropdowns
2. Enter your text or choose from examples
3. Click "Translate" to see the result
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
lines=5,
label="Text to Translate",
placeholder="Enter text here..."
)
with gr.Row():
src_lang = gr.Dropdown(
choices=["Bambara", "French", "English"],
label="Source Language",
value="Bambara"
)
tgt_lang = gr.Dropdown(
choices=["Bambara", "French", "English"],
label="Target Language",
value="French"
)
translate_btn = gr.Button("Translate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Translation", lines=5, interactive=False)
# Examples section
gr.Examples(
examples=examples,
inputs=[text_input, src_lang, tgt_lang],
outputs=output,
fn=translate_text,
cache_examples=False
)
gr.Markdown(
"""
**License:** CC BY-NC 4.0 (Non-commercial use)
**Based on:** Meta's NLLB (No Language Left Behind)
"""
)
translate_btn.click(
fn=translate_text,
inputs=[text_input, src_lang, tgt_lang],
outputs=output
)
return demo
if __name__ == "__main__":
logger.info("Starting the Gradio interface for the Bambara translator.")
interface = build_interface()
interface.launch()
logger.info("Gradio interface running.")