Spaces:
Running
Running
File size: 6,674 Bytes
d5894b1 4412065 d5894b1 4412065 d5894b1 4412065 d5894b1 4412065 d5894b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from huggingface_hub import login
# 模型配置
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if not os.environ.get("HF_TOKEN"):
print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
else:
login(token=os.environ.get("HF_TOKEN"))
# 标签处理配置
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
class Tagger:
def __init__(self):
self.model = None
self.tag_names = []
self.model_size = None
self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
self._init_model()
def _init_model(self):
"""初始化模型和标签"""
try:
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO,
LABEL_FILENAME,
token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO,
MODEL_FILENAME,
token=self.hf_token
)
# 加载标签
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0]
}
# 加载ONNX模型
self.model = rt.InferenceSession(model_path)
self.model_size = self.model.get_inputs()[0].shape[1]
except huggingface_hub.utils.HfHubHTTPError as e:
if "401" in str(e):
raise RuntimeError(
"模型下载认证失败,请:\n"
"1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
"2. 点击Agree and continue\n"
"3. 确保HF_TOKEN已正确设置"
)
else:
raise
def _preprocess(self, img):
"""图像预处理"""
# 转换为RGB
if img.mode != "RGB":
img = img.convert("RGB")
# 填充为正方形
size = max(img.size)
padded = Image.new("RGB", (size, size), (255, 255, 255))
padded.paste(img, ((size - img.width)//2, (size - img.height)//2))
# 调整尺寸
if size != self.model_size:
padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC)
# 转换为BGR格式
return np.array(padded)[:, :, ::-1].astype(np.float32)
def predict(self, img, general_thresh=0.35, character_thresh=0.85):
"""执行预测"""
# 预处理
img_data = self._preprocess(img)[np.newaxis]
# 运行模型
input_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {input_name: img_data})[0][0]
# 组织结果
results = {
"ratings": {},
"general": {},
"characters": {}
}
# 处理评分标签
for idx in self.categories["rating"]:
tag = self.tag_names[idx].replace("_", " ")
results["ratings"][tag] = float(outputs[idx])
# 处理通用标签
for idx in self.categories["general"]:
if outputs[idx] > general_thresh:
tag = self.tag_names[idx].replace("_", " ")
results["general"][tag] = float(outputs[idx])
# 处理角色标签
for idx in self.categories["character"]:
if outputs[idx] > character_thresh:
tag = self.tag_names[idx].replace("_", " ")
results["characters"][tag] = float(outputs[idx])
# 排序结果
results["general"] = dict(sorted(
results["general"].items(),
key=lambda x: x[1],
reverse=True
))
return results
# 创建Gradio界面
with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo:
gr.Markdown("# 🖼️ AI图像标签分析器")
gr.Markdown("上传图片自动分析图像内容标签")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="上传图片")
with gr.Accordion("高级设置", open=False):
general_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值",
info="值越高标签越少但更准确")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值",
info="推荐保持较高阈值")
analyze_btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签"):
general_tags = gr.Label(label="检测到的通用标签")
with gr.TabItem("👤 角色标签"):
char_tags = gr.Label(label="检测到的角色标签")
with gr.TabItem("⭐ 评分标签"):
rating_tags = gr.Label(label="图像评级标签")
output_text = gr.Textbox(label="标签文本",
placeholder="生成的标签文本将显示在这里...")
# 处理逻辑
def process_image(img, gen_thresh, char_thresh):
tagger = Tagger()
results = tagger.predict(img, gen_thresh, char_thresh)
# 格式化文本输出
tag_text = ", ".join(results["general"].keys())
if results["characters"]:
tag_text += ", " + ", ".join(results["characters"].keys())
return {
general_tags: results["general"],
char_tags: results["characters"],
rating_tags: results["ratings"],
output_text: tag_text
}
analyze_btn.click(
process_image,
inputs=[img_input, general_slider, char_slider],
outputs=[general_tags, char_tags, rating_tags, output_text]
)
# 启动应用
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |