tom12112 commited on
Commit
ae7a5a8
·
verified ·
1 Parent(s): 6c8a214

Upload api_only.py

Browse files
Files changed (1) hide show
  1. iopaint/api_only.py +58 -5
iopaint/api_only.py CHANGED
@@ -6,10 +6,12 @@ from fastapi.middleware.cors import CORSMiddleware
6
  import uvicorn
7
  from loguru import logger
8
 
9
- from iopaint.api import Api
10
- from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel, RemoveBGModel
11
  from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
12
- from iopaint.const import DEFAULT_MODEL_DIR
 
 
13
 
14
  # 从环境变量读取配置
15
  host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
@@ -46,6 +48,9 @@ logger.info(f"Allowed origins: {allowed_origins}")
46
  # 初始化FastAPI
47
  app = FastAPI(title="IOPaint API")
48
 
 
 
 
49
  # 读取disable_nsfw环境变量,确保两个字段使用相同的值
50
  disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"
51
 
@@ -68,7 +73,7 @@ config = ApiConfig(
68
  interactive_seg_device=Device.cpu,
69
  enable_remove_bg=False,
70
  remove_bg_device=Device.cpu,
71
- remove_bg_model="briaai/RMBG-1.4", # 修正为字符串类型而不是枚举
72
  enable_anime_seg=False,
73
  enable_realesrgan=False,
74
  realesrgan_device=Device.cpu,
@@ -114,8 +119,56 @@ if api_key:
114
  )
115
  return await call_next(request)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # 初始化API
118
- api = Api(app, config)
119
 
120
  # 直接启动服务
121
  if __name__ == "__main__":
 
6
  import uvicorn
7
  from loguru import logger
8
 
9
+ from iopaint.api import Api, api_middleware
10
+ from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
11
  from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
12
+
13
+ # 全局变量
14
+ global_sio = None
15
 
16
  # 从环境变量读取配置
17
  host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
 
48
  # 初始化FastAPI
49
  app = FastAPI(title="IOPaint API")
50
 
51
+ # 添加API中间件
52
+ api_middleware(app)
53
+
54
  # 读取disable_nsfw环境变量,确保两个字段使用相同的值
55
  disable_nsfw_value = os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true"
56
 
 
73
  interactive_seg_device=Device.cpu,
74
  enable_remove_bg=False,
75
  remove_bg_device=Device.cpu,
76
+ remove_bg_model="briaai/RMBG-1.4", # 字符串类型而不是枚举
77
  enable_anime_seg=False,
78
  enable_realesrgan=False,
79
  realesrgan_device=Device.cpu,
 
119
  )
120
  return await call_next(request)
121
 
122
+ # 创建自定义API类,不尝试加载前端文件
123
+ class ApiOnly(Api):
124
+ def __init__(self, app, config):
125
+ # 初始化大部分父类功能但跳过挂载静态文件
126
+ self.app = app
127
+ self.config = config
128
+ self.router = None
129
+ self.queue_lock = None
130
+
131
+ # 初始化必要的组件
132
+ self.file_manager = self._build_file_manager()
133
+ self.plugins = self._build_plugins()
134
+ self.model_manager = self._build_model_manager()
135
+
136
+ # 注册所有API端点
137
+ # fmt: off
138
+ self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=None)
139
+ self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=None)
140
+ self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=None)
141
+ self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=None)
142
+ self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
143
+ self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
144
+ self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
145
+ self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
146
+ self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
147
+ self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
148
+ self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
149
+ self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
150
+ # fmt: on
151
+
152
+ # 设置WebSocket,但跳过挂载静态文件
153
+ self.setup_socketio()
154
+
155
+ def setup_socketio(self):
156
+ # 设置socketio但不挂载到/ws路径
157
+ import socketio
158
+ import threading
159
+ import asyncio
160
+
161
+ global global_sio
162
+ self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
163
+ self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
164
+ self.app.mount("/ws", self.combined_asgi_app)
165
+ global_sio = self.sio
166
+
167
+ # 确保线程锁存在
168
+ self.queue_lock = threading.Lock()
169
+
170
  # 初始化API
171
+ api = ApiOnly(app, config)
172
 
173
  # 直接启动服务
174
  if __name__ == "__main__":