File size: 7,627 Bytes
6c6eb37 c3b0ec6 6c6eb37 ae7a5a8 6c6eb37 ae7a5a8 6c6eb37 b911230 3372d56 6c6eb37 c3b0ec6 6c6eb37 3372d56 6c6eb37 ae7a5a8 4d2ad30 3372d56 6c6eb37 4d2ad30 3372d56 ae7a5a8 3372d56 4d2ad30 b911230 6c6eb37 c3b0ec6 6c6eb37 c3b0ec6 6c6eb37 c3b0ec6 6c6eb37 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 ae7a5a8 c3b0ec6 6c6eb37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import os
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from loguru import logger
import socketio
import threading
import asyncio
from iopaint.api import Api, api_middleware
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
# 全局变量
global_sio = None
# 从环境变量读取配置
host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
port = int(os.environ.get("IOPAINT_PORT", "7860"))
# 修改默认模型为cv2,因为lama无法加载
model = os.environ.get("IOPAINT_MODEL", "cv2")
# 修改模型目录路径,使用/app或/tmp目录
model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models")
device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
# 读取API密钥,如果没有设置默认为None
api_key = os.environ.get("IOPAINT_API_KEY", "")
if not api_key.strip():
api_key = None
logger.info("未设置API密钥,禁用API密钥验证")
else:
logger.info("已设置API密钥,启用API密钥验证")
allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
# 初始化目录和环境
model_dir = Path(model_dir_str)
try:
model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Successfully created model directory: {model_dir}")
except Exception as e:
logger.error(f"Failed to create model directory: {e}")
# 如果失败,尝试使用/tmp目录
model_dir = Path("/tmp/iopaint/models")
model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Using alternative model directory: {model_dir}")
device = check_device(Device(device_str))
dump_environment_info()
logger.info(f"Starting API server with model: {model} on device: {device}")
logger.info(f"Model directory: {model_dir}")
logger.info(f"Allowed origins: {allowed_origins}")
# 初始化FastAPI
app = FastAPI(title="IOPaint API")
# 添加API中间件
api_middleware(app)
# 读取disable_nsfw环境变量,确保两个字段使用相同的值
disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"
# 配置API,添加所有缺失的必填字段
config = ApiConfig(
host=host,
port=port,
model=model,
device=device,
model_dir=model_dir,
input=None,
output_dir=None,
low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true",
no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
disable_nsfw=disable_nsfw_value,
# 添加之前缺失的必填字段
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
remove_bg_device=Device.cpu,
remove_bg_model="briaai/RMBG-1.4", # 字符串类型而不是枚举
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device=Device.cpu,
enable_restoreformer=False,
restoreformer_device=Device.cpu,
# 添加新发现缺失的必填字段
inbrowser=False,
disable_nsfw_checker=disable_nsfw_value, # 使用相同的值
local_files_only=False,
cpu_textencoder=False,
mask_dir=None,
quality=100,
# 添加cv2模型需要的特定参数
cv2_radius=5,
cv2_flag="INPAINT_NS",
)
# 配置CORS,允许所有源、方法和头,这对API服务器很重要
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API密钥验证(如果设置了),跳过健康检查路径
if api_key:
@app.middleware("http")
async def api_key_validation(request: Request, call_next):
# 健康检查和OPTIONS请求不需要API密钥验证
if request.method == "OPTIONS" or request.url.path in ["/", "/health", "/debug"]:
return await call_next(request)
req_api_key = request.headers.get("X-API-Key")
if not req_api_key or req_api_key != api_key:
return JSONResponse(
status_code=401,
content={"detail": "Invalid API key"}
)
return await call_next(request)
# 添加根路由和健康检查路由,方便测试API是否正常工作
@app.get("/")
@app.get("/health")
def health_check():
return {
"status": "ok",
"message": "IOPaint API服务器运行正常",
"model": model,
"device": str(device)
}
# 添加调试端点,显示API路由
@app.get("/debug")
def debug_info():
routes = []
for route in app.routes:
routes.append({
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None)
})
return {
"status": "ok",
"api_key_required": api_key is not None,
"routes": routes,
"config": {
"host": config.host,
"port": config.port,
"model": config.model,
"device": str(config.device),
"model_dir": str(config.model_dir)
}
}
# 创建Api类的子类,覆盖__init__方法,不挂载静态文件
class ApiNoFrontend(Api):
def __init__(self, app, config):
# 基本初始化,但跳过挂载静态文件
self.app = app
self.config = config
self.router = None
self.queue_lock = threading.Lock()
# 初始化组件
self.file_manager = self._build_file_manager()
self.plugins = self._build_plugins()
self.model_manager = self._build_model_manager()
# 注册API路由
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"])
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"])
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"])
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"])
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
# fmt: on
# 设置SocketIO
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
# 记录路由信息
logger.info(f"API初始化完成,注册了{len(app.routes)}个路由")
# 使用自定义API类
api = ApiNoFrontend(app, config)
# 直接启动服务
if __name__ == "__main__":
uvicorn.run(app, host=host, port=port) |