Spaces:
Running
Running
import os, json | |
import gradio as gr | |
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd | |
from PIL import Image | |
from huggingface_hub import login | |
from translator import translate_texts | |
# ------------------------------------------------------------------ | |
# 模型配置 | |
# ------------------------------------------------------------------ | |
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
else: | |
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") | |
# ------------------------------------------------------------------ | |
# Tagger 类 | |
# ------------------------------------------------------------------ | |
class Tagger: | |
def __init__(self): | |
self.hf_token = HF_TOKEN | |
self._load_model_and_labels() | |
def _load_model_and_labels(self): | |
label_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, LABEL_FILENAME, token=self.hf_token | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, MODEL_FILENAME, token=self.hf_token | |
) | |
tags_df = pd.read_csv(label_path) | |
self.tag_names = tags_df["name"].tolist() | |
self.categories = { | |
"rating": np.where(tags_df["category"] == 9)[0], | |
"general": np.where(tags_df["category"] == 0)[0], | |
"character": np.where(tags_df["category"] == 4)[0], | |
} | |
self.model = rt.InferenceSession(model_path) | |
self.input_size = self.model.get_inputs()[0].shape[1] | |
# ------------------------- preprocess ------------------------- | |
def _preprocess(self, img: Image.Image) -> np.ndarray: | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
size = max(img.size) | |
canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) | |
if size != self.input_size: | |
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR | |
# --------------------------- predict -------------------------- | |
def predict(self, img: Image.Image, | |
gen_th: float = 0.35, | |
char_th: float = 0.85): | |
inp_name = self.model.get_inputs()[0].name | |
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
res = {"ratings": {}, "general": {}, "characters": {}} | |
for idx in self.categories["rating"]: | |
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["general"]: | |
if outputs[idx] > gen_th: | |
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["character"]: | |
if outputs[idx] > char_th: | |
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
res["general"] = dict(sorted(res["general"].items(), | |
key=lambda kv: kv[1], | |
reverse=True)) | |
return res | |
# ------------------------------------------------------------------ | |
# Gradio UI | |
# ------------------------------------------------------------------ | |
custom_css = """ | |
.label-container { | |
max-height: 300px; | |
overflow-y: auto; | |
border: 1px solid #ddd; | |
padding: 10px; | |
border-radius: 5px; | |
background-color: #f9f9f9; | |
} | |
.tag-item { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin: 2px 0; | |
padding: 2px 5px; | |
border-radius: 3px; | |
background-color: #fff; | |
} | |
.tag-en { | |
font-weight: bold; | |
color: #333; | |
} | |
.tag-zh { | |
color: #666; | |
margin-left: 10px; | |
} | |
.tag-score { | |
color: #999; | |
font-size: 0.9em; | |
} | |
.btn-container { | |
margin-top: 20px; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo: | |
gr.Markdown("# 🖼️ AI 图像标签分析器") | |
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_in = gr.Image(type="pil", label="上传图片") | |
with gr.Accordion("⚙️ 高级设置", open=True): | |
gen_slider = gr.Slider(0, 1, 0.35, | |
label="通用标签阈值", info="越高→标签更少更准") | |
char_slider = gr.Slider(0, 1, 0.85, | |
label="角色标签阈值", info="推荐保持较高阈值") | |
show_zh = gr.Checkbox(True, label="显示中文翻译") | |
gr.Markdown("### 汇总设置") | |
with gr.Row(): | |
sum_general = gr.Checkbox(True, label="通用标签") | |
sum_char = gr.Checkbox(True, label="角色标签") | |
sum_rating = gr.Checkbox(False, label="评分标签") | |
sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符") | |
btn = gr.Button("开始分析", variant="primary", elem_classes=["btn-container"]) | |
processing_info = gr.Markdown("", visible=False) | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("🏷️ 通用标签"): | |
out_general = gr.HTML(label="General Tags") | |
with gr.TabItem("👤 角色标签"): | |
out_char = gr.HTML(label="Character Tags") | |
with gr.TabItem("⭐ 评分标签"): | |
out_rating = gr.HTML(label="Rating Tags") | |
gr.Markdown("### 标签汇总") | |
out_summary = gr.Textbox(label="标签汇总", | |
placeholder="选择需要汇总的标签类别...", | |
lines=3) | |
# ----------------- 处理回调 ----------------- | |
def format_tags_html(tags_dict, translations, show_translation=True): | |
"""格式化标签为HTML格式""" | |
if not tags_dict: | |
return "<p>暂无标签</p>" | |
html = '<div class="label-container">' | |
for i, (tag, score) in enumerate(tags_dict.items()): | |
tag_html = f'<div class="tag-item">' | |
tag_html += f'<div><span class="tag-en">{tag}</span>' | |
if show_translation and i < len(translations): | |
tag_html += f'<span class="tag-zh">({translations[i]})</span>' | |
tag_html += '</div>' | |
tag_html += f'<span class="tag-score">{score:.3f}</span>' | |
tag_html += '</div>' | |
html += tag_html | |
html += '</div>' | |
return html | |
def process(img, g_th, c_th, show_zh, sum_gen, sum_char, sum_rat, sep_type): | |
# 开始处理,返回更新 | |
yield ( | |
gr.update(interactive=False, value="处理中..."), | |
gr.update(visible=True, value="🔄 正在分析图像..."), | |
"", "", "", "" | |
) | |
try: | |
tagger = Tagger() | |
res = tagger.predict(img, g_th, c_th) | |
# 收集所有需要翻译的标签 | |
all_tags = [] | |
tag_categories = { | |
"general": list(res["general"].keys()), | |
"characters": list(res["characters"].keys()), | |
"ratings": list(res["ratings"].keys()) | |
} | |
if show_zh: | |
for tags in tag_categories.values(): | |
all_tags.extend(tags) | |
# 批量翻译 | |
if all_tags: | |
translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh") | |
else: | |
translations = [] | |
else: | |
translations = [] | |
# 分配翻译结果 | |
translations_dict = {} | |
offset = 0 | |
for category, tags in tag_categories.items(): | |
if show_zh and tags: | |
translations_dict[category] = translations[offset:offset+len(tags)] | |
offset += len(tags) | |
else: | |
translations_dict[category] = [] | |
# 生成HTML输出 | |
general_html = format_tags_html(res["general"], translations_dict["general"], show_zh) | |
char_html = format_tags_html(res["characters"], translations_dict["characters"], show_zh) | |
rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], show_zh) | |
# 生成汇总文本 | |
summary_parts = [] | |
separators = {"逗号": ", ", "换行": "\n", "空格": " "} | |
separator = separators[sep_type] | |
if sum_gen and res["general"]: | |
if show_zh and translations_dict["general"]: | |
gen_tags = [f"{en}({zh})" for en, zh in zip(res["general"].keys(), translations_dict["general"])] | |
else: | |
gen_tags = list(res["general"].keys()) | |
summary_parts.append("通用标签: " + separator.join(gen_tags)) | |
if sum_char and res["characters"]: | |
if show_zh and translations_dict["characters"]: | |
char_tags = [f"{en}({zh})" for en, zh in zip(res["characters"].keys(), translations_dict["characters"])] | |
else: | |
char_tags = list(res["characters"].keys()) | |
summary_parts.append("角色标签: " + separator.join(char_tags)) | |
if sum_rat and res["ratings"]: | |
if show_zh and translations_dict["ratings"]: | |
rat_tags = [f"{en}({zh})" for en, zh in zip(res["ratings"].keys(), translations_dict["ratings"])] | |
else: | |
rat_tags = list(res["ratings"].keys()) | |
summary_parts.append("评分标签: " + separator.join(rat_tags)) | |
summary_text = "\n\n".join(summary_parts) if summary_parts else "请选择要汇总的标签类别" | |
# 完成处理,返回最终结果 | |
yield ( | |
gr.update(interactive=True, value="开始分析"), | |
gr.update(visible=False), | |
general_html, | |
char_html, | |
rating_html, | |
summary_text | |
) | |
except Exception as e: | |
# 出错时的处理 | |
yield ( | |
gr.update(interactive=True, value="开始分析"), | |
gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), | |
"", "", "", "" | |
) | |
# 绑定事件 | |
btn.click( | |
process, | |
inputs=[img_in, gen_slider, char_slider, show_zh, sum_general, sum_char, sum_rating, sum_sep], | |
outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary], | |
show_progress=True | |
) | |
# ------------------------------------------------------------------ | |
# 启动 | |
# ------------------------------------------------------------------ | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |