Image_Inversion / app.py
IdlecloudX's picture
Update app.py
fcaf260 verified
raw
history blame
16.4 kB
import os, json
import gradio as gr
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
from PIL import Image
from huggingface_hub import login
from translator import translate_texts
# ------------------------------------------------------------------
# 模型配置
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
# ------------------------------------------------------------------
# Tagger 类
# ------------------------------------------------------------------
class Tagger:
def __init__(self):
self.hf_token = HF_TOKEN
self._load_model_and_labels()
def _load_model_and_labels(self):
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO, LABEL_FILENAME, token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO, MODEL_FILENAME, token=self.hf_token
)
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0],
}
self.model = rt.InferenceSession(model_path)
self.input_size = self.model.get_inputs()[0].shape[1]
# ------------------------- preprocess -------------------------
def _preprocess(self, img: Image.Image) -> np.ndarray:
if img.mode != "RGB":
img = img.convert("RGB")
size = max(img.size)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
if size != self.input_size:
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
# --------------------------- predict --------------------------
def predict(self, img: Image.Image,
gen_th: float = 0.35,
char_th: float = 0.85):
inp_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
res = {"ratings": {}, "general": {}, "characters": {}}
for idx in self.categories["rating"]:
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["general"]:
if outputs[idx] > gen_th:
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["character"]:
if outputs[idx] > char_th:
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
res["general"] = dict(sorted(res["general"].items(),
key=lambda kv: kv[1],
reverse=True))
return res
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
custom_css = """
.label-container {
max-height: 300px;
overflow-y: auto;
border: 1px solid #ddd;
padding: 10px;
border-radius: 5px;
background-color: #f9f9f9;
}
.tag-item {
display: flex;
justify-content: space-between;
align-items: center;
margin: 2px 0;
padding: 2px 5px;
border-radius: 3px;
background-color: #fff;
cursor: pointer;
transition: background-color 0.2s;
}
.tag-item:hover {
background-color: #e8f4ff;
}
.tag-item:active {
background-color: #bde0ff;
}
.tag-content {
display: flex;
align-items: center;
gap: 10px;
flex: 1;
}
.tag-text {
font-weight: bold;
color: #333;
}
.tag-score {
color: #999;
font-size: 0.9em;
}
.copy-container {
position: relative;
margin-bottom: 5px;
}
.copy-button {
position: absolute;
top: 5px;
right: 5px;
padding: 4px 8px;
font-size: 12px;
background-color: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
cursor: pointer;
transition: all 0.2s;
}
.copy-button:hover {
background-color: #e0e0e0;
}
.copy-button:active {
background-color: #d0d0d0;
}
.toast {
position: fixed;
top: 20px;
right: 20px;
padding: 10px 20px;
background-color: #4CAF50;
color: white;
border-radius: 4px;
opacity: 0;
transition: opacity 0.3s;
z-index: 1000;
}
.toast.show {
opacity: 1;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo:
gr.Markdown("# 🖼️ AI 图像标签分析器")
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="上传图片")
with gr.Accordion("⚙️ 高级设置", open=True):
gen_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值", info="越高→标签更少更准")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值", info="推荐保持较高阈值")
gr.Markdown("### 汇总设置")
with gr.Row():
sum_general = gr.Checkbox(True, label="通用标签")
sum_char = gr.Checkbox(True, label="角色标签")
sum_rating = gr.Checkbox(False, label="评分标签")
sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签"):
out_general = gr.HTML(label="General Tags")
with gr.TabItem("👤 角色标签"):
out_char = gr.HTML(label="Character Tags")
with gr.TabItem("⭐ 评分标签"):
out_rating = gr.HTML(label="Rating Tags")
gr.Markdown("### 标签汇总")
with gr.Row():
lang_btn = gr.Button("中/EN", variant="secondary", scale=0)
copy_btn = gr.Button("📋 复制", variant="secondary", scale=0)
out_summary = gr.Textbox(label="标签汇总",
placeholder="选择需要汇总的标签类别...",
lines=3,
interactive=False)
with gr.Row():
processing_info = gr.Markdown("", visible=False)
btn = gr.Button("开始分析", variant="primary", scale=0)
# 存储状态的隐藏组件
lang_state = gr.State("en") # 默认显示英文
tags_data = gr.State({}) # 存储标签数据
translations_data = gr.State({}) # 存储翻译数据
# ----------------- 处理回调 -----------------
def format_tags_html(tags_dict, translations, category_key, current_lang):
"""格式化标签为HTML格式"""
if not tags_dict:
return "<p>暂无标签</p>"
html = '<div class="label-container">'
for i, (tag, score) in enumerate(tags_dict.items()):
display_text = translations[i] if current_lang == "zh" and i < len(translations) else tag
tag_html = f'''
<div class="tag-item" onclick="copyToClipboard('{tag}', '{category_key}_{i}')">
<div class="tag-content">
<span class="tag-text">{display_text}</span>
</div>
<span class="tag-score">{score:.3f}</span>
</div>
'''
html += tag_html
html += '</div>'
# 添加复制函数的JavaScript
copy_script = '''
<script>
function copyToClipboard(text, itemId) {
navigator.clipboard.writeText(text).then(function() {
showToast('已复制: ' + text);
});
}
function showToast(message) {
var toast = document.createElement('div');
toast.className = 'toast show';
toast.textContent = message;
document.body.appendChild(toast);
setTimeout(function() {
toast.classList.remove('show');
setTimeout(function() {
document.body.removeChild(toast);
}, 300);
}, 1500);
}
</script>
'''
return html + copy_script
def process(img, g_th, c_th, sum_gen, sum_char, sum_rat, sep_type, current_lang, prev_tags, prev_translations):
# 开始处理
yield (
gr.update(interactive=False, value="处理中..."),
gr.update(visible=True, value="🔄 正在分析图像..."),
"", "", "", "", current_lang, {}, {}
)
try:
tagger = Tagger()
res = tagger.predict(img, g_th, c_th)
# 收集所有需要翻译的标签
all_tags = []
tag_categories = {
"general": list(res["general"].keys()),
"characters": list(res["characters"].keys()),
"ratings": list(res["ratings"].keys())
}
for tags in tag_categories.values():
all_tags.extend(tags)
# 批量翻译
if all_tags:
translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh")
else:
translations = []
# 分配翻译结果
translations_dict = {}
offset = 0
for category, tags in tag_categories.items():
if tags:
translations_dict[category] = translations[offset:offset+len(tags)]
offset += len(tags)
else:
translations_dict[category] = []
# 生成HTML输出
general_html = format_tags_html(res["general"], translations_dict["general"], "general", current_lang)
char_html = format_tags_html(res["characters"], translations_dict["characters"], "characters", current_lang)
rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], "ratings", current_lang)
# 生成汇总文本
summary_tags = []
separators = {"逗号": ", ", "换行": "\n", "空格": " "}
separator = separators[sep_type]
# 按顺序:角色、通用、评分
if sum_char and res["characters"]:
if current_lang == "zh" and translations_dict["characters"]:
summary_tags.extend(translations_dict["characters"])
else:
summary_tags.extend(list(res["characters"].keys()))
if sum_gen and res["general"]:
if current_lang == "zh" and translations_dict["general"]:
summary_tags.extend(translations_dict["general"])
else:
summary_tags.extend(list(res["general"].keys()))
if sum_rat and res["ratings"]:
if current_lang == "zh" and translations_dict["ratings"]:
summary_tags.extend(translations_dict["ratings"])
else:
summary_tags.extend(list(res["ratings"].keys()))
summary_text = separator.join(summary_tags) if summary_tags else "请选择要汇总的标签类别"
# 完成处理
yield (
gr.update(interactive=True, value="开始分析"),
gr.update(visible=False),
general_html,
char_html,
rating_html,
summary_text,
current_lang,
res,
translations_dict
)
except Exception as e:
# 出错处理
yield (
gr.update(interactive=True, value="开始分析"),
gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
"", "", "", "", current_lang, {}, {}
)
def toggle_language(current_lang, tags, translations):
"""切换语言显示"""
new_lang = "zh" if current_lang == "en" else "en"
# 重新生成HTML
general_html = format_tags_html(tags.get("general", {}), translations.get("general", []), "general", new_lang)
char_html = format_tags_html(tags.get("characters", {}), translations.get("characters", []), "characters", new_lang)
rating_html = format_tags_html(tags.get("ratings", {}), translations.get("ratings", []), "ratings", new_lang)
# 更新汇总文本
current_summary = out_summary.value if hasattr(out_summary, 'value') else ""
if current_summary and current_summary != "请选择要汇总的标签类别":
# 需要重新生成汇总文本
summary_tags = []
separator = ", " # 这里简化,实际应该记住用户选择的分隔符
# 检查选择的类别并生成汇总
# 注意:这里只是示例,实际需要传入选择状态
for category, category_tags in tags.items():
if category_tags:
if new_lang == "zh" and translations.get(category):
summary_tags.extend(translations[category])
else:
summary_tags.extend(list(category_tags.keys()))
summary_text = separator.join(summary_tags) if summary_tags else current_summary
else:
summary_text = current_summary
return (
new_lang,
general_html,
char_html,
rating_html,
summary_text
)
def copy_summary(text):
"""提示复制汇总文本"""
# 使用JavaScript来复制文本
copy_js = f'''
<script>
navigator.clipboard.writeText(`{text}`).then(function() {{
showCopyToast('标签已复制到剪贴板');
}});
function showCopyToast(message) {{
var toast = document.createElement('div');
toast.className = 'toast show';
toast.textContent = message;
toast.style.position = 'fixed';
toast.style.top = '20px';
toast.style.right = '20px';
toast.style.padding = '10px 20px';
toast.style.backgroundColor = '#4CAF50';
toast.style.color = 'white';
toast.style.borderRadius = '4px';
toast.style.zIndex = '1000';
document.body.appendChild(toast);
setTimeout(function() {{
toast.remove();
}}, 1500);
}}
</script>
'''
return gr.update(value=copy_js)
# 绑定事件
btn.click(
process,
inputs=[img_in, gen_slider, char_slider, sum_general, sum_char, sum_rating, sum_sep, lang_state, tags_data, translations_data],
outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary, lang_state, tags_data, translations_data],
show_progress=True
)
lang_btn.click(
toggle_language,
inputs=[lang_state, tags_data, translations_data],
outputs=[lang_state, out_general, out_char, out_rating, out_summary]
)
copy_btn.click(
copy_summary,
inputs=[out_summary],
outputs=[gr.HTML(visible=False)]
)
# ------------------------------------------------------------------
# 启动
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)