axxam's picture
Update app.py
7582e40 verified
import gradio as gr
import aiohttp
import asyncio
import html
import os
import json
import difflib
import re
from functools import wraps
import time
from typing import List, Tuple, Optional
import numpy as np
# --- Translation Service Configuration ---
API_URL = "https://imsidag-community-libretranslate-kabyle.hf.space/translate"
API_KEY = os.environ.get("API_KEY", "dummy_key")
# Translation Configuration
TRANSLATION_CONFIG = {
"request_timeout": 15,
"max_retries": 3,
"retry_delay": 1,
"max_text_length": 5000
}
langs = {
"English": "en",
"French": "fr",
"Italian": "it",
"Occitan (26000)": "oc_comp2",
"Taqbaylit (LTCommunity)": "kab",
"Taqbaylit (Tasenselkimt)": "kab_comp",
"Taqbaylit (51000)": "kab_comp2",
"Taqbaylit (OS)": "kab_os",
"Taqbaylit (num40000)": "kab_num",
"Taqbaylit (Google)": "google",
"Tasuqilt (Tamazight - Taqbaylit)": "ka",
"Taqbaylit (Wikimedia NLLB200)": "wikimedia",
"Taqbaylit (All models)": "kab_all"
}
models = [
("Taqbaylit (LTCommunity)", "kab"),
("Taqbaylit (Tasenselkimt)", "kab_comp"),
("Taqbaylit (51000)", "kab_comp2"),
("Taqbaylit (OS)", "kab_os"),
("Taqbaylit (num40000)", "kab_num"),
("Taqbaylit (Google)", "google"),
("Tasuqilt (Tamazight - Taqbaylit)", "ka"),
("Taqbaylit (Wikimedia NLLB200)", "wikimedia")
]
class TranslationError(Exception):
"""Custom exception for translation errors"""
def __init__(self, service: str, message: str, error_type: str = "general"):
self.service = service
self.message = message
self.error_type = error_type
super().__init__(f"{service}: {message}")
def format_error_message(service: str, error: Exception, error_type: str = "general") -> str:
"""Format error messages consistently"""
error_indicators = {
"timeout": "Timeout",
"connection": "Connection Error",
"http": "HTTP Error",
"json": "JSON Error",
"general": "Error"
}
indicator = error_indicators.get(error_type, "Error")
message = str(error)[:100] + "..." if len(str(error)) > 100 else str(error)
return f"{indicator} {service}: {message}"
def retry_with_backoff(max_retries: int = 3):
"""Decorator for retry logic with exponential backoff"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
delay = TRANSLATION_CONFIG["retry_delay"] * (2 ** attempt)
await asyncio.sleep(delay)
else:
break
if isinstance(last_exception, aiohttp.ClientTimeout):
return format_error_message(kwargs.get('service', 'Unknown'), last_exception, "timeout")
elif isinstance(last_exception, aiohttp.ClientConnectionError):
return format_error_message(kwargs.get('service', 'Unknown'), last_exception, "connection")
else:
return format_error_message(kwargs.get('service', 'Unknown'), last_exception, "general")
return wrapper
return decorator
def validate_input(text: str) -> Tuple[bool, str]:
"""Validate input text"""
if not text.strip():
return False, "Please enter text to translate"
if len(text) > TRANSLATION_CONFIG["max_text_length"]:
return False, f"Text too long (max {TRANSLATION_CONFIG['max_text_length']} characters)"
return True, ""
@retry_with_backoff(max_retries=TRANSLATION_CONFIG["max_retries"])
async def get_tasuqilt_translation_async(session: aiohttp.ClientSession, text: str, service: str = "Tasuqilt") -> str:
"""Async version of Tasuqilt translation"""
url = "https://d2sjol2amz2ojp.cloudfront.net/translate"
headers = {
"User-Agent": "Mozilla/5.0",
"Content-Type": "application/json; charset=utf-8",
"Origin": "https://www.tasuqilt.com",
"Referer": "https://www.tasuqilt.com/"
}
payload = {
"translate": "/translate/",
"sentence": text,
"direction": "en-ka"
}
timeout = aiohttp.ClientTimeout(total=TRANSLATION_CONFIG["request_timeout"])
async with session.post(url, json=payload, headers=headers, timeout=timeout) as response:
if response.status != 200:
raise TranslationError(service, f"HTTP {response.status}", "http")
result = await response.json()
body = json.loads(result["body"])
translation = body.get("translation", "No result")
cleaned_translation = re.sub(
r'\s*[\(\*\+\-x]\)\s*|\.$|\)$|\($|\s+$',
'',
translation
).strip()
cleaned_translation = re.sub(r'\s*\(\s*$', '', cleaned_translation)
return cleaned_translation
@retry_with_backoff(max_retries=TRANSLATION_CONFIG["max_retries"])
async def get_wikimedia_translation_async(session: aiohttp.ClientSession, text: str, service: str = "Wikimedia") -> str:
"""Async version of Wikimedia translation - replicates curl format"""
url = "https://translate.wmcloud.org/api/translate"
formatted_content = f"{text}\n " # Added spaces to replicate curl payload
payload = {
"source_language": "en",
"target_language": "kab",
"format": "text",
"model": "nllb200-600M",
"content": formatted_content
}
headers = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:140.0) Gecko/20100101 Firefox/140.0",
"Accept": "*/*",
"Accept-Language": "kab-DZ,kab;q=0.7,en-US;q=0.3",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Referer": "https://translate.wmcloud.org/text",
"Content-Type": "application/json",
"Origin": "https://translate.wmcloud.org",
"Sec-GPC": "1",
"Connection": "keep-alive",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "same-origin",
"Priority": "u=0",
"TE": "trailers"
}
timeout = aiohttp.ClientTimeout(total=TRANSLATION_CONFIG["request_timeout"])
raw_data = json.dumps(payload)
async with session.post(url, data=raw_data, headers=headers, timeout=timeout) as response:
response_text = await response.text()
if response.status != 200:
raise TranslationError(service, f"HTTP {response.status}", "http")
try:
result = json.loads(response_text)
translated_text = result.get("translation", "No result")
except json.JSONDecodeError as e:
raise TranslationError(service, f"JSON parsing error: {e}", "json")
return translated_text.strip() if translated_text else "No result"
@retry_with_backoff(max_retries=TRANSLATION_CONFIG["max_retries"])
async def get_google_translation_async(session: aiohttp.ClientSession, text: str, source_code: str, service: str = "Google") -> str:
"""Async version of Google translation"""
params = {
"engine": "google",
"from": source_code,
"to": "ber-Latn",
"text": text
}
timeout = aiohttp.ClientTimeout(total=TRANSLATION_CONFIG["request_timeout"])
async with session.get("https://mozhi.pussthecat.org/api/translate", params=params, timeout=timeout) as response:
if response.status != 200:
raise TranslationError(service, f"HTTP {response.status}", "http")
result = await response.json()
return result.get("translated-text", "No result")
@retry_with_backoff(max_retries=TRANSLATION_CONFIG["max_retries"])
async def get_api_translation_async(session: aiohttp.ClientSession, text: str, source_code: str, target_code: str, service: str = "API") -> str:
"""Async version of API translation"""
data = {
"q": text,
"source": source_code,
"target": target_code,
"format": "text",
"alternatives": 3,
"api_key": API_KEY
}
timeout = aiohttp.ClientTimeout(total=TRANSLATION_CONFIG["request_timeout"])
async with session.post(API_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=timeout) as response:
if response.status != 200:
raise TranslationError(service, f"HTTP {response.status}", "http")
result = await response.json()
return result.get("translatedText", "No result")
async def get_single_translation_async(session: aiohttp.ClientSession, text: str, source_code: str, target_code: str, service_name: str) -> str:
"""Get translation from a single service asynchronously"""
try:
if target_code == "ka":
return await get_tasuqilt_translation_async(session, text, service_name)
elif target_code == "google":
return await get_google_translation_async(session, text, source_code, service_name)
elif target_code == "wikimedia":
return await get_wikimedia_translation_async(session, text, service_name)
else:
return await get_api_translation_async(session, text, source_code, target_code, service_name)
except Exception as e:
return format_error_message(service_name, e)
async def get_all_translations_async(text: str, source_code: str, progress_callback=None) -> List[str]:
"""Get all translations in parallel"""
timeout = aiohttp.ClientTimeout(total=TRANSLATION_CONFIG["request_timeout"])
async with aiohttp.ClientSession(timeout=timeout) as session:
tasks = []
for i, (name, code) in enumerate(models):
task = get_single_translation_async(session, text, source_code, code, name)
tasks.append(task)
translations = []
task_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(task_results):
if isinstance(result, Exception):
service_name = models[i][0]
translations.append(format_error_message(service_name, result))
else:
translations.append(result)
if progress_callback:
progress_callback((i + 1) / len(models), f"Completed {i + 1}/{len(models)} translations")
return translations
def diff_two_strings(text1: str, text2: str) -> str:
"""Highlight differences between two text strings."""
if "Error" in text1 or "Error" in text2:
return "<span style='color:red'>Error - diff unavailable</span>"
text1_words = text1.split()
text2_words = text2.split()
matcher = difflib.SequenceMatcher(None, text1_words, text2_words)
highlighted = []
for op, i1, i2, j1, j2 in matcher.get_opcodes():
if op == "equal":
highlighted.extend(text1_words[i1:i2])
elif op == "insert":
part = " ".join(text2_words[j1:j2])
highlighted.append(f"<span style='background-color: #d4edda; color: #155724; border-radius: 3px; padding: 1px 3px;'>{html.escape(part)}</span>") # Green for inserted
elif op == "delete":
part = " ".join(text1_words[i1:i2])
highlighted.append(f"<span style='background-color: #f8d7da; color: #721c24; text-decoration: line-through; border-radius: 3px; padding: 1px 3px;'>{html.escape(part)}</span>") # Red for deleted
elif op == "replace":
deleted_part = " ".join(text1_words[i1:i2])
inserted_part = " ".join(text2_words[j1:j2])
highlighted.append(f"<span style='background-color: #fff3cd; color: #856404; border-radius: 3px; padding: 1px 3px;'>{html.escape(deleted_part)} &#8594; {html.escape(inserted_part)}</span>") # Yellow for replaced
return " ".join(highlighted)
# --- TTS Model Integration ---
from transformers import VitsModel, AutoTokenizer
import torch
# Configuration for TTS
TTS_CONFIG = {
"tts_text_limit": 500 # Max characters to send to TTS model
}
# TTS Model Initialization (Global to load once)
tts_model = None
tts_tokenizer = None
TTS_MODEL_ID = "facebook/mms-tts-kab"
def load_tts_model():
"""Loads the TTS model and tokenizer, moving to GPU if available."""
global tts_model, tts_tokenizer
if tts_model is None:
print(f"Loading TTS model: {TTS_MODEL_ID}...")
try:
tts_model = VitsModel.from_pretrained(TTS_MODEL_ID)
tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
if torch.cuda.is_available():
tts_model = tts_model.to("cuda")
print("TTS model moved to GPU.")
else:
print("Running TTS model on CPU.")
print("TTS model loaded successfully.")
except Exception as e:
print(f"Failed to load TTS model: {e}")
tts_model = None
tts_tokenizer = None
gr.Warning(f"Failed to load TTS model: {e}. TTS functionality will be disabled.")
return tts_model, tts_tokenizer
async def generate_kabyle_audio_async(text: str) -> tuple[int, np.ndarray] | None:
"""
Generates Kabyle audio from text using MMS-TTS-Kab model.
Returns a tuple (sample_rate, numpy_audio_array) or None on failure.
"""
model, tokenizer = load_tts_model()
if model is None or tokenizer is None:
print("TTS model not loaded, cannot generate audio.")
return None
if not text.strip():
print("Empty text for TTS generation.")
return None
if len(text) > TTS_CONFIG["tts_text_limit"]:
print(f"TTS text too long (max {TTS_CONFIG['tts_text_limit']} characters). Truncating text.")
text = text[:TTS_CONFIG["tts_text_limit"]]
inputs = tokenizer(text, return_tensors="pt")
if torch.cuda.is_available() and model.device.type == 'cuda':
inputs = {k: v.to("cuda") for k, v in inputs.items()}
try:
with torch.no_grad():
output_waveform = model(**inputs).waveform
audio_data_np = output_waveform.cpu().numpy().squeeze().astype(np.float32)
samplerate = int(model.config.sampling_rate)
return samplerate, audio_data_np
except Exception as e:
print(f"Error during Kabyle TTS generation: {e}")
return None
def generate_kabyle_audio_sync(text: str) -> tuple[int, np.ndarray] | None:
"""Synchronous wrapper for Gradio to call the async TTS function."""
if not text or "Error" in text: # Don't try to synthesize error messages
return None
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
audio_data = loop.run_until_complete(generate_kabyle_audio_async(text))
loop.close()
return audio_data
except Exception as e:
print(f"Sync wrapper error for TTS: {e}")
gr.Error(f"TTS Error: {e}") # Show Gradio error for direct TTS generation
return None
# --- Combined Gradio UI and Logic ---
def translate_and_speak_with_progress(text: str, source_lang: str, progress=gr.Progress()) -> Tuple[str, str, List[str], dict, dict, str, Optional[Tuple[int, np.ndarray]], dict]:
"""
Main function to handle translation and then trigger TTS for the primary translation.
Also updates dropdowns for comparison and specific TTS selection.
"""
# 1. Validate Input
is_valid, error_msg = validate_input(text)
if not is_valid:
model_names = [m[0] for m in models]
# Use gr.update for dropdowns on error
return (
f"<div style='color:red; padding:10px;'>{error_msg}</div>",
"",
[],
gr.update(choices=model_names, value=None), # Fixed: gr.update instead of gr.Dropdown.update
gr.update(choices=model_names, value=None), # Fixed: gr.update instead of gr.Dropdown.update
"",
None,
gr.update(choices=model_names, value=None) # Fixed: gr.update instead of gr.Dropdown.update
)
source_code = langs[source_lang]
progress(0, desc="Starting translations...")
def update_progress(completed_ratio, desc):
progress(completed_ratio, desc=desc)
# 2. Get All Translations
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
translations = loop.run_until_complete(
get_all_translations_async(text, source_code, update_progress)
)
loop.close()
except Exception as e:
error_msg = f"Critical error during translation: {str(e)}"
model_names = [m[0] for m in models]
# Use gr.update for dropdowns on error
return (
f"<div style='color:red; padding:10px;'>{error_msg}</div>",
"",
[],
gr.update(choices=model_names, value=None), # Fixed
gr.update(choices=model_names, value=None), # Fixed
"",
None,
gr.update(choices=model_names, value=None) # Fixed
)
progress(0.9, desc="Building comparison table and generating speech...")
# Store translations for state management and dropdowns
model_name_options = [name for name, _ in models]
# Build comparison table
first_model_name = models[0][0]
first_translation = translations[0] # This is the primary translation for initial TTS
table_html = f"""
<div style='overflow-x: auto;'>
<table style='width: 100%; border-collapse: collapse; color: var(--body-text-color);'>
<thead>
<tr style='background-color: var(--background-fill-secondary);'>
<th style='padding: 8px; border: 1px solid var(--border-color-primary); text-align: left; min-width: 120px;'>Model</th>
<th style='padding: 8px; border: 1px solid var(--border-color-primary); text-align: left; min-width: 300px;'>Translation</th>
<th style='padding: 8px; border: 1px solid var(--border-color-primary); text-align: left; min-width: 300px;'>Differences from {first_model_name}</th>
</tr>
</thead>
<tbody>
"""
# Add first translation row (the baseline for the 'Differences from...' column)
status_indicator_first = "Error" if any(err in first_translation for err in ["Error", "Timeout", "Connection Error"]) else "Success"
table_html += f"""
<tr>
<td style='padding: 8px; border: 1px solid #ddd;'>{status_indicator_first} {first_model_name}</td>
<td style='padding: 8px; border: 1px solid #ddd;'>{html.escape(first_translation)}</td>
<td style='padding: 8px; border: 1px solid #ddd; color: #666; font-style: italic;'>(This is the baseline for comparison)</td>
</tr>
"""
# Add other translations, comparing to the first
for i in range(1, len(translations)): # Start from the second translation
name, _ = models[i]
current_translation = translations[i]
status_indicator = "Error" if any(err in current_translation for err in ["Error", "Timeout", "Connection Error"]) else "Success"
diff_display = diff_two_strings(first_translation, current_translation)
table_html += f"""
<tr>
<td style='padding: 8px; border: 1px solid #ddd;'>{status_indicator} {name}</td>
<td style='padding: 8px; border: 1px solid #ddd;'>{html.escape(current_translation)}</td>
<td style='padding: 8px; border: 1px solid #ddd;'>{diff_display}</td>
</tr>
"""
table_html += "</tbody></table></div>"
# 3. Generate TTS Audio for the first translation (for initial play)
initial_audio_output = generate_kabyle_audio_sync(first_translation)
progress(1.0, desc="Translation and Speech complete!")
# Return all outputs, including the initial TTS audio and options for the new TTS dropdown
return (
table_html,
first_translation, # For suggestion box initial value
translations, # Stored in state_translations
gr.update(choices=model_name_options, value=model_name_options[0]), # Fixed: Update comparison dropdown 1
gr.update(choices=model_name_options, value=model_name_options[0]), # Fixed: Update comparison dropdown 2
"", # For comparison_output clear
initial_audio_output, # For tts_output_audio (initial play)
gr.update(choices=model_name_options, value=first_model_name) # Fixed: Update TTS selector dropdown
)
def update_suggestion(translations_list: List[str], idx: int) -> str:
"""Update suggestion box with selected translation"""
if translations_list and 0 <= idx < len(translations_list):
return translations_list[idx]
return ""
def compare_selected_translations(
translations_list: List[str], # This is the list of actual translated strings
model1_name: str,
model2_name: str
) -> str:
"""Compares two selected translations from dropdowns."""
if not translations_list:
return "Please run a translation first."
# Map model names back to their original index to retrieve the translation string
model_name_to_index = {model_info[0]: i for i, model_info in enumerate(models)}
idx1 = model_name_to_index.get(model1_name)
idx2 = model_name_to_index.get(model2_name)
if idx1 is None or idx2 is None:
return "Please select valid models for comparison."
if idx1 >= len(translations_list) or idx2 >= len(translations_list):
return "Translations not available for selected models."
text1 = translations_list[idx1]
text2 = translations_list[idx2]
if model1_name == model2_name:
return f"<h3>Comparing {model1_name} with itself:</h3><p>No differences.</p>"
diff_result = diff_two_strings(text1, text2)
return f"<h3>Differences between {model1_name} and {model2_name}:</h3><p>{diff_result}</p>"
def speak_specific_translation(selected_model_name: str, all_translations_list: List[str]) -> Optional[Tuple[int, np.ndarray]]:
"""Generates TTS audio for a specifically selected translation."""
if not all_translations_list or not selected_model_name:
gr.Warning("No translations available or model not selected.")
return None
# Find the index of the selected model name
try:
model_index = next(i for i, (name, _) in enumerate(models) if name == selected_model_name)
except StopIteration:
gr.Error(f"Selected model '{selected_model_name}' not found.")
return None
if model_index >= len(all_translations_list):
gr.Error("Translation text not found for the selected model.")
return None
text_to_speak = all_translations_list[model_index]
if "Error" in text_to_speak or "Timeout" in text_to_speak or "Connection Error" in text_to_speak:
gr.Warning(f"Cannot generate speech for an error translation from '{selected_model_name}'.")
return None
return generate_kabyle_audio_sync(text_to_speak)
def submit_suggestion_sync(text: str, suggestion: str, source_lang: str, target_lang: str) -> str:
"""Synchronous wrapper for suggestion submission (placeholder - not fully implemented)"""
# This function is a placeholder. Real submission would involve an API call to
# a service that accepts user suggestions to improve models.
print(f"Received suggestion: Original='{text}', Suggested='{suggestion}' from {source_lang} to {target_lang}")
return "Suggestion submitted (feature under development)!"
# Build UI
css = """
.gradio-container {
max-width: 1000px !important;
padding: 10px !important;
}
.gr-textbox {
width: 100% !important;
}
.gradio-button {
font-size: 0.9em;
padding: 8px 12px;
}
.gr-dropdown {
width: 100% !important;
}
.gradio-html table {
min-width: 600px;
display: block;
overflow-x: auto;
white-space: nowrap;
}
.gradio-html table td, .gradio-html table th {
padding: 6px 4px !important;
font-size: 0.85em;
}
.gradio-html span[style*='color:orange'],
.gradio-html span[style*='color:red'],
.gradio-html span[style*='background-color'] { /* Added for new diff styles */
font-size: 0.85em;
}
/* Specific styles for diff highlights */
.gr-html > div > p > span {
display: inline-block; /* Allows padding and margin */
margin: 0 1px; /* Small space between highlighted words */
}
@media (max-width: 768px) { /* Adjust breakpoint for tablets and smaller */
.gradio-container {
padding: 5px !important;
}
.gradio-button {
font-size: 0.8em;
padding: 6px 10px;
}
.gradio-html table td, .gradio-html table th {
font-size: 0.75em !important;
}
.gr-row { /* Force rows to stack on small screens */
flex-direction: column !important;
}
.gr-column { /* Make columns take full width when stacked */
width: 100% !important;
}
/* Adjust dropdowns in rows to be full width when stacked */
.gr-dropdown {
width: 100% !important;
}
/* Ensure buttons in a row stack nicely */
.gr-button-row button {
flex-grow: 1; /* Make buttons expand */
margin-bottom: 5px; /* Add some space between stacked buttons */
}
.gr-button-row { /* Apply flex-wrap to rows of buttons */
flex-wrap: wrap;
}
}
"""
with gr.Blocks(
title="Enhanced Kabyle Translator & TTS",
css=css,
theme=gr.themes.Default()
) as app:
gr.Markdown("## Enhanced Kabyle Translator & Text-to-Speech")
gr.Markdown("*Translate and listen to Kabyle translations from multiple models.*")
state_translations = gr.State(value=[]) # Stores the list of all translated strings
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Enter Text",
lines=3,
placeholder="Type text to translate here... (Max 5000 characters)",
max_lines=10
)
with gr.Row():
source_lang = gr.Dropdown(
list(langs.keys()),
label="From",
value="English",
scale=1
)
target_lang = gr.Dropdown(
list(langs.keys()),
label="To",
value="Taqbaylit (All models)",
interactive=False,
scale=1
)
translate_btn = gr.Button("Translate & Speak (Default Model)", variant="primary", size="lg")
table_output = gr.HTML()
# TTS Output section
with gr.Accordion("Listen to Translation", open=True):
tts_output_audio = gr.Audio(label="Kabyle Audio", autoplay=True, interactive=False, type="numpy")
with gr.Row():
tts_translation_selector = gr.Dropdown(
label="Select Translation to Speak",
choices=[], # Will be dynamically populated
interactive=True,
scale=2
)
speak_selected_btn = gr.Button("Speak Selected", variant="secondary", scale=1)
with gr.Accordion("Edit and Submit Suggestions", open=False):
suggestion_box = gr.Textbox(
label="Suggested Translation (editable)",
lines=3,
placeholder="Edit translation here...",
max_lines=5
)
with gr.Row():
for i, (name, _) in enumerate(models):
button_label = name.replace("Taqbaylit", "Kabyle").replace("(", "").replace(")", "").strip()
btn = gr.Button(f"Use {button_label}", size="sm")
btn.click(
fn=update_suggestion,
inputs=[state_translations, gr.State(i)],
outputs=suggestion_box
)
with gr.Row():
submit_btn = gr.Button("Submit Suggestion", variant="secondary")
status = gr.Textbox(label="Status", interactive=False, lines=1)
# New section for dynamic comparison
gr.Markdown("### Compare Any Two Translations")
with gr.Row():
model_dropdown_1 = gr.Dropdown(
label="Compare Model 1",
choices=[], # Will be populated after translation
interactive=True,
scale=1
)
model_dropdown_2 = gr.Dropdown(
label="Compare Model 2",
choices=[], # Will be populated after translation
interactive=True,
scale=1
)
compare_btn = gr.Button("Show Comparison", variant="secondary")
comparison_output = gr.HTML(label="Comparison Result")
# Add some usage tips
with gr.Accordion("Usage Tips", open=False):
gr.Markdown("""
- **Translate & Speak (Default Model)**: Get translations from multiple models, and hear the first model's translation spoken aloud automatically.
""")
# Event handlers
translate_btn.click(
translate_and_speak_with_progress,
inputs=[input_text, source_lang],
outputs=[
table_output,
suggestion_box,
state_translations, # Updated state for all translations
model_dropdown_1,
model_dropdown_2,
comparison_output,
tts_output_audio, # Initial audio from default model
tts_translation_selector # Update choices AND value for TTS selector with gr.update()
]
)
# New click event for speaking selected translation
speak_selected_btn.click(
speak_specific_translation,
inputs=[
tts_translation_selector, # The chosen model name from the dropdown
state_translations # The list of all translated texts
],
outputs=tts_output_audio # The audio player to update
)
submit_btn.click(
submit_suggestion_sync,
inputs=[input_text, suggestion_box, source_lang, target_lang],
outputs=status
)
compare_btn.click(
compare_selected_translations,
inputs=[state_translations, model_dropdown_1, model_dropdown_2],
outputs=comparison_output
)
if __name__ == "__main__":
load_tts_model() # Load TTS model once at startup
app.launch(server_name="0.0.0.0")