Spaces:
Runtime error
Runtime error
import gradio as gr | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import langdetect | |
import logging | |
import os | |
from typing import Optional | |
import re | |
from functools import lru_cache | |
import asyncio | |
import threading | |
import time | |
# Create necessary directories | |
os.makedirs("./cache", exist_ok=True) | |
os.makedirs("./logs", exist_ok=True) | |
# Set environment variables for Hugging Face cache | |
os.environ["HF_HOME"] = "./cache" | |
os.environ["TRANSFORMERS_CACHE"] = "./cache" | |
# Environment configuration | |
DEVICE = -1 # Always use CPU for HF Spaces | |
MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "5000")) | |
# Configure logging | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
# Map of supported language models | |
MODEL_MAP = { | |
"th": "Helsinki-NLP/opus-mt-th-en", | |
"ja": "Helsinki-NLP/opus-mt-ja-en", | |
"zh": "Helsinki-NLP/opus-mt-zh-en", | |
"vi": "Helsinki-NLP/opus-mt-vi-en", | |
} | |
# List of terms to protect from translation | |
PROTECTED_TERMS = ["2030 Aspirations", "Griffith"] | |
# Cache for translators | |
translators = {} | |
# Pydantic models | |
class TranslationRequest(BaseModel): | |
text: str | |
source_lang_override: Optional[str] = None | |
class TranslationResponse(BaseModel): | |
translated_text: str | |
source_language: Optional[str] = None | |
# FastAPI app | |
app = FastAPI(title="Translation Service API") | |
def get_translator(lang: str): | |
"""Load or retrieve cached translator for the given language.""" | |
if lang not in translators: | |
logger.info(f"Loading model for {lang}...") | |
try: | |
translators[lang] = pipeline( | |
"translation", | |
model=MODEL_MAP[lang], | |
device=-1 | |
) | |
logger.info(f"Model for {lang} loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model for {lang}: {str(e)}") | |
raise | |
return translators[lang] | |
def detect_language(text: str) -> str: | |
"""Cached language detection.""" | |
try: | |
detected_lang = langdetect.detect(text) | |
if detected_lang.startswith('zh'): | |
return 'zh' | |
return detected_lang if detected_lang in MODEL_MAP else "en" | |
except Exception as e: | |
logger.warning(f"Language detection failed: {str(e)}") | |
return "en" | |
def protect_terms(text: str, protected_terms: list) -> tuple[str, dict]: | |
"""Replace protected terms with placeholders using more robust patterns.""" | |
modified_text = text | |
replacements = {} | |
for i, term in enumerate(protected_terms): | |
# Create a unique placeholder | |
placeholder = f"PROTECTEDTERM{i}PLACEHOLDER" | |
replacements[placeholder] = term | |
# Use multiple patterns to catch the term | |
patterns = [ | |
# Exact match with word boundaries | |
r'\b' + re.escape(term) + r'\b', | |
# Case insensitive match | |
r'(?i)\b' + re.escape(term) + r'\b', | |
# Match with potential spaces/punctuation | |
re.escape(term).replace(r'\ ', r'\s+'), | |
] | |
for pattern in patterns: | |
if re.search(pattern, modified_text): | |
modified_text = re.sub(pattern, placeholder, modified_text) | |
logger.debug(f"Protected term '{term}' replaced with '{placeholder}'") | |
break | |
return modified_text, replacements | |
def restore_terms(text: str, replacements: dict) -> str: | |
"""Restore protected terms in the translated text with fuzzy matching.""" | |
restored_text = text | |
for placeholder, original_term in replacements.items(): | |
# Direct replacement | |
if placeholder in restored_text: | |
restored_text = restored_text.replace(placeholder, original_term) | |
logger.debug(f"Restored '{placeholder}' to '{original_term}'") | |
else: | |
# Try to find partial matches or corrupted placeholders | |
# Sometimes translation models might alter the placeholder slightly | |
words = restored_text.split() | |
for i, word in enumerate(words): | |
# Check if word contains part of our placeholder | |
if "PROTECTEDTERM" in word and "PLACEHOLDER" in word: | |
words[i] = original_term | |
logger.debug(f"Fuzzy restored corrupted placeholder '{word}' to '{original_term}'") | |
# Also check for common corruptions | |
elif word.upper().replace(".", "").replace(",", "") == placeholder.upper(): | |
words[i] = original_term | |
logger.debug(f"Restored corrupted '{word}' to '{original_term}'") | |
restored_text = " ".join(words) | |
# Clean up any remaining artifacts (dots, extra spaces) | |
restored_text = re.sub(r'\s*\.\s*\.\s*\.\s*\.+', '', restored_text) # Remove multiple dots | |
restored_text = re.sub(r'\s+', ' ', restored_text) # Normalize spaces | |
restored_text = restored_text.strip() | |
return restored_text | |
# FastAPI endpoints | |
async def root(): | |
return {"message": "Translation Service API is running"} | |
async def health_check(): | |
return {"status": "healthy", "supported_languages": list(MODEL_MAP.keys())} | |
async def translate_api(request: TranslationRequest): | |
"""API endpoint for translation.""" | |
return await translate(request.text, request.source_lang_override) | |
# Core translation function | |
async def translate(text: str, source_lang_override: Optional[str] = None): | |
"""Core translation function used by both API and Gradio.""" | |
if not text or not text.strip(): | |
raise HTTPException(status_code=400, detail="Text input is required.") | |
if len(text) > MAX_TEXT_LENGTH: | |
raise HTTPException( | |
status_code=413, | |
detail=f"Text too long. Max allowed length: {MAX_TEXT_LENGTH}." | |
) | |
try: | |
# Determine source language | |
if source_lang_override and source_lang_override in MODEL_MAP: | |
source_lang = source_lang_override | |
else: | |
source_lang = detect_language(text) | |
# If source language is English, return original text | |
if source_lang == "en": | |
return TranslationResponse( | |
translated_text=text, | |
source_language=source_lang | |
) | |
# Get translator | |
translator = get_translator(source_lang) | |
# Protect terms before translation | |
modified_text, replacements = protect_terms(text, PROTECTED_TERMS) | |
logger.debug(f"Original text: '{text}'") | |
logger.debug(f"Modified text: '{modified_text}'") | |
logger.debug(f"Replacements: {replacements}") | |
# Perform translation with more conservative settings | |
result = translator( | |
modified_text, | |
max_length=512, | |
num_beams=2, # Reduced from 4 to be more conservative | |
do_sample=False, | |
early_stopping=True, | |
no_repeat_ngram_size=2 | |
) | |
translated_text = result[0]["translation_text"] | |
logger.debug(f"Raw translation: '{translated_text}'") | |
# Restore protected terms | |
final_text = restore_terms(translated_text, replacements) | |
logger.debug(f"Final text after restoration: '{final_text}'") | |
return TranslationResponse( | |
translated_text=final_text, | |
source_language=source_lang | |
) | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
# Gradio interface functions | |
def translate_gradio(text: str, source_lang: str = "auto"): | |
"""Gradio wrapper for translation function.""" | |
if not text.strip(): | |
return "Please enter some text to translate.", "N/A" | |
try: | |
source_lang_param = source_lang if source_lang != "auto" else None | |
# Call the async function synchronously for Gradio | |
import asyncio | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
result = loop.run_until_complete(translate(text, source_lang_param)) | |
return result.translated_text, result.source_language or "Unknown" | |
except HTTPException as e: | |
return f"Error: {e.detail}", "Error" | |
except Exception as e: | |
return f"Error: {str(e)}", "Error" | |
# Create Gradio interface | |
def create_gradio_interface(): | |
with gr.Blocks( | |
title="Multi-Language Translation Service", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
""" | |
) as interface: | |
gr.Markdown(""" | |
# 🌐 Multi-Language Translation Service | |
Translate text from **Thai**, **Japanese**, **Chinese**, or **Vietnamese** to **English** | |
✨ Features: Automatic language detection • Protected terms preservation • Fast Helsinki-NLP models | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="📝 Input Text", | |
placeholder="Enter text to translate...", | |
lines=6, | |
max_lines=10 | |
) | |
with gr.Row(): | |
lang_dropdown = gr.Dropdown( | |
choices=[ | |
("🔍 Auto-detect", "auto"), | |
("🇹🇭 Thai", "th"), | |
("🇯🇵 Japanese", "ja"), | |
("🇨🇳 Chinese", "zh"), | |
("🇻🇳 Vietnamese", "vi") | |
], | |
value="auto", | |
label="Source Language" | |
) | |
translate_btn = gr.Button( | |
"🚀 Translate", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
output_text = gr.Textbox( | |
label="🎯 Translation Result", | |
lines=6, | |
max_lines=10, | |
interactive=False | |
) | |
detected_lang = gr.Textbox( | |
label="🔍 Detected Language", | |
interactive=False, | |
max_lines=1 | |
) | |
# Examples section | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
["สวัสดีครับ ยินดีที่ได้รู้จัก การพัฒนา 2030 Aspirations เป็นเป้าหมายสำคัญ", "th"], | |
["ฉันเลือกทานอาหารที่ดีต่อสุขภาพร่างกายเพื่อเป็นส่วนหนึ่งในการสนับสนุน 2030 Aspirations", "th"], | |
["こんにちは、はじめまして。Griffith大学での研究が進んでいます。", "ja"], | |
["你好,很高兴认识你。我们正在为2030 Aspirations制定计划。", "zh"], | |
["Xin chào, rất vui được gặp bạn. Griffith là trường đại học tuyệt vời.", "vi"], | |
], | |
inputs=[text_input, lang_dropdown], | |
outputs=[output_text, detected_lang], | |
fn=translate_gradio, | |
cache_examples=False, | |
label="📋 Try these examples:" | |
) | |
# Event handlers | |
translate_btn.click( | |
fn=translate_gradio, | |
inputs=[text_input, lang_dropdown], | |
outputs=[output_text, detected_lang] | |
) | |
text_input.submit( | |
fn=translate_gradio, | |
inputs=[text_input, lang_dropdown], | |
outputs=[output_text, detected_lang] | |
) | |
# Information accordion | |
with gr.Accordion("ℹ️ About this service", open=False): | |
gr.Markdown(""" | |
### 🎯 Supported Languages: | |
- **Thai (th)** → English | |
- **Japanese (ja)** → English | |
- **Chinese (zh)** → English | |
- **Vietnamese (vi)** → English | |
### 🛡️ Special Features: | |
- **Protected Terms**: Certain terms like "2030 Aspirations" and "Griffith" are preserved during translation | |
- **Auto Detection**: Automatically detects the source language if not specified | |
- **Fast Processing**: Uses optimized Helsinki-NLP translation models | |
### 🚀 How to use: | |
1. Paste or type your text in the input box | |
2. Choose source language or leave as 'Auto-detect' | |
3. Click 'Translate' or press Enter | |
4. Get your English translation instantly! | |
### 🔧 API Access: | |
This service also provides REST API endpoints: | |
- `GET /health` - Check service status | |
- `POST /translate` - Translate text (JSON payload required) | |
""") | |
return interface | |
# Start FastAPI in a separate thread | |
def start_fastapi(): | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") | |
# Main execution | |
if __name__ == "__main__": | |
# Start FastAPI server in background thread | |
fastapi_thread = threading.Thread(target=start_fastapi, daemon=True) | |
fastapi_thread.start() | |
# Give FastAPI time to start | |
time.sleep(2) | |
# Create and launch Gradio interface | |
demo = create_gradio_interface() | |
demo.queue(max_size=10) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7861, | |
share=False, | |
show_error=True | |
) |