Upload api_only.py
Browse files- iopaint/api_only.py +58 -5
iopaint/api_only.py
CHANGED
@@ -6,10 +6,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
6 |
import uvicorn
|
7 |
from loguru import logger
|
8 |
|
9 |
-
from iopaint.api import Api
|
10 |
-
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
|
11 |
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
|
12 |
-
|
|
|
|
|
13 |
|
14 |
# 从环境变量读取配置
|
15 |
host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
|
@@ -46,6 +48,9 @@ logger.info(f"Allowed origins: {allowed_origins}")
|
|
46 |
# 初始化FastAPI
|
47 |
app = FastAPI(title="IOPaint API")
|
48 |
|
|
|
|
|
|
|
49 |
# 读取disable_nsfw环境变量,确保两个字段使用相同的值
|
50 |
disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"
|
51 |
|
@@ -68,7 +73,7 @@ config = ApiConfig(
|
|
68 |
interactive_seg_device=Device.cpu,
|
69 |
enable_remove_bg=False,
|
70 |
remove_bg_device=Device.cpu,
|
71 |
-
remove_bg_model="briaai/RMBG-1.4", #
|
72 |
enable_anime_seg=False,
|
73 |
enable_realesrgan=False,
|
74 |
realesrgan_device=Device.cpu,
|
@@ -114,8 +119,56 @@ if api_key:
|
|
114 |
)
|
115 |
return await call_next(request)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
# 初始化API
|
118 |
-
api =
|
119 |
|
120 |
# 直接启动服务
|
121 |
if __name__ == "__main__":
|
|
|
6 |
import uvicorn
|
7 |
from loguru import logger
|
8 |
|
9 |
+
from iopaint.api import Api, api_middleware
|
10 |
+
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
|
11 |
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
|
12 |
+
|
13 |
+
# 全局变量
|
14 |
+
global_sio = None
|
15 |
|
16 |
# 从环境变量读取配置
|
17 |
host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
|
|
|
48 |
# 初始化FastAPI
|
49 |
app = FastAPI(title="IOPaint API")
|
50 |
|
51 |
+
# 添加API中间件
|
52 |
+
api_middleware(app)
|
53 |
+
|
54 |
# 读取disable_nsfw环境变量,确保两个字段使用相同的值
|
55 |
disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"
|
56 |
|
|
|
73 |
interactive_seg_device=Device.cpu,
|
74 |
enable_remove_bg=False,
|
75 |
remove_bg_device=Device.cpu,
|
76 |
+
remove_bg_model="briaai/RMBG-1.4", # 字符串类型而不是枚举
|
77 |
enable_anime_seg=False,
|
78 |
enable_realesrgan=False,
|
79 |
realesrgan_device=Device.cpu,
|
|
|
119 |
)
|
120 |
return await call_next(request)
|
121 |
|
122 |
+
# 创建自定义API类,不尝试加载前端文件
|
123 |
+
class ApiOnly(Api):
|
124 |
+
def __init__(self, app, config):
|
125 |
+
# 初始化大部分父类功能但跳过挂载静态文件
|
126 |
+
self.app = app
|
127 |
+
self.config = config
|
128 |
+
self.router = None
|
129 |
+
self.queue_lock = None
|
130 |
+
|
131 |
+
# 初始化必要的组件
|
132 |
+
self.file_manager = self._build_file_manager()
|
133 |
+
self.plugins = self._build_plugins()
|
134 |
+
self.model_manager = self._build_model_manager()
|
135 |
+
|
136 |
+
# 注册所有API端点
|
137 |
+
# fmt: off
|
138 |
+
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=None)
|
139 |
+
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=None)
|
140 |
+
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=None)
|
141 |
+
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=None)
|
142 |
+
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
143 |
+
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
144 |
+
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
145 |
+
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
146 |
+
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
147 |
+
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
148 |
+
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
149 |
+
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
150 |
+
# fmt: on
|
151 |
+
|
152 |
+
# 设置WebSocket,但跳过挂载静态文件
|
153 |
+
self.setup_socketio()
|
154 |
+
|
155 |
+
def setup_socketio(self):
|
156 |
+
# 设置socketio但不挂载到/ws路径
|
157 |
+
import socketio
|
158 |
+
import threading
|
159 |
+
import asyncio
|
160 |
+
|
161 |
+
global global_sio
|
162 |
+
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
163 |
+
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
164 |
+
self.app.mount("/ws", self.combined_asgi_app)
|
165 |
+
global_sio = self.sio
|
166 |
+
|
167 |
+
# 确保线程锁存在
|
168 |
+
self.queue_lock = threading.Lock()
|
169 |
+
|
170 |
# 初始化API
|
171 |
+
api = ApiOnly(app, config)
|
172 |
|
173 |
# 直接启动服务
|
174 |
if __name__ == "__main__":
|