Spaces:
Running
Running
import os, json | |
import gradio as gr | |
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd | |
from PIL import Image | |
from huggingface_hub import login | |
from translator import translate_texts | |
# ------------------------------------------------------------------ | |
# 模型配置 | |
# ------------------------------------------------------------------ | |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
else: | |
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") | |
# ------------------------------------------------------------------ | |
# Tagger 类 | |
# ------------------------------------------------------------------ | |
class Tagger: | |
def __init__(self): | |
self.hf_token = HF_TOKEN | |
self._load_model_and_labels() | |
def _load_model_and_labels(self): | |
label_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, LABEL_FILENAME, token=self.hf_token | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, MODEL_FILENAME, token=self.hf_token | |
) | |
tags_df = pd.read_csv(label_path) | |
self.tag_names = tags_df["name"].tolist() | |
self.categories = { | |
"rating": np.where(tags_df["category"] == 9)[0], | |
"general": np.where(tags_df["category"] == 0)[0], | |
"character": np.where(tags_df["category"] == 4)[0], | |
} | |
self.model = rt.InferenceSession(model_path) | |
self.input_size = self.model.get_inputs()[0].shape[1] | |
# ------------------------- preprocess ------------------------- | |
def _preprocess(self, img: Image.Image) -> np.ndarray: | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
size = max(img.size) | |
canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) | |
if size != self.input_size: | |
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR | |
# --------------------------- predict -------------------------- | |
def predict(self, img: Image.Image, | |
gen_th: float = 0.35, | |
char_th: float = 0.85): | |
inp_name = self.model.get_inputs()[0].name | |
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
res = {"ratings": {}, "general": {}, "characters": {}} | |
for idx in self.categories["rating"]: | |
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["general"]: | |
if outputs[idx] > gen_th: | |
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["character"]: | |
if outputs[idx] > char_th: | |
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
res["general"] = dict(sorted(res["general"].items(), | |
key=lambda kv: kv[1], | |
reverse=True)) | |
return res | |
# ------------------------------------------------------------------ | |
# Gradio UI | |
# ------------------------------------------------------------------ | |
custom_css = """ | |
.label-container { | |
max-height: 300px; | |
overflow-y: auto; | |
border: 1px solid #ddd; | |
padding: 10px; | |
border-radius: 5px; | |
background-color: #f9f9f9; | |
} | |
.tag-item { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin: 2px 0; | |
padding: 2px 5px; | |
border-radius: 3px; | |
background-color: #fff; | |
cursor: pointer; | |
transition: background-color 0.2s; | |
} | |
.tag-item:hover { | |
background-color: #e8f4ff; | |
} | |
.tag-item:active { | |
background-color: #bde0ff; | |
} | |
.tag-content { | |
display: flex; | |
align-items: center; | |
gap: 10px; | |
flex: 1; | |
} | |
.tag-text { | |
font-weight: bold; | |
color: #333; | |
} | |
.tag-score { | |
color: #999; | |
font-size: 0.9em; | |
} | |
.copy-container { | |
position: relative; | |
margin-bottom: 5px; | |
} | |
.copy-button { | |
position: absolute; | |
top: 5px; | |
right: 5px; | |
padding: 4px 8px; | |
font-size: 12px; | |
background-color: #f0f0f0; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
cursor: pointer; | |
transition: all 0.2s; | |
} | |
.copy-button:hover { | |
background-color: #e0e0e0; | |
} | |
.copy-button:active { | |
background-color: #d0d0d0; | |
} | |
.toast { | |
position: fixed; | |
top: 20px; | |
right: 20px; | |
padding: 10px 20px; | |
background-color: #4CAF50; | |
color: white; | |
border-radius: 4px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
z-index: 1000; | |
} | |
.toast.show { | |
opacity: 1; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo: | |
gr.Markdown("# 🖼️ AI 图像标签分析器") | |
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_in = gr.Image(type="pil", label="上传图片") | |
with gr.Accordion("⚙️ 高级设置", open=True): | |
gen_slider = gr.Slider(0, 1, 0.35, | |
label="通用标签阈值", info="越高→标签更少更准") | |
char_slider = gr.Slider(0, 1, 0.85, | |
label="角色标签阈值", info="推荐保持较高阈值") | |
gr.Markdown("### 汇总设置") | |
with gr.Row(): | |
sum_general = gr.Checkbox(True, label="通用标签") | |
sum_char = gr.Checkbox(True, label="角色标签") | |
sum_rating = gr.Checkbox(False, label="评分标签") | |
sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("🏷️ 通用标签"): | |
out_general = gr.HTML(label="General Tags") | |
with gr.TabItem("👤 角色标签"): | |
out_char = gr.HTML(label="Character Tags") | |
with gr.TabItem("⭐ 评分标签"): | |
out_rating = gr.HTML(label="Rating Tags") | |
gr.Markdown("### 标签汇总") | |
with gr.Row(): | |
lang_btn = gr.Button("中/EN", variant="secondary", scale=0) | |
copy_btn = gr.Button("📋 复制", variant="secondary", scale=0) | |
out_summary = gr.Textbox(label="标签汇总", | |
placeholder="选择需要汇总的标签类别...", | |
lines=3, | |
interactive=False) | |
with gr.Row(): | |
processing_info = gr.Markdown("", visible=False) | |
btn = gr.Button("开始分析", variant="primary", scale=0) | |
# 存储状态的隐藏组件 | |
lang_state = gr.State("en") # 默认显示英文 | |
tags_data = gr.State({}) # 存储标签数据 | |
translations_data = gr.State({}) # 存储翻译数据 | |
# ----------------- 处理回调 ----------------- | |
def format_tags_html(tags_dict, translations, category_key, current_lang): | |
"""格式化标签为HTML格式""" | |
if not tags_dict: | |
return "<p>暂无标签</p>" | |
html = '<div class="label-container">' | |
for i, (tag, score) in enumerate(tags_dict.items()): | |
display_text = translations[i] if current_lang == "zh" and i < len(translations) else tag | |
tag_html = f''' | |
<div class="tag-item" onclick="copyToClipboard('{tag}', '{category_key}_{i}')"> | |
<div class="tag-content"> | |
<span class="tag-text">{display_text}</span> | |
</div> | |
<span class="tag-score">{score:.3f}</span> | |
</div> | |
''' | |
html += tag_html | |
html += '</div>' | |
# 添加复制函数的JavaScript | |
copy_script = ''' | |
<script> | |
function copyToClipboard(text, itemId) { | |
navigator.clipboard.writeText(text).then(function() { | |
showToast('已复制: ' + text); | |
}); | |
} | |
function showToast(message) { | |
var toast = document.createElement('div'); | |
toast.className = 'toast show'; | |
toast.textContent = message; | |
document.body.appendChild(toast); | |
setTimeout(function() { | |
toast.classList.remove('show'); | |
setTimeout(function() { | |
document.body.removeChild(toast); | |
}, 300); | |
}, 1500); | |
} | |
</script> | |
''' | |
return html + copy_script | |
def process(img, g_th, c_th, sum_gen, sum_char, sum_rat, sep_type, current_lang, prev_tags, prev_translations): | |
# 开始处理 | |
yield ( | |
gr.update(interactive=False, value="处理中..."), | |
gr.update(visible=True, value="🔄 正在分析图像..."), | |
"", "", "", "", current_lang, {}, {} | |
) | |
try: | |
tagger = Tagger() | |
res = tagger.predict(img, g_th, c_th) | |
# 收集所有需要翻译的标签 | |
all_tags = [] | |
tag_categories = { | |
"general": list(res["general"].keys()), | |
"characters": list(res["characters"].keys()), | |
"ratings": list(res["ratings"].keys()) | |
} | |
for tags in tag_categories.values(): | |
all_tags.extend(tags) | |
# 批量翻译 | |
if all_tags: | |
translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh") | |
else: | |
translations = [] | |
# 分配翻译结果 | |
translations_dict = {} | |
offset = 0 | |
for category, tags in tag_categories.items(): | |
if tags: | |
translations_dict[category] = translations[offset:offset+len(tags)] | |
offset += len(tags) | |
else: | |
translations_dict[category] = [] | |
# 生成HTML输出 | |
general_html = format_tags_html(res["general"], translations_dict["general"], "general", current_lang) | |
char_html = format_tags_html(res["characters"], translations_dict["characters"], "characters", current_lang) | |
rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], "ratings", current_lang) | |
# 生成汇总文本 | |
summary_tags = [] | |
separators = {"逗号": ", ", "换行": "\n", "空格": " "} | |
separator = separators[sep_type] | |
# 按顺序:角色、通用、评分 | |
if sum_char and res["characters"]: | |
if current_lang == "zh" and translations_dict["characters"]: | |
summary_tags.extend(translations_dict["characters"]) | |
else: | |
summary_tags.extend(list(res["characters"].keys())) | |
if sum_gen and res["general"]: | |
if current_lang == "zh" and translations_dict["general"]: | |
summary_tags.extend(translations_dict["general"]) | |
else: | |
summary_tags.extend(list(res["general"].keys())) | |
if sum_rat and res["ratings"]: | |
if current_lang == "zh" and translations_dict["ratings"]: | |
summary_tags.extend(translations_dict["ratings"]) | |
else: | |
summary_tags.extend(list(res["ratings"].keys())) | |
summary_text = separator.join(summary_tags) if summary_tags else "请选择要汇总的标签类别" | |
# 完成处理 | |
yield ( | |
gr.update(interactive=True, value="开始分析"), | |
gr.update(visible=False), | |
general_html, | |
char_html, | |
rating_html, | |
summary_text, | |
current_lang, | |
res, | |
translations_dict | |
) | |
except Exception as e: | |
# 出错处理 | |
yield ( | |
gr.update(interactive=True, value="开始分析"), | |
gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), | |
"", "", "", "", current_lang, {}, {} | |
) | |
def toggle_language(current_lang, tags, translations): | |
"""切换语言显示""" | |
new_lang = "zh" if current_lang == "en" else "en" | |
# 重新生成HTML | |
general_html = format_tags_html(tags.get("general", {}), translations.get("general", []), "general", new_lang) | |
char_html = format_tags_html(tags.get("characters", {}), translations.get("characters", []), "characters", new_lang) | |
rating_html = format_tags_html(tags.get("ratings", {}), translations.get("ratings", []), "ratings", new_lang) | |
# 更新汇总文本 | |
current_summary = out_summary.value if hasattr(out_summary, 'value') else "" | |
if current_summary and current_summary != "请选择要汇总的标签类别": | |
# 需要重新生成汇总文本 | |
summary_tags = [] | |
separator = ", " # 这里简化,实际应该记住用户选择的分隔符 | |
# 检查选择的类别并生成汇总 | |
# 注意:这里只是示例,实际需要传入选择状态 | |
for category, category_tags in tags.items(): | |
if category_tags: | |
if new_lang == "zh" and translations.get(category): | |
summary_tags.extend(translations[category]) | |
else: | |
summary_tags.extend(list(category_tags.keys())) | |
summary_text = separator.join(summary_tags) if summary_tags else current_summary | |
else: | |
summary_text = current_summary | |
return ( | |
new_lang, | |
general_html, | |
char_html, | |
rating_html, | |
summary_text | |
) | |
def copy_summary(text): | |
"""提示复制汇总文本""" | |
# 使用JavaScript来复制文本 | |
copy_js = f''' | |
<script> | |
navigator.clipboard.writeText(`{text}`).then(function() {{ | |
showCopyToast('标签已复制到剪贴板'); | |
}}); | |
function showCopyToast(message) {{ | |
var toast = document.createElement('div'); | |
toast.className = 'toast show'; | |
toast.textContent = message; | |
toast.style.position = 'fixed'; | |
toast.style.top = '20px'; | |
toast.style.right = '20px'; | |
toast.style.padding = '10px 20px'; | |
toast.style.backgroundColor = '#4CAF50'; | |
toast.style.color = 'white'; | |
toast.style.borderRadius = '4px'; | |
toast.style.zIndex = '1000'; | |
document.body.appendChild(toast); | |
setTimeout(function() {{ | |
toast.remove(); | |
}}, 1500); | |
}} | |
</script> | |
''' | |
return gr.update(value=copy_js) | |
# 绑定事件 | |
btn.click( | |
process, | |
inputs=[img_in, gen_slider, char_slider, sum_general, sum_char, sum_rating, sum_sep, lang_state, tags_data, translations_data], | |
outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary, lang_state, tags_data, translations_data], | |
show_progress=True | |
) | |
lang_btn.click( | |
toggle_language, | |
inputs=[lang_state, tags_data, translations_data], | |
outputs=[lang_state, out_general, out_char, out_rating, out_summary] | |
) | |
copy_btn.click( | |
copy_summary, | |
inputs=[out_summary], | |
outputs=[gr.HTML(visible=False)] | |
) | |
# ------------------------------------------------------------------ | |
# 启动 | |
# ------------------------------------------------------------------ | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |