IdlecloudX commited on
Commit
d5894b1
·
verified ·
1 Parent(s): 59168b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import huggingface_hub
4
+ import numpy as np
5
+ import onnxruntime as rt
6
+ import pandas as pd
7
+ from PIL import Image
8
+
9
+ # 模型配置
10
+ MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
11
+ MODEL_FILENAME = "model.onnx"
12
+ LABEL_FILENAME = "selected_tags.csv"
13
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
14
+
15
+ # 标签处理配置
16
+ kaomojis = [
17
+ "0_0",
18
+ "(o)_(o)",
19
+ "+_+",
20
+ "+_-",
21
+ "._.",
22
+ "<o>_<o>",
23
+ "<|>_<|>",
24
+ "=_=",
25
+ ">_<",
26
+ "3_3",
27
+ "6_9",
28
+ ">_o",
29
+ "@_@",
30
+ "^_^",
31
+ "o_o",
32
+ "u_u",
33
+ "x_x",
34
+ "|_|",
35
+ "||_||",
36
+ ]
37
+
38
+ class Tagger:
39
+ def __init__(self):
40
+ self.model = None
41
+ self.tag_names = []
42
+ self.model_size = None
43
+ self._init_model()
44
+
45
+ def _init_model(self):
46
+ """初始化模型和标签"""
47
+ # 下载模型文件
48
+ label_path = huggingface_hub.hf_hub_download(
49
+ MODEL_REPO,
50
+ LABEL_FILENAME,
51
+ token=HF_TOKEN
52
+ )
53
+ model_path = huggingface_hub.hf_hub_download(
54
+ MODEL_REPO,
55
+ MODEL_FILENAME,
56
+ token=HF_TOKEN
57
+ )
58
+
59
+ # 加载标签
60
+ tags_df = pd.read_csv(label_path)
61
+ self.tag_names = tags_df["name"].tolist()
62
+ self.categories = {
63
+ "rating": np.where(tags_df["category"] == 9)[0],
64
+ "general": np.where(tags_df["category"] == 0)[0],
65
+ "character": np.where(tags_df["category"] == 4)[0]
66
+ }
67
+
68
+ # 加载ONNX模型
69
+ self.model = rt.InferenceSession(model_path)
70
+ self.model_size = self.model.get_inputs()[0].shape[1]
71
+
72
+ def _preprocess(self, img):
73
+ """图像预处理"""
74
+ # 转换为RGB
75
+ if img.mode != "RGB":
76
+ img = img.convert("RGB")
77
+
78
+ # 填充为正方形
79
+ size = max(img.size)
80
+ padded = Image.new("RGB", (size, size), (255, 255, 255))
81
+ padded.paste(img, ((size - img.width)//2, (size - img.height)//2))
82
+
83
+ # 调整尺寸
84
+ if size != self.model_size:
85
+ padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC)
86
+
87
+ # 转换为BGR格式
88
+ return np.array(padded)[:, :, ::-1].astype(np.float32)
89
+
90
+ def predict(self, img, general_thresh=0.35, character_thresh=0.85):
91
+ """执行预测"""
92
+ # 预处理
93
+ img_data = self._preprocess(img)[np.newaxis]
94
+
95
+ # 运行模型
96
+ input_name = self.model.get_inputs()[0].name
97
+ outputs = self.model.run(None, {input_name: img_data})[0][0]
98
+
99
+ # 组织结果
100
+ results = {
101
+ "ratings": {},
102
+ "general": {},
103
+ "characters": {}
104
+ }
105
+
106
+ # 处理评分标签
107
+ for idx in self.categories["rating"]:
108
+ tag = self.tag_names[idx].replace("_", " ")
109
+ results["ratings"][tag] = float(outputs[idx])
110
+
111
+ # 处理通用标签
112
+ for idx in self.categories["general"]:
113
+ if outputs[idx] > general_thresh:
114
+ tag = self.tag_names[idx].replace("_", " ")
115
+ results["general"][tag] = float(outputs[idx])
116
+
117
+ # 处理角色标签
118
+ for idx in self.categories["character"]:
119
+ if outputs[idx] > character_thresh:
120
+ tag = self.tag_names[idx].replace("_", " ")
121
+ results["characters"][tag] = float(outputs[idx])
122
+
123
+ # 排序结果
124
+ results["general"] = dict(sorted(
125
+ results["general"].items(),
126
+ key=lambda x: x[1],
127
+ reverse=True
128
+ ))
129
+
130
+ return results
131
+
132
+ # 创建Gradio界面
133
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo:
134
+ gr.Markdown("# 🖼️ AI图像标签分析器")
135
+ gr.Markdown("上传图片自动分析图像内容标签")
136
+
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ img_input = gr.Image(type="pil", label="上传图片")
140
+ with gr.Accordion("高级设置", open=False):
141
+ general_slider = gr.Slider(0, 1, 0.35,
142
+ label="通用标签阈值",
143
+ info="值越高标签越少但更准确")
144
+ char_slider = gr.Slider(0, 1, 0.85,
145
+ label="角色标签阈值",
146
+ info="推荐保持较高阈值")
147
+ analyze_btn = gr.Button("开始分析", variant="primary")
148
+
149
+ with gr.Column(scale=2):
150
+ with gr.Tabs():
151
+ with gr.TabItem("🏷️ 通用标签"):
152
+ general_tags = gr.Label(label="检测到的通用标签")
153
+ with gr.TabItem("👤 角色标签"):
154
+ char_tags = gr.Label(label="检测到的角色标签")
155
+ with gr.TabItem("⭐ 评分标签"):
156
+ rating_tags = gr.Label(label="图像评级标签")
157
+
158
+ output_text = gr.Textbox(label="标签文本",
159
+ placeholder="生成的标签文本将显示在这里...")
160
+
161
+ # 处理逻辑
162
+ def process_image(img, gen_thresh, char_thresh):
163
+ tagger = Tagger()
164
+ results = tagger.predict(img, gen_thresh, char_thresh)
165
+
166
+ # 格式化文本输出
167
+ tag_text = ", ".join(results["general"].keys())
168
+ if results["characters"]:
169
+ tag_text += ", " + ", ".join(results["characters"].keys())
170
+
171
+ return {
172
+ general_tags: results["general"],
173
+ char_tags: results["characters"],
174
+ rating_tags: results["ratings"],
175
+ output_text: tag_text
176
+ }
177
+
178
+ analyze_btn.click(
179
+ process_image,
180
+ inputs=[img_input, general_slider, char_slider],
181
+ outputs=[general_tags, char_tags, rating_tags, output_text]
182
+ )
183
+
184
+ # 启动应用
185
+ if __name__ == "__main__":
186
+ demo.launch(server_name="0.0.0.0", server_port=7860)