import os from pathlib import Path from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn from loguru import logger import socketio import threading import asyncio from iopaint.api import Api, api_middleware from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel from iopaint.runtime import setup_model_dir, check_device, dump_environment_info # 全局变量 global_sio = None # 从环境变量读取配置 host = os.environ.get("IOPAINT_HOST", "0.0.0.0") port = int(os.environ.get("IOPAINT_PORT", "7860")) # 修改默认模型为cv2,因为lama无法加载 model = os.environ.get("IOPAINT_MODEL", "cv2") # 修改模型目录路径,使用/app或/tmp目录 model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models") device_str = os.environ.get("IOPAINT_DEVICE", "cpu") # 读取API密钥,如果没有设置默认为None api_key = os.environ.get("IOPAINT_API_KEY", "") if not api_key.strip(): api_key = None logger.info("未设置API密钥,禁用API密钥验证") else: logger.info("已设置API密钥,启用API密钥验证") allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",") # 初始化目录和环境 model_dir = Path(model_dir_str) try: model_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Successfully created model directory: {model_dir}") except Exception as e: logger.error(f"Failed to create model directory: {e}") # 如果失败,尝试使用/tmp目录 model_dir = Path("/tmp/iopaint/models") model_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Using alternative model directory: {model_dir}") device = check_device(Device(device_str)) dump_environment_info() logger.info(f"Starting API server with model: {model} on device: {device}") logger.info(f"Model directory: {model_dir}") logger.info(f"Allowed origins: {allowed_origins}") # 初始化FastAPI app = FastAPI(title="IOPaint API") # 添加API中间件 api_middleware(app) # 读取disable_nsfw环境变量,确保两个字段使用相同的值 disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true" # 配置API,添加所有缺失的必填字段 config = ApiConfig( host=host, port=port, model=model, device=device, model_dir=model_dir, input=None, output_dir=None, low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true", no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true", cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true", disable_nsfw=disable_nsfw_value, # 添加之前缺失的必填字段 enable_interactive_seg=False, interactive_seg_model=InteractiveSegModel.sam2_1_tiny, interactive_seg_device=Device.cpu, enable_remove_bg=False, remove_bg_device=Device.cpu, remove_bg_model="briaai/RMBG-1.4", # 字符串类型而不是枚举 enable_anime_seg=False, enable_realesrgan=False, realesrgan_device=Device.cpu, realesrgan_model=RealESRGANModel.realesr_general_x4v3, enable_gfpgan=False, gfpgan_device=Device.cpu, enable_restoreformer=False, restoreformer_device=Device.cpu, # 添加新发现缺失的必填字段 inbrowser=False, disable_nsfw_checker=disable_nsfw_value, # 使用相同的值 local_files_only=False, cpu_textencoder=False, mask_dir=None, quality=100, # 添加cv2模型需要的特定参数 cv2_radius=5, cv2_flag="INPAINT_NS", ) # 配置CORS,允许所有源、方法和头,这对API服务器很重要 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # API密钥验证(如果设置了),跳过健康检查路径 if api_key: @app.middleware("http") async def api_key_validation(request: Request, call_next): # 健康检查和OPTIONS请求不需要API密钥验证 if request.method == "OPTIONS" or request.url.path in ["/", "/health", "/debug"]: return await call_next(request) req_api_key = request.headers.get("X-API-Key") if not req_api_key or req_api_key != api_key: return JSONResponse( status_code=401, content={"detail": "Invalid API key"} ) return await call_next(request) # 添加根路由和健康检查路由,方便测试API是否正常工作 @app.get("/") @app.get("/health") def health_check(): return { "status": "ok", "message": "IOPaint API服务器运行正常", "model": model, "device": str(device) } # 添加调试端点,显示API路由 @app.get("/debug") def debug_info(): routes = [] for route in app.routes: routes.append({ "path": getattr(route, "path", None), "methods": getattr(route, "methods", None), "name": getattr(route, "name", None) }) return { "status": "ok", "api_key_required": api_key is not None, "routes": routes, "config": { "host": config.host, "port": config.port, "model": config.model, "device": str(config.device), "model_dir": str(config.model_dir) } } # 创建Api类的子类,覆盖__init__方法,不挂载静态文件 class ApiNoFrontend(Api): def __init__(self, app, config): # 基本初始化,但跳过挂载静态文件 self.app = app self.config = config self.router = None self.queue_lock = threading.Lock() # 初始化组件 self.file_manager = self._build_file_manager() self.plugins = self._build_plugins() self.model_manager = self._build_model_manager() # 注册API路由 # fmt: off self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"]) self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"]) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"]) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"]) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) # fmt: on # 设置SocketIO global global_sio self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) self.app.mount("/ws", self.combined_asgi_app) global_sio = self.sio # 记录路由信息 logger.info(f"API初始化完成,注册了{len(app.routes)}个路由") # 使用自定义API类 api = ApiNoFrontend(app, config) # 直接启动服务 if __name__ == "__main__": uvicorn.run(app, host=host, port=port)