Spaces:
Running
Running
import os | |
import gradio as gr | |
import huggingface_hub | |
import numpy as np | |
import onnxruntime as rt | |
import pandas as pd | |
from PIL import Image | |
from huggingface_hub import login | |
# 模型配置 | |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型 | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
if not os.environ.get("HF_TOKEN"): | |
print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证") | |
else: | |
login(token=os.environ.get("HF_TOKEN")) | |
# 标签处理配置 | |
kaomojis = [ | |
"0_0", | |
"(o)_(o)", | |
"+_+", | |
"+_-", | |
"._.", | |
"<o>_<o>", | |
"<|>_<|>", | |
"=_=", | |
">_<", | |
"3_3", | |
"6_9", | |
">_o", | |
"@_@", | |
"^_^", | |
"o_o", | |
"u_u", | |
"x_x", | |
"|_|", | |
"||_||", | |
] | |
class Tagger: | |
def __init__(self): | |
self.model = None | |
self.tag_names = [] | |
self.model_size = None | |
self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取 | |
self._init_model() | |
def _init_model(self): | |
"""初始化模型和标签""" | |
try: | |
label_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, | |
LABEL_FILENAME, | |
token=self.hf_token | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, | |
MODEL_FILENAME, | |
token=self.hf_token | |
) | |
# 加载标签 | |
tags_df = pd.read_csv(label_path) | |
self.tag_names = tags_df["name"].tolist() | |
self.categories = { | |
"rating": np.where(tags_df["category"] == 9)[0], | |
"general": np.where(tags_df["category"] == 0)[0], | |
"character": np.where(tags_df["category"] == 4)[0] | |
} | |
# 加载ONNX模型 | |
self.model = rt.InferenceSession(model_path) | |
self.model_size = self.model.get_inputs()[0].shape[1] | |
except huggingface_hub.utils.HfHubHTTPError as e: | |
if "401" in str(e): | |
raise RuntimeError( | |
"模型下载认证失败,请:\n" | |
"1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n" | |
"2. 点击Agree and continue\n" | |
"3. 确保HF_TOKEN已正确设置" | |
) | |
else: | |
raise | |
def _preprocess(self, img): | |
"""图像预处理""" | |
# 转换为RGB | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
# 填充为正方形 | |
size = max(img.size) | |
padded = Image.new("RGB", (size, size), (255, 255, 255)) | |
padded.paste(img, ((size - img.width)//2, (size - img.height)//2)) | |
# 调整尺寸 | |
if size != self.model_size: | |
padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC) | |
# 转换为BGR格式 | |
return np.array(padded)[:, :, ::-1].astype(np.float32) | |
def predict(self, img, general_thresh=0.35, character_thresh=0.85): | |
"""执行预测""" | |
# 预处理 | |
img_data = self._preprocess(img)[np.newaxis] | |
# 运行模型 | |
input_name = self.model.get_inputs()[0].name | |
outputs = self.model.run(None, {input_name: img_data})[0][0] | |
# 组织结果 | |
results = { | |
"ratings": {}, | |
"general": {}, | |
"characters": {} | |
} | |
# 处理评分标签 | |
for idx in self.categories["rating"]: | |
tag = self.tag_names[idx].replace("_", " ") | |
results["ratings"][tag] = float(outputs[idx]) | |
# 处理通用标签 | |
for idx in self.categories["general"]: | |
if outputs[idx] > general_thresh: | |
tag = self.tag_names[idx].replace("_", " ") | |
results["general"][tag] = float(outputs[idx]) | |
# 处理角色标签 | |
for idx in self.categories["character"]: | |
if outputs[idx] > character_thresh: | |
tag = self.tag_names[idx].replace("_", " ") | |
results["characters"][tag] = float(outputs[idx]) | |
# 排序结果 | |
results["general"] = dict(sorted( | |
results["general"].items(), | |
key=lambda x: x[1], | |
reverse=True | |
)) | |
return results | |
# 创建Gradio界面 | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo: | |
gr.Markdown("# 🖼️ AI图像标签分析器") | |
gr.Markdown("上传图片自动分析图像内容标签") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_input = gr.Image(type="pil", label="上传图片") | |
with gr.Accordion("高级设置", open=False): | |
general_slider = gr.Slider(0, 1, 0.35, | |
label="通用标签阈值", | |
info="值越高标签越少但更准确") | |
char_slider = gr.Slider(0, 1, 0.85, | |
label="角色标签阈值", | |
info="推荐保持较高阈值") | |
analyze_btn = gr.Button("开始分析", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("🏷️ 通用标签"): | |
general_tags = gr.Label(label="检测到的通用标签") | |
with gr.TabItem("👤 角色标签"): | |
char_tags = gr.Label(label="检测到的角色标签") | |
with gr.TabItem("⭐ 评分标签"): | |
rating_tags = gr.Label(label="图像评级标签") | |
output_text = gr.Textbox(label="标签文本", | |
placeholder="生成的标签文本将显示在这里...") | |
# 处理逻辑 | |
def process_image(img, gen_thresh, char_thresh): | |
tagger = Tagger() | |
results = tagger.predict(img, gen_thresh, char_thresh) | |
# 格式化文本输出 | |
tag_text = ", ".join(results["general"].keys()) | |
if results["characters"]: | |
tag_text += ", " + ", ".join(results["characters"].keys()) | |
return { | |
general_tags: results["general"], | |
char_tags: results["characters"], | |
rating_tags: results["ratings"], | |
output_text: tag_text | |
} | |
analyze_btn.click( | |
process_image, | |
inputs=[img_input, general_slider, char_slider], | |
outputs=[general_tags, char_tags, rating_tags, output_text] | |
) | |
# 启动应用 | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |