File size: 7,627 Bytes
6c6eb37
 
 
 
 
 
 
c3b0ec6
 
 
6c6eb37
ae7a5a8
 
6c6eb37
ae7a5a8
 
 
6c6eb37
 
 
 
b911230
 
3372d56
 
 
 
6c6eb37
c3b0ec6
 
 
 
 
 
 
 
6c6eb37
 
 
 
3372d56
 
 
 
 
 
 
 
 
 
6c6eb37
 
 
 
 
 
 
 
 
 
ae7a5a8
 
 
4d2ad30
 
 
3372d56
6c6eb37
 
 
 
 
 
 
 
 
 
 
4d2ad30
 
3372d56
 
 
 
 
ae7a5a8
3372d56
 
 
 
 
 
 
 
4d2ad30
 
 
 
 
 
 
b911230
 
 
6c6eb37
 
c3b0ec6
 
 
 
 
 
 
 
6c6eb37
c3b0ec6
6c6eb37
 
 
c3b0ec6
 
6c6eb37
 
 
 
 
 
 
 
 
 
c3b0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae7a5a8
c3b0ec6
ae7a5a8
 
 
c3b0ec6
 
 
ae7a5a8
 
 
 
c3b0ec6
ae7a5a8
c3b0ec6
 
 
 
ae7a5a8
 
 
 
 
 
 
 
 
 
c3b0ec6
ae7a5a8
 
 
 
 
 
c3b0ec6
 
ae7a5a8
c3b0ec6
 
6c6eb37
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from loguru import logger
import socketio
import threading
import asyncio

from iopaint.api import Api, api_middleware
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info

# 全局变量
global_sio = None

# 从环境变量读取配置
host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
port = int(os.environ.get("IOPAINT_PORT", "7860"))
# 修改默认模型为cv2,因为lama无法加载
model = os.environ.get("IOPAINT_MODEL", "cv2")

# 修改模型目录路径,使用/app或/tmp目录
model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models")

device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
# 读取API密钥,如果没有设置默认为None
api_key = os.environ.get("IOPAINT_API_KEY", "")
if not api_key.strip():
    api_key = None
    logger.info("未设置API密钥,禁用API密钥验证")
else:
    logger.info("已设置API密钥,启用API密钥验证")

allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")

# 初始化目录和环境
model_dir = Path(model_dir_str)
try:
    model_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Successfully created model directory: {model_dir}")
except Exception as e:
    logger.error(f"Failed to create model directory: {e}")
    # 如果失败,尝试使用/tmp目录
    model_dir = Path("/tmp/iopaint/models")
    model_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Using alternative model directory: {model_dir}")

device = check_device(Device(device_str))
dump_environment_info()

logger.info(f"Starting API server with model: {model} on device: {device}")
logger.info(f"Model directory: {model_dir}")
logger.info(f"Allowed origins: {allowed_origins}")

# 初始化FastAPI
app = FastAPI(title="IOPaint API")

# 添加API中间件
api_middleware(app)

# 读取disable_nsfw环境变量,确保两个字段使用相同的值
disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"

# 配置API,添加所有缺失的必填字段
config = ApiConfig(
    host=host,
    port=port,
    model=model,
    device=device,
    model_dir=model_dir,
    input=None,
    output_dir=None,
    low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true",
    no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
    cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
    disable_nsfw=disable_nsfw_value,
    # 添加之前缺失的必填字段
    enable_interactive_seg=False,
    interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
    interactive_seg_device=Device.cpu,
    enable_remove_bg=False,
    remove_bg_device=Device.cpu,
    remove_bg_model="briaai/RMBG-1.4",  # 字符串类型而不是枚举
    enable_anime_seg=False,
    enable_realesrgan=False,
    realesrgan_device=Device.cpu,
    realesrgan_model=RealESRGANModel.realesr_general_x4v3,
    enable_gfpgan=False,
    gfpgan_device=Device.cpu,
    enable_restoreformer=False,
    restoreformer_device=Device.cpu,
    # 添加新发现缺失的必填字段
    inbrowser=False,
    disable_nsfw_checker=disable_nsfw_value,  # 使用相同的值
    local_files_only=False,
    cpu_textencoder=False,
    mask_dir=None,
    quality=100,
    # 添加cv2模型需要的特定参数
    cv2_radius=5,
    cv2_flag="INPAINT_NS",
)

# 配置CORS,允许所有源、方法和头,这对API服务器很重要
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# API密钥验证(如果设置了),跳过健康检查路径
if api_key:
    @app.middleware("http")
    async def api_key_validation(request: Request, call_next):
        # 健康检查和OPTIONS请求不需要API密钥验证
        if request.method == "OPTIONS" or request.url.path in ["/", "/health", "/debug"]:
            return await call_next(request)
        
        req_api_key = request.headers.get("X-API-Key")
        if not req_api_key or req_api_key != api_key:
            return JSONResponse(
                status_code=401,
                content={"detail": "Invalid API key"}
            )
        return await call_next(request)

# 添加根路由和健康检查路由,方便测试API是否正常工作
@app.get("/")
@app.get("/health")
def health_check():
    return {
        "status": "ok",
        "message": "IOPaint API服务器运行正常",
        "model": model,
        "device": str(device)
    }

# 添加调试端点,显示API路由
@app.get("/debug")
def debug_info():
    routes = []
    for route in app.routes:
        routes.append({
            "path": getattr(route, "path", None),
            "methods": getattr(route, "methods", None),
            "name": getattr(route, "name", None)
        })
    
    return {
        "status": "ok",
        "api_key_required": api_key is not None,
        "routes": routes,
        "config": {
            "host": config.host,
            "port": config.port,
            "model": config.model,
            "device": str(config.device),
            "model_dir": str(config.model_dir)
        }
    }

# 创建Api类的子类,覆盖__init__方法,不挂载静态文件
class ApiNoFrontend(Api):
    def __init__(self, app, config):
        # 基本初始化,但跳过挂载静态文件
        self.app = app
        self.config = config
        self.router = None
        self.queue_lock = threading.Lock()
        
        # 初始化组件
        self.file_manager = self._build_file_manager()
        self.plugins = self._build_plugins()
        self.model_manager = self._build_model_manager()
        
        # 注册API路由
        # fmt: off
        self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"])
        self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"])
        self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"])
        self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"])
        self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
        self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
        self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
        self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
        self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
        self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
        self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
        self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
        # fmt: on
        
        # 设置SocketIO
        global global_sio
        self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
        self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
        self.app.mount("/ws", self.combined_asgi_app)
        global_sio = self.sio
        
        # 记录路由信息
        logger.info(f"API初始化完成,注册了{len(app.routes)}个路由")

# 使用自定义API类
api = ApiNoFrontend(app, config)

# 直接启动服务
if __name__ == "__main__":
    uvicorn.run(app, host=host, port=port)