Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
| 5 |
import onnxruntime as rt
|
| 6 |
import pandas as pd
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
|
| 9 |
# 模型配置
|
| 10 |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
|
|
@@ -12,6 +13,11 @@ MODEL_FILENAME = "model.onnx"
|
|
| 12 |
LABEL_FILENAME = "selected_tags.csv"
|
| 13 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# 标签处理配置
|
| 16 |
kaomojis = [
|
| 17 |
"0_0",
|
|
@@ -40,34 +46,45 @@ class Tagger:
|
|
| 40 |
self.model = None
|
| 41 |
self.tag_names = []
|
| 42 |
self.model_size = None
|
|
|
|
| 43 |
self._init_model()
|
| 44 |
|
| 45 |
def _init_model(self):
|
| 46 |
"""初始化模型和标签"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def _preprocess(self, img):
|
| 73 |
"""图像预处理"""
|
|
|
|
| 5 |
import onnxruntime as rt
|
| 6 |
import pandas as pd
|
| 7 |
from PIL import Image
|
| 8 |
+
from huggingface_hub import login
|
| 9 |
|
| 10 |
# 模型配置
|
| 11 |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
|
|
|
|
| 13 |
LABEL_FILENAME = "selected_tags.csv"
|
| 14 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 15 |
|
| 16 |
+
if not os.environ.get("HF_TOKEN"):
|
| 17 |
+
print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
|
| 18 |
+
else:
|
| 19 |
+
login(token=os.environ.get("HF_TOKEN"))
|
| 20 |
+
|
| 21 |
# 标签处理配置
|
| 22 |
kaomojis = [
|
| 23 |
"0_0",
|
|
|
|
| 46 |
self.model = None
|
| 47 |
self.tag_names = []
|
| 48 |
self.model_size = None
|
| 49 |
+
self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
|
| 50 |
self._init_model()
|
| 51 |
|
| 52 |
def _init_model(self):
|
| 53 |
"""初始化模型和标签"""
|
| 54 |
+
try:
|
| 55 |
+
label_path = huggingface_hub.hf_hub_download(
|
| 56 |
+
MODEL_REPO,
|
| 57 |
+
LABEL_FILENAME,
|
| 58 |
+
token=self.hf_token
|
| 59 |
+
)
|
| 60 |
+
model_path = huggingface_hub.hf_hub_download(
|
| 61 |
+
MODEL_REPO,
|
| 62 |
+
MODEL_FILENAME,
|
| 63 |
+
token=self.hf_token
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# 加载标签
|
| 67 |
+
tags_df = pd.read_csv(label_path)
|
| 68 |
+
self.tag_names = tags_df["name"].tolist()
|
| 69 |
+
self.categories = {
|
| 70 |
+
"rating": np.where(tags_df["category"] == 9)[0],
|
| 71 |
+
"general": np.where(tags_df["category"] == 0)[0],
|
| 72 |
+
"character": np.where(tags_df["category"] == 4)[0]
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# 加载ONNX模型
|
| 76 |
+
self.model = rt.InferenceSession(model_path)
|
| 77 |
+
self.model_size = self.model.get_inputs()[0].shape[1]
|
| 78 |
+
except huggingface_hub.utils.HfHubHTTPError as e:
|
| 79 |
+
if "401" in str(e):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"模型下载认证失败,请:\n"
|
| 82 |
+
"1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
|
| 83 |
+
"2. 点击Agree and continue\n"
|
| 84 |
+
"3. 确保HF_TOKEN已正确设置"
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
raise
|
| 88 |
|
| 89 |
def _preprocess(self, img):
|
| 90 |
"""图像预处理"""
|