EC2 Default User
Added basic frontend, dockerfile
0f60365
import argparse
import logging
import os
import sys
# Set up basic logging first
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
# Using direct imports
from env_vars import API_LOG_LEVEL
from flask import Flask, jsonify, send_file, send_from_directory
from flask_cors import CORS
from translations_blueprint import translations_blueprint
# Configure logging with imported level
logging.basicConfig(stream=sys.stdout, level=API_LOG_LEVEL)
# Global variables to store the loaded model components
global_model = None
global_text_decoder = None
global_device = None
_model_loaded = False
_model_loading = False
def load_model():
"""Load the MMS model on startup - only called once"""
global global_model, global_text_decoder, global_device, _model_loaded, _model_loading
# If model is already loaded, return it
if _model_loaded and global_model is not None:
logger.info("Model already loaded, returning existing instance")
return global_model
# If model is currently being loaded by another thread/process, wait
if _model_loading:
logger.info("Model is currently being loaded, waiting...")
return None
try:
_model_loading = True
logger.info("Loading MMS model...")
# Get models directory from environment, with fallback
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models")
ckpt_path = os.path.join(models_dir, "mms_XRI.pt")
tokenizer_path = os.path.join(models_dir, "mms_1143_langs_tokenizer_spm.model")
# Import the load function from model.py
from model import load_mms_model
global_model, global_text_decoder, global_device = load_mms_model(
ckpt_path, tokenizer_path
)
_model_loaded = True
logger.info(f"MMS model loaded successfully on device: {global_device}")
return global_model
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
global_model = None
global_text_decoder = None
global_device = None
_model_loaded = False
return None
finally:
_model_loading = False
def get_model():
"""Get the loaded model instance - loads if not already loaded"""
global global_model, _model_loaded
if not _model_loaded or global_model is None:
logger.info("Model not loaded, attempting to load...")
load_model()
return global_model
def get_text_decoder():
"""Get the loaded text decoder"""
global global_text_decoder, _model_loaded
if not _model_loaded or global_text_decoder is None:
logger.info("Text decoder not loaded, attempting to load model...")
load_model()
return global_text_decoder
def get_device():
"""Get the device the model is loaded on"""
global global_device, _model_loaded
if not _model_loaded or global_device is None:
logger.info("Device not set, attempting to load model...")
load_model()
return global_device
app = Flask(__name__)
app.register_blueprint(translations_blueprint)
cors = CORS(
app,
resources={
r"/*": {
"origins": "*",
"allow_headers": "*",
"expose_headers": "*",
"supports_credentials": True,
}
},
)
logger = logging.getLogger(__name__)
gunicorn_logger = logging.getLogger("gunicorn.error")
app.logger.handlers = gunicorn_logger.handlers
app.logger.setLevel(gunicorn_logger.level)
# Load model on startup - only once during app initialization
logger.info("Initializing application and loading model...")
if not _model_loaded:
load_model()
else:
logger.info("Model already loaded, skipping initialization")
# Frontend static file serving
@app.route("/")
def serve_frontend():
"""Serve the frontend index.html"""
frontend_dist = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
)
return send_file(os.path.join(frontend_dist, "index.html"))
@app.route("/assets/<path:filename>")
def serve_assets(filename):
"""Serve frontend static assets"""
frontend_dist = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
)
return send_from_directory(os.path.join(frontend_dist, "assets"), filename)
@app.route("/api/health")
def health_check():
"""API health check endpoint"""
return {"message": "Translations API is running", "version": "1.0.0"}
# Catch-all route for SPA routing - must be last
@app.route("/<path:path>")
def serve_spa(path):
"""Serve index.html for any unmatched routes (SPA routing)"""
# If the path starts with 'api/', return 404 for API routes
if path.startswith("api/"):
return jsonify({"error": "API endpoint not found"}), 404
# For all other paths, serve the frontend index.html
frontend_dist = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "frontend", "dist"
)
return send_file(os.path.join(frontend_dist, "index.html"))
@app.errorhandler(404)
def handle_404(e):
return jsonify({"error": "Endpoint not found"}), 404
@app.errorhandler(500)
def handle_500(e):
logger.error(f"Internal server error: {str(e)}")
return jsonify({"error": "Internal server error"}), 500
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--debug", default=True, type=bool)
args = parser.parse_args()
logger.info(f"Starting Translations API on {args.host}:{args.port}")
app.run(host=args.host, port=args.port, debug=args.debug)