mms-transcription / server /translations_blueprint.py
EC2 Default User
Initial Transcription Commit
38818c3
from flask import Blueprint, jsonify, request
import json
import logging
import torch
import os
import tempfile
from env_vars import API_LOG_LEVEL
translations_blueprint = Blueprint(
"translations_blueprint",
__name__,
)
logger = logging.getLogger(__name__)
logger.level = API_LOG_LEVEL
logging.getLogger("boto3").setLevel(API_LOG_LEVEL)
logging.getLogger("botocore").setLevel(API_LOG_LEVEL)
def get_model():
"""Import the model getter function from server module"""
from server import get_model as server_get_model
return server_get_model()
def get_text_decoder():
"""Import the text decoder getter function from server module"""
from server import get_text_decoder as server_get_text_decoder
return server_get_text_decoder()
def get_device():
"""Import the device getter function from server module"""
from server import get_device as server_get_device
return server_get_device()
@translations_blueprint.route("/health")
def health():
"""Health check endpoint"""
model = get_model()
device = get_device()
return {
"status": "healthy",
"service": "translations",
"model_loaded": model is not None,
"device": str(device) if device else "unknown",
"cuda_available": torch.cuda.is_available(),
}
@translations_blueprint.route("/hello")
def hello():
"""Simple hello world endpoint"""
return {"message": "Hello from Translations API!"}
@translations_blueprint.route("/transcribe", methods=["POST"])
def transcribe_audio():
"""Transcribe audio using the MMS model"""
try:
# Check if model is loaded
model = get_model()
text_decoder = get_text_decoder()
device = get_device()
if model is None or text_decoder is None:
return (
jsonify({"error": "Model not loaded. Please check server logs."}),
503,
)
# Check if audio file is provided
if "audio" not in request.files:
return jsonify({"error": "No audio file provided"}), 400
audio_file = request.files["audio"]
if audio_file.filename == "":
return jsonify({"error": "No audio file selected"}), 400
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
audio_file.save(tmp_file.name)
temp_path = tmp_file.name
try:
# Import and use the complete transcription pipeline with alignment from model.py
from model import transcribe_audio_with_alignment
# Use the complete pipeline function with alignment
results = transcribe_audio_with_alignment(
wav_path=temp_path,
max_duration_seconds=10,
)
logger.info(f"Transcription with alignment completed: {results}")
# Format response with alignment data
response = {
"transcription": results.get("transcription", ""),
"model": "fairseq2-MMS",
"device": str(device),
"status": "success",
"total_duration": results.get("total_duration", 0.0),
"num_segments": results.get("num_segments", 0),
}
# Add alignment information if available
if results.get("aligned_segments"):
response["aligned_segments"] = results["aligned_segments"]
response["alignment_available"] = True
else:
response["alignment_available"] = False
if "alignment_error" in results:
response["alignment_error"] = results["alignment_error"]
return jsonify(response)
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
@translations_blueprint.route("/align", methods=["POST"])
def align_transcription():
"""Perform forced alignment on audio with provided transcription"""
try:
# Check if model is loaded
model = get_model()
if model is None:
return (
jsonify({"error": "Model not loaded. Please check server logs."}),
503,
)
# Check if audio file and transcription are provided
if "audio" not in request.files:
return jsonify({"error": "No audio file provided"}), 400
audio_file = request.files["audio"]
if audio_file.filename == "":
return jsonify({"error": "No audio file selected"}), 400
# Get transcription text from form data
transcription = request.form.get("transcription", "").strip()
if not transcription:
return jsonify({"error": "No transcription text provided"}), 400
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
audio_file.save(tmp_file.name)
temp_path = tmp_file.name
try:
# Import forced alignment function
from model import perform_forced_alignment
# Tokenize the transcription
tokens = transcription.split()
# Get model and device
model = get_model()
device = get_device()
if model is None or device is None:
return jsonify({"error": "Model not available for alignment"}), 503
# Perform forced alignment
aligned_segments = perform_forced_alignment(
temp_path, tokens, model, device
)
# Calculate total duration
total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0
logger.info(f"Forced alignment completed: {len(aligned_segments)} segments")
return jsonify(
{
"transcription": transcription,
"aligned_segments": aligned_segments,
"total_duration": total_duration,
"num_segments": len(aligned_segments),
"status": "success",
}
)
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Alignment error: {str(e)}")
return jsonify({"error": f"Alignment failed: {str(e)}"}), 500
@translations_blueprint.route("/translate", methods=["POST"])
def translate():
"""Main translation endpoint using fairseq2"""
try:
# Check if model is loaded
model = get_model()
if model is None:
return (
jsonify({"error": "Model not loaded. Please check server logs."}),
503,
)
data = request.get_json()
# Validate input
if not data or "text" not in data:
return jsonify({"error": "Missing 'text' field in request"}), 400
text = data["text"]
source_lang = data.get("source_lang", "en")
target_lang = data.get("target_lang", "es")
logger.info(f"Translation request: {source_lang} -> {target_lang}: '{text}'")
logger.info(f"Model loaded: {model is not None}")
# Use fairseq2 model for translation
# TODO: Implement actual model inference for translation here
translation = f"[fairseq2-MMS] Translation of '{text}' from {source_lang} to {target_lang}"
result = {
"original_text": text,
"translated_text": translation,
"source_language": source_lang,
"target_language": target_lang,
"model": "fairseq2-MMS",
"model_loaded": model is not None,
}
logger.info(f"Translation completed: {result}")
return jsonify(result)
except Exception as e:
logger.error(f"Translation error: {str(e)}")
return jsonify({"error": f"Translation failed: {str(e)}"}), 500
@translations_blueprint.route("/models", methods=["GET"])
def list_models():
"""List available translation models"""
model = get_model()
models = {
"available_models": [
{
"name": "fairseq2-MMS",
"description": "fairseq2 MMS (Massively Multilingual Speech) speech-to-text model",
"available": model is not None,
"loaded": model is not None,
"capabilities": ["speech-to-text", "transcription"],
}
],
"default_model": "fairseq2-MMS",
}
return jsonify(models)
@translations_blueprint.route("/languages", methods=["GET"])
def supported_languages():
"""Get list of supported languages"""
languages = {
"supported_languages": [
{"code": "en", "name": "English"},
{"code": "es", "name": "Spanish"},
{"code": "fr", "name": "French"},
{"code": "de", "name": "German"},
{"code": "it", "name": "Italian"},
{"code": "pt", "name": "Portuguese"},
{"code": "ar", "name": "Arabic"},
{"code": "zh", "name": "Chinese"},
{"code": "ja", "name": "Japanese"},
{"code": "ko", "name": "Korean"},
],
"note": "MMS model supports 1143+ languages",
}
return jsonify(languages)
@translations_blueprint.route("/test_model", methods=["GET"])
def test_model():
"""Test endpoint to verify model functionality"""
try:
model = get_model()
text_decoder = get_text_decoder()
device = get_device()
if model is None:
return jsonify({"error": "Model not loaded"}), 503
# Return model status
return jsonify(
{
"model_loaded": True,
"device": str(device),
"model_type": str(type(model).__name__),
"text_decoder_available": text_decoder is not None,
"cuda_available": torch.cuda.is_available(),
"status": "Model is ready for inference",
}
)
except Exception as e:
logger.error(f"Model test error: {str(e)}")
return jsonify({"error": f"Model test failed: {str(e)}"}), 500