|
import os |
|
from pathlib import Path |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
from loguru import logger |
|
import socketio |
|
import threading |
|
import asyncio |
|
|
|
from iopaint.api import Api, api_middleware |
|
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel |
|
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info |
|
|
|
|
|
global_sio = None |
|
|
|
|
|
host = os.environ.get("IOPAINT_HOST", "0.0.0.0") |
|
port = int(os.environ.get("IOPAINT_PORT", "7860")) |
|
|
|
model = os.environ.get("IOPAINT_MODEL", "cv2") |
|
|
|
|
|
model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models") |
|
|
|
device_str = os.environ.get("IOPAINT_DEVICE", "cpu") |
|
|
|
api_key = os.environ.get("IOPAINT_API_KEY", "") |
|
if not api_key.strip(): |
|
api_key = None |
|
logger.info("未设置API密钥,禁用API密钥验证") |
|
else: |
|
logger.info("已设置API密钥,启用API密钥验证") |
|
|
|
allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",") |
|
|
|
|
|
model_dir = Path(model_dir_str) |
|
try: |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Successfully created model directory: {model_dir}") |
|
except Exception as e: |
|
logger.error(f"Failed to create model directory: {e}") |
|
|
|
model_dir = Path("/tmp/iopaint/models") |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Using alternative model directory: {model_dir}") |
|
|
|
device = check_device(Device(device_str)) |
|
dump_environment_info() |
|
|
|
logger.info(f"Starting API server with model: {model} on device: {device}") |
|
logger.info(f"Model directory: {model_dir}") |
|
logger.info(f"Allowed origins: {allowed_origins}") |
|
|
|
|
|
app = FastAPI(title="IOPaint API") |
|
|
|
|
|
api_middleware(app) |
|
|
|
|
|
disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true" |
|
|
|
|
|
config = ApiConfig( |
|
host=host, |
|
port=port, |
|
model=model, |
|
device=device, |
|
model_dir=model_dir, |
|
input=None, |
|
output_dir=None, |
|
low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true", |
|
no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true", |
|
cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true", |
|
disable_nsfw=disable_nsfw_value, |
|
|
|
enable_interactive_seg=False, |
|
interactive_seg_model=InteractiveSegModel.sam2_1_tiny, |
|
interactive_seg_device=Device.cpu, |
|
enable_remove_bg=False, |
|
remove_bg_device=Device.cpu, |
|
remove_bg_model="briaai/RMBG-1.4", |
|
enable_anime_seg=False, |
|
enable_realesrgan=False, |
|
realesrgan_device=Device.cpu, |
|
realesrgan_model=RealESRGANModel.realesr_general_x4v3, |
|
enable_gfpgan=False, |
|
gfpgan_device=Device.cpu, |
|
enable_restoreformer=False, |
|
restoreformer_device=Device.cpu, |
|
|
|
inbrowser=False, |
|
disable_nsfw_checker=disable_nsfw_value, |
|
local_files_only=False, |
|
cpu_textencoder=False, |
|
mask_dir=None, |
|
quality=100, |
|
|
|
cv2_radius=5, |
|
cv2_flag="INPAINT_NS", |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
if api_key: |
|
@app.middleware("http") |
|
async def api_key_validation(request: Request, call_next): |
|
|
|
if request.method == "OPTIONS" or request.url.path in ["/", "/health", "/debug"]: |
|
return await call_next(request) |
|
|
|
req_api_key = request.headers.get("X-API-Key") |
|
if not req_api_key or req_api_key != api_key: |
|
return JSONResponse( |
|
status_code=401, |
|
content={"detail": "Invalid API key"} |
|
) |
|
return await call_next(request) |
|
|
|
|
|
@app.get("/") |
|
@app.get("/health") |
|
def health_check(): |
|
return { |
|
"status": "ok", |
|
"message": "IOPaint API服务器运行正常", |
|
"model": model, |
|
"device": str(device) |
|
} |
|
|
|
|
|
@app.get("/debug") |
|
def debug_info(): |
|
routes = [] |
|
for route in app.routes: |
|
routes.append({ |
|
"path": getattr(route, "path", None), |
|
"methods": getattr(route, "methods", None), |
|
"name": getattr(route, "name", None) |
|
}) |
|
|
|
return { |
|
"status": "ok", |
|
"api_key_required": api_key is not None, |
|
"routes": routes, |
|
"config": { |
|
"host": config.host, |
|
"port": config.port, |
|
"model": config.model, |
|
"device": str(config.device), |
|
"model_dir": str(config.model_dir) |
|
} |
|
} |
|
|
|
|
|
class ApiNoFrontend(Api): |
|
def __init__(self, app, config): |
|
|
|
self.app = app |
|
self.config = config |
|
self.router = None |
|
self.queue_lock = threading.Lock() |
|
|
|
|
|
self.file_manager = self._build_file_manager() |
|
self.plugins = self._build_plugins() |
|
self.model_manager = self._build_model_manager() |
|
|
|
|
|
|
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"]) |
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"]) |
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"]) |
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"]) |
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) |
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) |
|
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) |
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) |
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) |
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) |
|
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) |
|
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) |
|
|
|
|
|
|
|
global global_sio |
|
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") |
|
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) |
|
self.app.mount("/ws", self.combined_asgi_app) |
|
global_sio = self.sio |
|
|
|
|
|
logger.info(f"API初始化完成,注册了{len(app.routes)}个路由") |
|
|
|
|
|
api = ApiNoFrontend(app, config) |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host=host, port=port) |