import os import gradio as gr import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd from PIL import Image from huggingface_hub import login # 模型配置 MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型 MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") if not os.environ.get("HF_TOKEN"): print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证") else: login(token=os.environ.get("HF_TOKEN")) # 标签处理配置 kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] class Tagger: def __init__(self): self.model = None self.tag_names = [] self.model_size = None self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取 self._init_model() def _init_model(self): """初始化模型和标签""" try: label_path = huggingface_hub.hf_hub_download( MODEL_REPO, LABEL_FILENAME, token=self.hf_token ) model_path = huggingface_hub.hf_hub_download( MODEL_REPO, MODEL_FILENAME, token=self.hf_token ) # 加载标签 tags_df = pd.read_csv(label_path) self.tag_names = tags_df["name"].tolist() self.categories = { "rating": np.where(tags_df["category"] == 9)[0], "general": np.where(tags_df["category"] == 0)[0], "character": np.where(tags_df["category"] == 4)[0] } # 加载ONNX模型 self.model = rt.InferenceSession(model_path) self.model_size = self.model.get_inputs()[0].shape[1] except huggingface_hub.utils.HfHubHTTPError as e: if "401" in str(e): raise RuntimeError( "模型下载认证失败,请:\n" "1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n" "2. 点击Agree and continue\n" "3. 确保HF_TOKEN已正确设置" ) else: raise def _preprocess(self, img): """图像预处理""" # 转换为RGB if img.mode != "RGB": img = img.convert("RGB") # 填充为正方形 size = max(img.size) padded = Image.new("RGB", (size, size), (255, 255, 255)) padded.paste(img, ((size - img.width)//2, (size - img.height)//2)) # 调整尺寸 if size != self.model_size: padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC) # 转换为BGR格式 return np.array(padded)[:, :, ::-1].astype(np.float32) def predict(self, img, general_thresh=0.35, character_thresh=0.85): """执行预测""" # 预处理 img_data = self._preprocess(img)[np.newaxis] # 运行模型 input_name = self.model.get_inputs()[0].name outputs = self.model.run(None, {input_name: img_data})[0][0] # 组织结果 results = { "ratings": {}, "general": {}, "characters": {} } # 处理评分标签 for idx in self.categories["rating"]: tag = self.tag_names[idx].replace("_", " ") results["ratings"][tag] = float(outputs[idx]) # 处理通用标签 for idx in self.categories["general"]: if outputs[idx] > general_thresh: tag = self.tag_names[idx].replace("_", " ") results["general"][tag] = float(outputs[idx]) # 处理角色标签 for idx in self.categories["character"]: if outputs[idx] > character_thresh: tag = self.tag_names[idx].replace("_", " ") results["characters"][tag] = float(outputs[idx]) # 排序结果 results["general"] = dict(sorted( results["general"].items(), key=lambda x: x[1], reverse=True )) return results # 创建Gradio界面 with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo: gr.Markdown("# 🖼️ AI图像标签分析器") gr.Markdown("上传图片自动分析图像内容标签") with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(type="pil", label="上传图片") with gr.Accordion("高级设置", open=False): general_slider = gr.Slider(0, 1, 0.35, label="通用标签阈值", info="值越高标签越少但更准确") char_slider = gr.Slider(0, 1, 0.85, label="角色标签阈值", info="推荐保持较高阈值") analyze_btn = gr.Button("开始分析", variant="primary") with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🏷️ 通用标签"): general_tags = gr.Label(label="检测到的通用标签") with gr.TabItem("👤 角色标签"): char_tags = gr.Label(label="检测到的角色标签") with gr.TabItem("⭐ 评分标签"): rating_tags = gr.Label(label="图像评级标签") output_text = gr.Textbox(label="标签文本", placeholder="生成的标签文本将显示在这里...") # 处理逻辑 def process_image(img, gen_thresh, char_thresh): tagger = Tagger() results = tagger.predict(img, gen_thresh, char_thresh) # 格式化文本输出 tag_text = ", ".join(results["general"].keys()) if results["characters"]: tag_text += ", " + ", ".join(results["characters"].keys()) return { general_tags: results["general"], char_tags: results["characters"], rating_tags: results["ratings"], output_text: tag_text } analyze_btn.click( process_image, inputs=[img_input, general_slider, char_slider], outputs=[general_tags, char_tags, rating_tags, output_text] ) # 启动应用 if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)