Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os | |
import sys | |
# Set up basic logging first | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Using direct imports | |
from env_vars import API_LOG_LEVEL | |
from flask import Flask, jsonify, send_file, send_from_directory | |
from flask_cors import CORS | |
from translations_blueprint import translations_blueprint | |
# Configure logging with imported level | |
logging.basicConfig(stream=sys.stdout, level=API_LOG_LEVEL) | |
# Global variables to store the loaded model components | |
global_model = None | |
global_text_decoder = None | |
global_device = None | |
_model_loaded = False | |
_model_loading = False | |
def load_model(): | |
"""Load the MMS model on startup - only called once""" | |
global global_model, global_text_decoder, global_device, _model_loaded, _model_loading | |
# If model is already loaded, return it | |
if _model_loaded and global_model is not None: | |
logger.info("Model already loaded, returning existing instance") | |
return global_model | |
# If model is currently being loaded by another thread/process, wait | |
if _model_loading: | |
logger.info("Model is currently being loaded, waiting...") | |
return None | |
try: | |
_model_loading = True | |
logger.info("Loading MMS model...") | |
# Get models directory from environment, with fallback | |
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
ckpt_path = os.path.join(models_dir, "mms_XRI.pt") | |
tokenizer_path = os.path.join(models_dir, "mms_1143_langs_tokenizer_spm.model") | |
# Import the load function from model.py | |
from model import load_mms_model | |
global_model, global_text_decoder, global_device = load_mms_model( | |
ckpt_path, tokenizer_path | |
) | |
_model_loaded = True | |
logger.info(f"MMS model loaded successfully on device: {global_device}") | |
return global_model | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
global_model = None | |
global_text_decoder = None | |
global_device = None | |
_model_loaded = False | |
return None | |
finally: | |
_model_loading = False | |
def get_model(): | |
"""Get the loaded model instance - loads if not already loaded""" | |
global global_model, _model_loaded | |
if not _model_loaded or global_model is None: | |
logger.info("Model not loaded, attempting to load...") | |
load_model() | |
return global_model | |
def get_text_decoder(): | |
"""Get the loaded text decoder""" | |
global global_text_decoder, _model_loaded | |
if not _model_loaded or global_text_decoder is None: | |
logger.info("Text decoder not loaded, attempting to load model...") | |
load_model() | |
return global_text_decoder | |
def get_device(): | |
"""Get the device the model is loaded on""" | |
global global_device, _model_loaded | |
if not _model_loaded or global_device is None: | |
logger.info("Device not set, attempting to load model...") | |
load_model() | |
return global_device | |
app = Flask(__name__) | |
app.register_blueprint(translations_blueprint) | |
cors = CORS( | |
app, | |
resources={ | |
r"/*": { | |
"origins": "*", | |
"allow_headers": "*", | |
"expose_headers": "*", | |
"supports_credentials": True, | |
} | |
}, | |
) | |
logger = logging.getLogger(__name__) | |
gunicorn_logger = logging.getLogger("gunicorn.error") | |
app.logger.handlers = gunicorn_logger.handlers | |
app.logger.setLevel(gunicorn_logger.level) | |
# Load model on startup - only once during app initialization | |
logger.info("Initializing application and loading model...") | |
if not _model_loaded: | |
load_model() | |
else: | |
logger.info("Model already loaded, skipping initialization") | |
# Frontend static file serving | |
def serve_frontend(): | |
"""Serve the frontend index.html""" | |
frontend_dist = os.path.join( | |
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" | |
) | |
return send_file(os.path.join(frontend_dist, "index.html")) | |
def serve_assets(filename): | |
"""Serve frontend static assets""" | |
frontend_dist = os.path.join( | |
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" | |
) | |
return send_from_directory(os.path.join(frontend_dist, "assets"), filename) | |
def health_check(): | |
"""API health check endpoint""" | |
return {"message": "Translations API is running", "version": "1.0.0"} | |
# Catch-all route for SPA routing - must be last | |
def serve_spa(path): | |
"""Serve index.html for any unmatched routes (SPA routing)""" | |
# If the path starts with 'api/', return 404 for API routes | |
if path.startswith("api/"): | |
return jsonify({"error": "API endpoint not found"}), 404 | |
# For all other paths, serve the frontend index.html | |
frontend_dist = os.path.join( | |
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" | |
) | |
return send_file(os.path.join(frontend_dist, "index.html")) | |
def handle_404(e): | |
return jsonify({"error": "Endpoint not found"}), 404 | |
def handle_500(e): | |
logger.error(f"Internal server error: {str(e)}") | |
return jsonify({"error": "Internal server error"}), 500 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", default="0.0.0.0") | |
parser.add_argument("--port", default=5000, type=int) | |
parser.add_argument("--debug", default=True, type=bool) | |
args = parser.parse_args() | |
logger.info(f"Starting Translations API on {args.host}:{args.port}") | |
app.run(host=args.host, port=args.port, debug=args.debug) | |