import os from pathlib import Path from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn from loguru import logger from iopaint.api import Api from iopaint.schema import ApiConfig, Device from iopaint.runtime import setup_model_dir, check_device, dump_environment_info from iopaint.const import DEFAULT_MODEL_DIR # 从环境变量读取配置 host = os.environ.get("IOPAINT_HOST", "0.0.0.0") port = int(os.environ.get("IOPAINT_PORT", "7860")) model = os.environ.get("IOPAINT_MODEL", "lama") # 修改模型目录路径,使用/app或/tmp目录 model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models") device_str = os.environ.get("IOPAINT_DEVICE", "cpu") api_key = os.environ.get("IOPAINT_API_KEY", None) allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",") # 初始化目录和环境 model_dir = Path(model_dir_str) try: model_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Successfully created model directory: {model_dir}") except Exception as e: logger.error(f"Failed to create model directory: {e}") # 如果失败,尝试使用/tmp目录 model_dir = Path("/tmp/iopaint/models") model_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Using alternative model directory: {model_dir}") device = check_device(Device(device_str)) dump_environment_info() logger.info(f"Starting API server with model: {model} on device: {device}") logger.info(f"Model directory: {model_dir}") logger.info(f"Allowed origins: {allowed_origins}") # 初始化FastAPI app = FastAPI(title="IOPaint API") # 配置API config = ApiConfig( host=host, port=port, model=model, device=device, model_dir=model_dir, input=None, output_dir=None, low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true", no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true", cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true", disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true", ) # 配置CORS cors_options = { "allow_methods": ["*"], "allow_headers": ["*", "X-API-Key"], "allow_origins": allowed_origins, "allow_credentials": True, } app.add_middleware(CORSMiddleware, **cors_options) # API密钥验证(如果设置了) if api_key: @app.middleware("http") async def api_key_validation(request: Request, call_next): # 如果是预检请求(OPTIONS),直接放行 if request.method == "OPTIONS": return await call_next(request) req_api_key = request.headers.get("X-API-Key") if not req_api_key or req_api_key != api_key: return JSONResponse( status_code=401, content={"detail": "Invalid API key"} ) return await call_next(request) # 初始化API api = Api(app, config) # 直接启动服务 if __name__ == "__main__": uvicorn.run(app, host=host, port=port)