import os, json import gradio as gr import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd from PIL import Image from huggingface_hub import login from translator import translate_texts # ------------------------------------------------------------------ # 模型配置 # ------------------------------------------------------------------ MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") if HF_TOKEN: login(token=HF_TOKEN) else: print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") # ------------------------------------------------------------------ # Tagger 类 # ------------------------------------------------------------------ class Tagger: def __init__(self): self.hf_token = HF_TOKEN self._load_model_and_labels() def _load_model_and_labels(self): label_path = huggingface_hub.hf_hub_download( MODEL_REPO, LABEL_FILENAME, token=self.hf_token ) model_path = huggingface_hub.hf_hub_download( MODEL_REPO, MODEL_FILENAME, token=self.hf_token ) tags_df = pd.read_csv(label_path) self.tag_names = tags_df["name"].tolist() self.categories = { "rating": np.where(tags_df["category"] == 9)[0], "general": np.where(tags_df["category"] == 0)[0], "character": np.where(tags_df["category"] == 4)[0], } self.model = rt.InferenceSession(model_path) self.input_size = self.model.get_inputs()[0].shape[1] # ------------------------- preprocess ------------------------- def _preprocess(self, img: Image.Image) -> np.ndarray: if img.mode != "RGB": img = img.convert("RGB") size = max(img.size) canvas = Image.new("RGB", (size, size), (255, 255, 255)) canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) if size != self.input_size: canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR # --------------------------- predict -------------------------- def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): inp_name = self.model.get_inputs()[0].name outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] res = {"ratings": {}, "general": {}, "characters": {}} for idx in self.categories["rating"]: res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["general"]: if outputs[idx] > gen_th: res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["character"]: if outputs[idx] > char_th: res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True)) return res # ------------------------------------------------------------------ # Gradio UI # ------------------------------------------------------------------ custom_css = """ .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; } .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; } .tag-en { font-weight: bold; color: #333; } .tag-zh { color: #666; margin-left: 10px; } .tag-score { color: #999; font-size: 0.9em; } .btn-container { margin-top: 20px; } """ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo: gr.Markdown("# 🖼️ AI 图像标签分析器") gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="上传图片") with gr.Accordion("⚙️ 高级设置", open=True): gen_slider = gr.Slider(0, 1, 0.35, label="通用标签阈值", info="越高→标签更少更准") char_slider = gr.Slider(0, 1, 0.85, label="角色标签阈值", info="推荐保持较高阈值") show_zh = gr.Checkbox(True, label="显示中文翻译") gr.Markdown("### 汇总设置") with gr.Row(): sum_general = gr.Checkbox(True, label="通用标签") sum_char = gr.Checkbox(True, label="角色标签") sum_rating = gr.Checkbox(False, label="评分标签") sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符") btn = gr.Button("开始分析", variant="primary", elem_classes=["btn-container"]) processing_info = gr.Markdown("", visible=False) with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags") with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags") with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags") gr.Markdown("### 标签汇总") out_summary = gr.Textbox(label="标签汇总", placeholder="选择需要汇总的标签类别...", lines=3) # ----------------- 处理回调 ----------------- def format_tags_html(tags_dict, translations, show_translation=True): """格式化标签为HTML格式""" if not tags_dict: return "
暂无标签
" html = '