import os, json import gradio as gr import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd from PIL import Image from huggingface_hub import login from translator import translate_texts # ------------------------------------------------------------------ # 模型配置 # ------------------------------------------------------------------ MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") if HF_TOKEN: login(token=HF_TOKEN) else: print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") # ------------------------------------------------------------------ # Tagger 类 # ------------------------------------------------------------------ class Tagger: def __init__(self): self.hf_token = HF_TOKEN self._load_model_and_labels() def _load_model_and_labels(self): label_path = huggingface_hub.hf_hub_download( MODEL_REPO, LABEL_FILENAME, token=self.hf_token ) model_path = huggingface_hub.hf_hub_download( MODEL_REPO, MODEL_FILENAME, token=self.hf_token ) tags_df = pd.read_csv(label_path) self.tag_names = tags_df["name"].tolist() self.categories = { "rating": np.where(tags_df["category"] == 9)[0], "general": np.where(tags_df["category"] == 0)[0], "character": np.where(tags_df["category"] == 4)[0], } self.model = rt.InferenceSession(model_path) self.input_size = self.model.get_inputs()[0].shape[1] # ------------------------- preprocess ------------------------- def _preprocess(self, img: Image.Image) -> np.ndarray: if img.mode != "RGB": img = img.convert("RGB") size = max(img.size) canvas = Image.new("RGB", (size, size), (255, 255, 255)) canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) if size != self.input_size: canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR # --------------------------- predict -------------------------- def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): inp_name = self.model.get_inputs()[0].name outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] res = {"ratings": {}, "general": {}, "characters": {}} for idx in self.categories["rating"]: res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["general"]: if outputs[idx] > gen_th: res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["character"]: if outputs[idx] > char_th: res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True)) return res # ------------------------------------------------------------------ # Gradio UI # ------------------------------------------------------------------ custom_css = """ .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; } .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; cursor: pointer; } .tag-item:hover { background-color: #f0f0f0; } .tag-en { font-weight: bold; color: #333; } .tag-zh { color: #666; margin-left: 10px; } .tag-score { color: #999; font-size: 0.9em; } .btn-container { margin-top: 20px; } .copy-btn { margin-top: 10px; padding: 5px 10px; background-color: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; cursor: pointer; display: inline-flex; align-items: center; font-size: 0.9em; } .copy-btn:hover { background-color: #e0e0e0; } .copy-icon { margin-right: 5px; width: 16px; height: 16px; } .copied-message { display: none; color: #4CAF50; margin-left: 10px; font-size: 0.9em; } .note-text { color: #ff6b6b; font-size: 0.9em; padding: 5px; border-left: 3px solid #ff6b6b; margin-top: 15px; background-color: #fff5f5; } """ js_code = """ function setupCopyFunctions() { // 为标签项设置点击复制 document.querySelectorAll('.tag-item').forEach(item => { item.addEventListener('click', function() { const tagText = this.querySelector('.tag-en').textContent; navigator.clipboard.writeText(tagText).then(() => { // 显示临时复制成功提示 const msg = document.createElement('span'); msg.textContent = '已复制!'; msg.style.color = '#4CAF50'; msg.style.marginLeft = '5px'; msg.style.fontSize = '0.8em'; this.appendChild(msg); setTimeout(() => msg.remove(), 1000); }); }); }); // 为汇总区域的复制按钮设置功能 document.getElementById('copy-tags-btn').addEventListener('click', function() { const tagsText = document.getElementById('summary-text').value; navigator.clipboard.writeText(tagsText).then(() => { const copiedMsg = document.getElementById('copied-message'); copiedMsg.style.display = 'inline'; setTimeout(() => { copiedMsg.style.display = 'none'; }, 2000); }); }); } // 在DOM加载完成或更新后调用设置函数 function onUiUpdate() { setupCopyFunctions(); } document.addEventListener('DOMContentLoaded', onUiUpdate); """ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=js_code) as demo: gr.Markdown("# 🖼️ AI 图像标签分析器") gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") gr.Markdown("
暂无标签
" html = '