Spaces:
Runtime error
Runtime error
from flask import Blueprint, jsonify, request | |
import json | |
import logging | |
import torch | |
import os | |
import tempfile | |
from env_vars import API_LOG_LEVEL | |
translations_blueprint = Blueprint( | |
"translations_blueprint", | |
__name__, | |
) | |
logger = logging.getLogger(__name__) | |
logger.level = API_LOG_LEVEL | |
logging.getLogger("boto3").setLevel(API_LOG_LEVEL) | |
logging.getLogger("botocore").setLevel(API_LOG_LEVEL) | |
def get_model(): | |
"""Import the model getter function from server module""" | |
from server import get_model as server_get_model | |
return server_get_model() | |
def get_text_decoder(): | |
"""Import the text decoder getter function from server module""" | |
from server import get_text_decoder as server_get_text_decoder | |
return server_get_text_decoder() | |
def get_device(): | |
"""Import the device getter function from server module""" | |
from server import get_device as server_get_device | |
return server_get_device() | |
def health(): | |
"""Health check endpoint""" | |
model = get_model() | |
device = get_device() | |
return { | |
"status": "healthy", | |
"service": "translations", | |
"model_loaded": model is not None, | |
"device": str(device) if device else "unknown", | |
"cuda_available": torch.cuda.is_available(), | |
} | |
def hello(): | |
"""Simple hello world endpoint""" | |
return {"message": "Hello from Translations API!"} | |
def transcribe_audio(): | |
"""Transcribe audio using the MMS model""" | |
try: | |
# Check if model is loaded | |
model = get_model() | |
text_decoder = get_text_decoder() | |
device = get_device() | |
if model is None or text_decoder is None: | |
return ( | |
jsonify({"error": "Model not loaded. Please check server logs."}), | |
503, | |
) | |
# Check if audio file is provided | |
if "audio" not in request.files: | |
return jsonify({"error": "No audio file provided"}), 400 | |
audio_file = request.files["audio"] | |
if audio_file.filename == "": | |
return jsonify({"error": "No audio file selected"}), 400 | |
# Save uploaded file temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
audio_file.save(tmp_file.name) | |
temp_path = tmp_file.name | |
try: | |
# Import and use the complete transcription pipeline with alignment from model.py | |
from model import transcribe_audio_with_alignment | |
# Use the complete pipeline function with alignment | |
results = transcribe_audio_with_alignment( | |
wav_path=temp_path, | |
max_duration_seconds=10, | |
) | |
logger.info(f"Transcription with alignment completed: {results}") | |
# Format response with alignment data | |
response = { | |
"transcription": results.get("transcription", ""), | |
"model": "fairseq2-MMS", | |
"device": str(device), | |
"status": "success", | |
"total_duration": results.get("total_duration", 0.0), | |
"num_segments": results.get("num_segments", 0), | |
} | |
# Add alignment information if available | |
if results.get("aligned_segments"): | |
response["aligned_segments"] = results["aligned_segments"] | |
response["alignment_available"] = True | |
else: | |
response["alignment_available"] = False | |
if "alignment_error" in results: | |
response["alignment_error"] = results["alignment_error"] | |
return jsonify(response) | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"Transcription error: {str(e)}") | |
return jsonify({"error": f"Transcription failed: {str(e)}"}), 500 | |
def align_transcription(): | |
"""Perform forced alignment on audio with provided transcription""" | |
try: | |
# Check if model is loaded | |
model = get_model() | |
if model is None: | |
return ( | |
jsonify({"error": "Model not loaded. Please check server logs."}), | |
503, | |
) | |
# Check if audio file and transcription are provided | |
if "audio" not in request.files: | |
return jsonify({"error": "No audio file provided"}), 400 | |
audio_file = request.files["audio"] | |
if audio_file.filename == "": | |
return jsonify({"error": "No audio file selected"}), 400 | |
# Get transcription text from form data | |
transcription = request.form.get("transcription", "").strip() | |
if not transcription: | |
return jsonify({"error": "No transcription text provided"}), 400 | |
# Save uploaded file temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
audio_file.save(tmp_file.name) | |
temp_path = tmp_file.name | |
try: | |
# Import forced alignment function | |
from model import perform_forced_alignment | |
# Tokenize the transcription | |
tokens = transcription.split() | |
# Get model and device | |
model = get_model() | |
device = get_device() | |
if model is None or device is None: | |
return jsonify({"error": "Model not available for alignment"}), 503 | |
# Perform forced alignment | |
aligned_segments = perform_forced_alignment( | |
temp_path, tokens, model, device | |
) | |
# Calculate total duration | |
total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0 | |
logger.info(f"Forced alignment completed: {len(aligned_segments)} segments") | |
return jsonify( | |
{ | |
"transcription": transcription, | |
"aligned_segments": aligned_segments, | |
"total_duration": total_duration, | |
"num_segments": len(aligned_segments), | |
"status": "success", | |
} | |
) | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"Alignment error: {str(e)}") | |
return jsonify({"error": f"Alignment failed: {str(e)}"}), 500 | |
def translate(): | |
"""Main translation endpoint using fairseq2""" | |
try: | |
# Check if model is loaded | |
model = get_model() | |
if model is None: | |
return ( | |
jsonify({"error": "Model not loaded. Please check server logs."}), | |
503, | |
) | |
data = request.get_json() | |
# Validate input | |
if not data or "text" not in data: | |
return jsonify({"error": "Missing 'text' field in request"}), 400 | |
text = data["text"] | |
source_lang = data.get("source_lang", "en") | |
target_lang = data.get("target_lang", "es") | |
logger.info(f"Translation request: {source_lang} -> {target_lang}: '{text}'") | |
logger.info(f"Model loaded: {model is not None}") | |
# Use fairseq2 model for translation | |
# TODO: Implement actual model inference for translation here | |
translation = f"[fairseq2-MMS] Translation of '{text}' from {source_lang} to {target_lang}" | |
result = { | |
"original_text": text, | |
"translated_text": translation, | |
"source_language": source_lang, | |
"target_language": target_lang, | |
"model": "fairseq2-MMS", | |
"model_loaded": model is not None, | |
} | |
logger.info(f"Translation completed: {result}") | |
return jsonify(result) | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}") | |
return jsonify({"error": f"Translation failed: {str(e)}"}), 500 | |
def list_models(): | |
"""List available translation models""" | |
model = get_model() | |
models = { | |
"available_models": [ | |
{ | |
"name": "fairseq2-MMS", | |
"description": "fairseq2 MMS (Massively Multilingual Speech) speech-to-text model", | |
"available": model is not None, | |
"loaded": model is not None, | |
"capabilities": ["speech-to-text", "transcription"], | |
} | |
], | |
"default_model": "fairseq2-MMS", | |
} | |
return jsonify(models) | |
def supported_languages(): | |
"""Get list of supported languages""" | |
languages = { | |
"supported_languages": [ | |
{"code": "en", "name": "English"}, | |
{"code": "es", "name": "Spanish"}, | |
{"code": "fr", "name": "French"}, | |
{"code": "de", "name": "German"}, | |
{"code": "it", "name": "Italian"}, | |
{"code": "pt", "name": "Portuguese"}, | |
{"code": "ar", "name": "Arabic"}, | |
{"code": "zh", "name": "Chinese"}, | |
{"code": "ja", "name": "Japanese"}, | |
{"code": "ko", "name": "Korean"}, | |
], | |
"note": "MMS model supports 1143+ languages", | |
} | |
return jsonify(languages) | |
def test_model(): | |
"""Test endpoint to verify model functionality""" | |
try: | |
model = get_model() | |
text_decoder = get_text_decoder() | |
device = get_device() | |
if model is None: | |
return jsonify({"error": "Model not loaded"}), 503 | |
# Return model status | |
return jsonify( | |
{ | |
"model_loaded": True, | |
"device": str(device), | |
"model_type": str(type(model).__name__), | |
"text_decoder_available": text_decoder is not None, | |
"cuda_available": torch.cuda.is_available(), | |
"status": "Model is ready for inference", | |
} | |
) | |
except Exception as e: | |
logger.error(f"Model test error: {str(e)}") | |
return jsonify({"error": f"Model test failed: {str(e)}"}), 500 | |