tom12112 commited on
Commit
3372d56
·
verified ·
1 Parent(s): bc1a883

Upload api_only.py

Browse files
Files changed (1) hide show
  1. iopaint/api_only.py +31 -4
iopaint/api_only.py CHANGED
@@ -7,7 +7,7 @@ import uvicorn
7
  from loguru import logger
8
 
9
  from iopaint.api import Api
10
- from iopaint.schema import ApiConfig, Device
11
  from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
12
  from iopaint.const import DEFAULT_MODEL_DIR
13
 
@@ -15,14 +15,26 @@ from iopaint.const import DEFAULT_MODEL_DIR
15
  host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
16
  port = int(os.environ.get("IOPAINT_PORT", "7860"))
17
  model = os.environ.get("IOPAINT_MODEL", "lama")
18
- model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", str(DEFAULT_MODEL_DIR))
 
 
 
19
  device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
20
  api_key = os.environ.get("IOPAINT_API_KEY", None)
21
  allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
22
 
23
  # 初始化目录和环境
24
  model_dir = Path(model_dir_str)
25
- model_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
26
  device = check_device(Device(device_str))
27
  dump_environment_info()
28
 
@@ -33,7 +45,7 @@ logger.info(f"Allowed origins: {allowed_origins}")
33
  # 初始化FastAPI
34
  app = FastAPI(title="IOPaint API")
35
 
36
- # 配置API
37
  config = ApiConfig(
38
  host=host,
39
  port=port,
@@ -46,6 +58,21 @@ config = ApiConfig(
46
  no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
47
  cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
48
  disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
50
 
51
  # 配置CORS
 
7
  from loguru import logger
8
 
9
  from iopaint.api import Api
10
+ from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel, RemoveBGModel
11
  from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
12
  from iopaint.const import DEFAULT_MODEL_DIR
13
 
 
15
  host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
16
  port = int(os.environ.get("IOPAINT_PORT", "7860"))
17
  model = os.environ.get("IOPAINT_MODEL", "lama")
18
+
19
+ # 修改模型目录路径,使用/app或/tmp目录
20
+ model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models")
21
+
22
  device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
23
  api_key = os.environ.get("IOPAINT_API_KEY", None)
24
  allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
25
 
26
  # 初始化目录和环境
27
  model_dir = Path(model_dir_str)
28
+ try:
29
+ model_dir.mkdir(parents=True, exist_ok=True)
30
+ logger.info(f"Successfully created model directory: {model_dir}")
31
+ except Exception as e:
32
+ logger.error(f"Failed to create model directory: {e}")
33
+ # 如果失败,尝试使用/tmp目录
34
+ model_dir = Path("/tmp/iopaint/models")
35
+ model_dir.mkdir(parents=True, exist_ok=True)
36
+ logger.info(f"Using alternative model directory: {model_dir}")
37
+
38
  device = check_device(Device(device_str))
39
  dump_environment_info()
40
 
 
45
  # 初始化FastAPI
46
  app = FastAPI(title="IOPaint API")
47
 
48
+ # 配置API,添加所有缺失的必填字段
49
  config = ApiConfig(
50
  host=host,
51
  port=port,
 
58
  no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
59
  cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
60
  disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true",
61
+ # 添加缺失的必填字段
62
+ enable_interactive_seg=False,
63
+ interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
64
+ interactive_seg_device=Device.cpu,
65
+ enable_remove_bg=False,
66
+ remove_bg_device=Device.cpu,
67
+ remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
68
+ enable_anime_seg=False,
69
+ enable_realesrgan=False,
70
+ realesrgan_device=Device.cpu,
71
+ realesrgan_model=RealESRGANModel.realesr_general_x4v3,
72
+ enable_gfpgan=False,
73
+ gfpgan_device=Device.cpu,
74
+ enable_restoreformer=False,
75
+ restoreformer_device=Device.cpu,
76
  )
77
 
78
  # 配置CORS