Image_Inversion / app.py
IdlecloudX's picture
Update app.py
1eb8a26 verified
raw
history blame
11.4 kB
import os, json
import gradio as gr
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
from PIL import Image
from huggingface_hub import login
from translator import translate_texts
# ------------------------------------------------------------------
# 模型配置
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
# ------------------------------------------------------------------
# Tagger 类
# ------------------------------------------------------------------
class Tagger:
def __init__(self):
self.hf_token = HF_TOKEN
self._load_model_and_labels()
def _load_model_and_labels(self):
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO, LABEL_FILENAME, token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO, MODEL_FILENAME, token=self.hf_token
)
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0],
}
self.model = rt.InferenceSession(model_path)
self.input_size = self.model.get_inputs()[0].shape[1]
# ------------------------- preprocess -------------------------
def _preprocess(self, img: Image.Image) -> np.ndarray:
if img.mode != "RGB":
img = img.convert("RGB")
size = max(img.size)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
if size != self.input_size:
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
# --------------------------- predict --------------------------
def predict(self, img: Image.Image,
gen_th: float = 0.35,
char_th: float = 0.85):
inp_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
res = {"ratings": {}, "general": {}, "characters": {}}
for idx in self.categories["rating"]:
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["general"]:
if outputs[idx] > gen_th:
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["character"]:
if outputs[idx] > char_th:
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
res["general"] = dict(sorted(res["general"].items(),
key=lambda kv: kv[1],
reverse=True))
return res
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
custom_css = """
.label-container {
max-height: 300px;
overflow-y: auto;
border: 1px solid #ddd;
padding: 10px;
border-radius: 5px;
background-color: #f9f9f9;
}
.tag-item {
display: flex;
justify-content: space-between;
align-items: center;
margin: 2px 0;
padding: 2px 5px;
border-radius: 3px;
background-color: #fff;
}
.tag-en {
font-weight: bold;
color: #333;
}
.tag-zh {
color: #666;
margin-left: 10px;
}
.tag-score {
color: #999;
font-size: 0.9em;
}
.btn-container {
margin-top: 20px;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo:
gr.Markdown("# 🖼️ AI 图像标签分析器")
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="上传图片")
with gr.Accordion("⚙️ 高级设置", open=True):
gen_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值", info="越高→标签更少更准")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值", info="推荐保持较高阈值")
show_zh = gr.Checkbox(True, label="显示中文翻译")
gr.Markdown("### 汇总设置")
with gr.Row():
sum_general = gr.Checkbox(True, label="通用标签")
sum_char = gr.Checkbox(True, label="角色标签")
sum_rating = gr.Checkbox(False, label="评分标签")
sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符")
btn = gr.Button("开始分析", variant="primary", elem_classes=["btn-container"])
processing_info = gr.Markdown("", visible=False)
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签"):
out_general = gr.HTML(label="General Tags")
with gr.TabItem("👤 角色标签"):
out_char = gr.HTML(label="Character Tags")
with gr.TabItem("⭐ 评分标签"):
out_rating = gr.HTML(label="Rating Tags")
gr.Markdown("### 标签汇总")
out_summary = gr.Textbox(label="标签汇总",
placeholder="选择需要汇总的标签类别...",
lines=3)
# ----------------- 处理回调 -----------------
def format_tags_html(tags_dict, translations, show_translation=True):
"""格式化标签为HTML格式"""
if not tags_dict:
return "<p>暂无标签</p>"
html = '<div class="label-container">'
for i, (tag, score) in enumerate(tags_dict.items()):
tag_html = f'<div class="tag-item">'
tag_html += f'<div><span class="tag-en">{tag}</span>'
if show_translation and i < len(translations):
tag_html += f'<span class="tag-zh">({translations[i]})</span>'
tag_html += '</div>'
tag_html += f'<span class="tag-score">{score:.3f}</span>'
tag_html += '</div>'
html += tag_html
html += '</div>'
return html
def process(img, g_th, c_th, show_zh, sum_gen, sum_char, sum_rat, sep_type):
# 开始处理,返回更新
yield (
gr.update(interactive=False, value="处理中..."),
gr.update(visible=True, value="🔄 正在分析图像..."),
"", "", "", ""
)
try:
tagger = Tagger()
res = tagger.predict(img, g_th, c_th)
# 收集所有需要翻译的标签
all_tags = []
tag_categories = {
"general": list(res["general"].keys()),
"characters": list(res["characters"].keys()),
"ratings": list(res["ratings"].keys())
}
if show_zh:
for tags in tag_categories.values():
all_tags.extend(tags)
# 批量翻译
if all_tags:
translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh")
else:
translations = []
else:
translations = []
# 分配翻译结果
translations_dict = {}
offset = 0
for category, tags in tag_categories.items():
if show_zh and tags:
translations_dict[category] = translations[offset:offset+len(tags)]
offset += len(tags)
else:
translations_dict[category] = []
# 生成HTML输出
general_html = format_tags_html(res["general"], translations_dict["general"], show_zh)
char_html = format_tags_html(res["characters"], translations_dict["characters"], show_zh)
rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], show_zh)
# 生成汇总文本
summary_parts = []
separators = {"逗号": ", ", "换行": "\n", "空格": " "}
separator = separators[sep_type]
if sum_gen and res["general"]:
if show_zh and translations_dict["general"]:
gen_tags = [f"{en}({zh})" for en, zh in zip(res["general"].keys(), translations_dict["general"])]
else:
gen_tags = list(res["general"].keys())
summary_parts.append("通用标签: " + separator.join(gen_tags))
if sum_char and res["characters"]:
if show_zh and translations_dict["characters"]:
char_tags = [f"{en}({zh})" for en, zh in zip(res["characters"].keys(), translations_dict["characters"])]
else:
char_tags = list(res["characters"].keys())
summary_parts.append("角色标签: " + separator.join(char_tags))
if sum_rat and res["ratings"]:
if show_zh and translations_dict["ratings"]:
rat_tags = [f"{en}({zh})" for en, zh in zip(res["ratings"].keys(), translations_dict["ratings"])]
else:
rat_tags = list(res["ratings"].keys())
summary_parts.append("评分标签: " + separator.join(rat_tags))
summary_text = "\n\n".join(summary_parts) if summary_parts else "请选择要汇总的标签类别"
# 完成处理,返回最终结果
yield (
gr.update(interactive=True, value="开始分析"),
gr.update(visible=False),
general_html,
char_html,
rating_html,
summary_text
)
except Exception as e:
# 出错时的处理
yield (
gr.update(interactive=True, value="开始分析"),
gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
"", "", "", ""
)
# 绑定事件
btn.click(
process,
inputs=[img_in, gen_slider, char_slider, show_zh, sum_general, sum_char, sum_rating, sum_sep],
outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary],
show_progress=True
)
# ------------------------------------------------------------------
# 启动
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)