tom12112 commited on
Commit
bc1a883
·
verified ·
1 Parent(s): 007c288

Delete api_only.py

Browse files
Files changed (1) hide show
  1. api_only.py +0 -93
api_only.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from fastapi import FastAPI, Request
4
- from fastapi.responses import JSONResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
- import uvicorn
7
- from loguru import logger
8
-
9
- from iopaint.api import Api
10
- from iopaint.schema import ApiConfig, Device
11
- from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
12
- from iopaint.const import DEFAULT_MODEL_DIR
13
-
14
- # 从环境变量读取配置
15
- host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
16
- port = int(os.environ.get("IOPAINT_PORT", "7860"))
17
- model = os.environ.get("IOPAINT_MODEL", "lama")
18
-
19
- # 修改模型目录路径,使用/app或/tmp目录
20
- model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models")
21
-
22
- device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
23
- api_key = os.environ.get("IOPAINT_API_KEY", None)
24
- allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
25
-
26
- # 初始化目录和环境
27
- model_dir = Path(model_dir_str)
28
- try:
29
- model_dir.mkdir(parents=True, exist_ok=True)
30
- logger.info(f"Successfully created model directory: {model_dir}")
31
- except Exception as e:
32
- logger.error(f"Failed to create model directory: {e}")
33
- # 如果失败,尝试使用/tmp目录
34
- model_dir = Path("/tmp/iopaint/models")
35
- model_dir.mkdir(parents=True, exist_ok=True)
36
- logger.info(f"Using alternative model directory: {model_dir}")
37
-
38
- device = check_device(Device(device_str))
39
- dump_environment_info()
40
-
41
- logger.info(f"Starting API server with model: {model} on device: {device}")
42
- logger.info(f"Model directory: {model_dir}")
43
- logger.info(f"Allowed origins: {allowed_origins}")
44
-
45
- # 初始化FastAPI
46
- app = FastAPI(title="IOPaint API")
47
-
48
- # 配置API
49
- config = ApiConfig(
50
- host=host,
51
- port=port,
52
- model=model,
53
- device=device,
54
- model_dir=model_dir,
55
- input=None,
56
- output_dir=None,
57
- low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true",
58
- no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
59
- cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
60
- disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true",
61
- )
62
-
63
- # 配置CORS
64
- cors_options = {
65
- "allow_methods": ["*"],
66
- "allow_headers": ["*", "X-API-Key"],
67
- "allow_origins": allowed_origins,
68
- "allow_credentials": True,
69
- }
70
- app.add_middleware(CORSMiddleware, **cors_options)
71
-
72
- # API密钥验证(如果设置了)
73
- if api_key:
74
- @app.middleware("http")
75
- async def api_key_validation(request: Request, call_next):
76
- # 如果是预检请求(OPTIONS),直接放行
77
- if request.method == "OPTIONS":
78
- return await call_next(request)
79
-
80
- req_api_key = request.headers.get("X-API-Key")
81
- if not req_api_key or req_api_key != api_key:
82
- return JSONResponse(
83
- status_code=401,
84
- content={"detail": "Invalid API key"}
85
- )
86
- return await call_next(request)
87
-
88
- # 初始化API
89
- api = Api(app, config)
90
-
91
- # 直接启动服务
92
- if __name__ == "__main__":
93
- uvicorn.run(app, host=host, port=port)