import argparse import logging import os import sys # Set up basic logging first logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger(__name__) # Using direct imports from env_vars import API_LOG_LEVEL from flask import Flask, jsonify, send_file, send_from_directory from flask_cors import CORS from translations_blueprint import translations_blueprint # Configure logging with imported level logging.basicConfig(stream=sys.stdout, level=API_LOG_LEVEL) # Global variables to store the loaded model components global_model = None global_text_decoder = None global_device = None _model_loaded = False _model_loading = False def load_model(): """Load the MMS model on startup - only called once""" global global_model, global_text_decoder, global_device, _model_loaded, _model_loading # If model is already loaded, return it if _model_loaded and global_model is not None: logger.info("Model already loaded, returning existing instance") return global_model # If model is currently being loaded by another thread/process, wait if _model_loading: logger.info("Model is currently being loaded, waiting...") return None try: _model_loading = True logger.info("Loading MMS model...") # Get models directory from environment, with fallback models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") ckpt_path = os.path.join(models_dir, "mms_XRI.pt") tokenizer_path = os.path.join(models_dir, "mms_1143_langs_tokenizer_spm.model") # Import the load function from model.py from model import load_mms_model global_model, global_text_decoder, global_device = load_mms_model( ckpt_path, tokenizer_path ) _model_loaded = True logger.info(f"MMS model loaded successfully on device: {global_device}") return global_model except Exception as e: logger.error(f"Failed to load model: {str(e)}") global_model = None global_text_decoder = None global_device = None _model_loaded = False return None finally: _model_loading = False def get_model(): """Get the loaded model instance - loads if not already loaded""" global global_model, _model_loaded if not _model_loaded or global_model is None: logger.info("Model not loaded, attempting to load...") load_model() return global_model def get_text_decoder(): """Get the loaded text decoder""" global global_text_decoder, _model_loaded if not _model_loaded or global_text_decoder is None: logger.info("Text decoder not loaded, attempting to load model...") load_model() return global_text_decoder def get_device(): """Get the device the model is loaded on""" global global_device, _model_loaded if not _model_loaded or global_device is None: logger.info("Device not set, attempting to load model...") load_model() return global_device app = Flask(__name__) app.register_blueprint(translations_blueprint) cors = CORS( app, resources={ r"/*": { "origins": "*", "allow_headers": "*", "expose_headers": "*", "supports_credentials": True, } }, ) logger = logging.getLogger(__name__) gunicorn_logger = logging.getLogger("gunicorn.error") app.logger.handlers = gunicorn_logger.handlers app.logger.setLevel(gunicorn_logger.level) # Load model on startup - only once during app initialization logger.info("Initializing application and loading model...") if not _model_loaded: load_model() else: logger.info("Model already loaded, skipping initialization") # Frontend static file serving @app.route("/") def serve_frontend(): """Serve the frontend index.html""" frontend_dist = os.path.join( os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" ) return send_file(os.path.join(frontend_dist, "index.html")) @app.route("/assets/") def serve_assets(filename): """Serve frontend static assets""" frontend_dist = os.path.join( os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" ) return send_from_directory(os.path.join(frontend_dist, "assets"), filename) @app.route("/api/health") def health_check(): """API health check endpoint""" return {"message": "Translations API is running", "version": "1.0.0"} # Catch-all route for SPA routing - must be last @app.route("/") def serve_spa(path): """Serve index.html for any unmatched routes (SPA routing)""" # If the path starts with 'api/', return 404 for API routes if path.startswith("api/"): return jsonify({"error": "API endpoint not found"}), 404 # For all other paths, serve the frontend index.html frontend_dist = os.path.join( os.path.dirname(os.path.dirname(__file__)), "frontend", "dist" ) return send_file(os.path.join(frontend_dist, "index.html")) @app.errorhandler(404) def handle_404(e): return jsonify({"error": "Endpoint not found"}), 404 @app.errorhandler(500) def handle_500(e): logger.error(f"Internal server error: {str(e)}") return jsonify({"error": "Internal server error"}), 500 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", default=5000, type=int) parser.add_argument("--debug", default=True, type=bool) args = parser.parse_args() logger.info(f"Starting Translations API on {args.host}:{args.port}") app.run(host=args.host, port=args.port, debug=args.debug)