Image_Inversion / app.py
IdlecloudX's picture
Update app.py
4412065 verified
raw
history blame
6.67 kB
import os
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from huggingface_hub import login
# 模型配置
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if not os.environ.get("HF_TOKEN"):
print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
else:
login(token=os.environ.get("HF_TOKEN"))
# 标签处理配置
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
class Tagger:
def __init__(self):
self.model = None
self.tag_names = []
self.model_size = None
self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
self._init_model()
def _init_model(self):
"""初始化模型和标签"""
try:
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO,
LABEL_FILENAME,
token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO,
MODEL_FILENAME,
token=self.hf_token
)
# 加载标签
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0]
}
# 加载ONNX模型
self.model = rt.InferenceSession(model_path)
self.model_size = self.model.get_inputs()[0].shape[1]
except huggingface_hub.utils.HfHubHTTPError as e:
if "401" in str(e):
raise RuntimeError(
"模型下载认证失败,请:\n"
"1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
"2. 点击Agree and continue\n"
"3. 确保HF_TOKEN已正确设置"
)
else:
raise
def _preprocess(self, img):
"""图像预处理"""
# 转换为RGB
if img.mode != "RGB":
img = img.convert("RGB")
# 填充为正方形
size = max(img.size)
padded = Image.new("RGB", (size, size), (255, 255, 255))
padded.paste(img, ((size - img.width)//2, (size - img.height)//2))
# 调整尺寸
if size != self.model_size:
padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC)
# 转换为BGR格式
return np.array(padded)[:, :, ::-1].astype(np.float32)
def predict(self, img, general_thresh=0.35, character_thresh=0.85):
"""执行预测"""
# 预处理
img_data = self._preprocess(img)[np.newaxis]
# 运行模型
input_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {input_name: img_data})[0][0]
# 组织结果
results = {
"ratings": {},
"general": {},
"characters": {}
}
# 处理评分标签
for idx in self.categories["rating"]:
tag = self.tag_names[idx].replace("_", " ")
results["ratings"][tag] = float(outputs[idx])
# 处理通用标签
for idx in self.categories["general"]:
if outputs[idx] > general_thresh:
tag = self.tag_names[idx].replace("_", " ")
results["general"][tag] = float(outputs[idx])
# 处理角色标签
for idx in self.categories["character"]:
if outputs[idx] > character_thresh:
tag = self.tag_names[idx].replace("_", " ")
results["characters"][tag] = float(outputs[idx])
# 排序结果
results["general"] = dict(sorted(
results["general"].items(),
key=lambda x: x[1],
reverse=True
))
return results
# 创建Gradio界面
with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo:
gr.Markdown("# 🖼️ AI图像标签分析器")
gr.Markdown("上传图片自动分析图像内容标签")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="上传图片")
with gr.Accordion("高级设置", open=False):
general_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值",
info="值越高标签越少但更准确")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值",
info="推荐保持较高阈值")
analyze_btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签"):
general_tags = gr.Label(label="检测到的通用标签")
with gr.TabItem("👤 角色标签"):
char_tags = gr.Label(label="检测到的角色标签")
with gr.TabItem("⭐ 评分标签"):
rating_tags = gr.Label(label="图像评级标签")
output_text = gr.Textbox(label="标签文本",
placeholder="生成的标签文本将显示在这里...")
# 处理逻辑
def process_image(img, gen_thresh, char_thresh):
tagger = Tagger()
results = tagger.predict(img, gen_thresh, char_thresh)
# 格式化文本输出
tag_text = ", ".join(results["general"].keys())
if results["characters"]:
tag_text += ", " + ", ".join(results["characters"].keys())
return {
general_tags: results["general"],
char_tags: results["characters"],
rating_tags: results["ratings"],
output_text: tag_text
}
analyze_btn.click(
process_image,
inputs=[img_input, general_slider, char_slider],
outputs=[general_tags, char_tags, rating_tags, output_text]
)
# 启动应用
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)