IdlecloudX commited on
Commit
7c7be00
·
verified ·
1 Parent(s): fcde2f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -35
app.py CHANGED
@@ -84,62 +84,201 @@ class Tagger:
84
  # ------------------------------------------------------------------
85
  # Gradio UI
86
  # ------------------------------------------------------------------
87
- with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo:
88
- gr.Markdown("# 🖼️ AI 图像标签分析器")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
90
 
91
  with gr.Row():
92
  with gr.Column(scale=1):
93
  img_in = gr.Image(type="pil", label="上传图片")
94
- with gr.Accordion("⚙️ 高级设置", open=False):
95
  gen_slider = gr.Slider(0, 1, 0.35,
96
  label="通用标签阈值", info="越高→标签更少更准")
97
  char_slider = gr.Slider(0, 1, 0.85,
98
  label="角色标签阈值", info="推荐保持较高阈值")
99
- lang_drop = gr.Dropdown(["zh", "en"], value="zh",
100
- label="翻译目标语言",
101
- info="当前仅内置中 / 英")
 
 
 
 
 
102
 
103
- btn = gr.Button("开始分析", variant="primary")
 
104
 
105
  with gr.Column(scale=2):
106
  with gr.Tabs():
107
- with gr.TabItem("🏷️ 通用标签 (英文)"):
108
- out_general = gr.Label(label="General Tags")
109
- with gr.TabItem("👤 角色标签 (英文)"):
110
- out_char = gr.Label(label="Character Tags")
111
- with gr.TabItem("⭐ 评分标签 (英文)"):
112
- out_rating = gr.Label(label="Rating Tags")
113
- with gr.TabItem("🌐 翻译结果"):
114
- out_trans = gr.Textbox(label="翻译后的标签",
115
- placeholder="翻译结果显示在此处")
 
 
116
 
117
  # ----------------- 处理回调 -----------------
118
- def process(img, g_th, c_th, tgt_lang):
119
- tagger = Tagger()
120
- res = tagger.predict(img, g_th, c_th)
121
-
122
- # =========== 组织翻译 ===========
123
- tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys())
124
- translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang)
125
- # 拼接字符串
126
- trans_str = ", ".join(translations)
127
-
128
- return {
129
- out_general: res["general"],
130
- out_char: res["characters"],
131
- out_rating: res["ratings"],
132
- out_trans: trans_str
133
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
 
135
  btn.click(
136
  process,
137
- inputs=[img_in, gen_slider, char_slider, lang_drop],
138
- outputs=[out_general, out_char, out_rating, out_trans]
 
139
  )
140
 
141
  # ------------------------------------------------------------------
142
  # 启动
143
  # ------------------------------------------------------------------
144
  if __name__ == "__main__":
145
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
84
  # ------------------------------------------------------------------
85
  # Gradio UI
86
  # ------------------------------------------------------------------
87
+ custom_css = """
88
+ .label-container {
89
+ max-height: 300px;
90
+ overflow-y: auto;
91
+ border: 1px solid #ddd;
92
+ padding: 10px;
93
+ border-radius: 5px;
94
+ background-color: #f9f9f9;
95
+ }
96
+ .tag-item {
97
+ display: flex;
98
+ justify-content: space-between;
99
+ align-items: center;
100
+ margin: 2px 0;
101
+ padding: 2px 5px;
102
+ border-radius: 3px;
103
+ background-color: #fff;
104
+ }
105
+ .tag-en {
106
+ font-weight: bold;
107
+ color: #333;
108
+ }
109
+ .tag-zh {
110
+ color: #666;
111
+ margin-left: 10px;
112
+ }
113
+ .tag-score {
114
+ color: #999;
115
+ font-size: 0.9em;
116
+ }
117
+ .btn-container {
118
+ margin-top: 20px;
119
+ }
120
+ """
121
+
122
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo:
123
+ gr.Markdown("# 🖼️ AI 图像标签分析器")
124
  gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
125
 
126
  with gr.Row():
127
  with gr.Column(scale=1):
128
  img_in = gr.Image(type="pil", label="上传图片")
129
+ with gr.Accordion("⚙️ 高级设置", open=True):
130
  gen_slider = gr.Slider(0, 1, 0.35,
131
  label="通用标签阈值", info="越高→标签更少更准")
132
  char_slider = gr.Slider(0, 1, 0.85,
133
  label="角色标签阈值", info="推荐保持较高阈值")
134
+ show_zh = gr.Checkbox(True, label="显示中文翻译")
135
+
136
+ gr.Markdown("### 汇总设置")
137
+ with gr.Row():
138
+ sum_general = gr.Checkbox(True, label="通用标签")
139
+ sum_char = gr.Checkbox(True, label="角色标签")
140
+ sum_rating = gr.Checkbox(False, label="评分标签")
141
+ sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符")
142
 
143
+ btn = gr.Button("开始分析", variant="primary", elem_classes=["btn-container"])
144
+ processing_info = gr.Markdown("", visible=False)
145
 
146
  with gr.Column(scale=2):
147
  with gr.Tabs():
148
+ with gr.TabItem("🏷️ 通用标签"):
149
+ out_general = gr.HTML(label="General Tags")
150
+ with gr.TabItem("👤 角色标签"):
151
+ out_char = gr.HTML(label="Character Tags")
152
+ with gr.TabItem("⭐ 评分标签"):
153
+ out_rating = gr.HTML(label="Rating Tags")
154
+
155
+ gr.Markdown("### 标签汇总")
156
+ out_summary = gr.Textbox(label="标签汇总",
157
+ placeholder="选择需要汇总的标签类别...",
158
+ lines=3)
159
 
160
  # ----------------- 处理回调 -----------------
161
+ def format_tags_html(tags_dict, translations, show_translation=True):
162
+ """格式化标签为HTML格式"""
163
+ if not tags_dict:
164
+ return "<p>暂无标签</p>"
165
+
166
+ html = '<div class="label-container">'
167
+ for i, (tag, score) in enumerate(tags_dict.items()):
168
+ tag_html = f'<div class="tag-item">'
169
+ tag_html += f'<div><span class="tag-en">{tag}</span>'
170
+ if show_translation and i < len(translations):
171
+ tag_html += f'<span class="tag-zh">({translations[i]})</span>'
172
+ tag_html += '</div>'
173
+ tag_html += f'<span class="tag-score">{score:.3f}</span>'
174
+ tag_html += '</div>'
175
+ html += tag_html
176
+ html += '</div>'
177
+ return html
178
+
179
+ def process(img, g_th, c_th, show_zh, sum_gen, sum_char, sum_rat, sep_type):
180
+ # 开始处理,返回更新
181
+ yield (
182
+ gr.update(interactive=False, value="处理中..."),
183
+ gr.update(visible=True, value="🔄 正在分析图像..."),
184
+ "", "", "", ""
185
+ )
186
+
187
+ try:
188
+ tagger = Tagger()
189
+ res = tagger.predict(img, g_th, c_th)
190
+
191
+ # 收集所有需要翻译的标签
192
+ all_tags = []
193
+ tag_categories = {
194
+ "general": list(res["general"].keys()),
195
+ "characters": list(res["characters"].keys()),
196
+ "ratings": list(res["ratings"].keys())
197
+ }
198
+
199
+ if show_zh:
200
+ for tags in tag_categories.values():
201
+ all_tags.extend(tags)
202
+
203
+ # 批量翻译
204
+ if all_tags:
205
+ translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh")
206
+ else:
207
+ translations = []
208
+ else:
209
+ translations = []
210
+
211
+ # 分配翻译结果
212
+ translations_dict = {}
213
+ offset = 0
214
+ for category, tags in tag_categories.items():
215
+ if show_zh and tags:
216
+ translations_dict[category] = translations[offset:offset+len(tags)]
217
+ offset += len(tags)
218
+ else:
219
+ translations_dict[category] = []
220
+
221
+ # 生成HTML输出
222
+ general_html = format_tags_html(res["general"], translations_dict["general"], show_zh)
223
+ char_html = format_tags_html(res["characters"], translations_dict["characters"], show_zh)
224
+ rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], show_zh)
225
+
226
+ # 生成汇总文本
227
+ summary_parts = []
228
+ separators = {"逗号": ", ", "换行": "\n", "空格": " "}
229
+ separator = separators[sep_type]
230
+
231
+ if sum_gen and res["general"]:
232
+ if show_zh and translations_dict["general"]:
233
+ gen_tags = [f"{en}({zh})" for en, zh in zip(res["general"].keys(), translations_dict["general"])]
234
+ else:
235
+ gen_tags = list(res["general"].keys())
236
+ summary_parts.append("通用标签: " + separator.join(gen_tags))
237
+
238
+ if sum_char and res["characters"]:
239
+ if show_zh and translations_dict["characters"]:
240
+ char_tags = [f"{en}({zh})" for en, zh in zip(res["characters"].keys(), translations_dict["characters"])]
241
+ else:
242
+ char_tags = list(res["characters"].keys())
243
+ summary_parts.append("角色标签: " + separator.join(char_tags))
244
+
245
+ if sum_rat and res["ratings"]:
246
+ if show_zh and translations_dict["ratings"]:
247
+ rat_tags = [f"{en}({zh})" for en, zh in zip(res["ratings"].keys(), translations_dict["ratings"])]
248
+ else:
249
+ rat_tags = list(res["ratings"].keys())
250
+ summary_parts.append("评分标签: " + separator.join(rat_tags))
251
+
252
+ summary_text = "\n\n".join(summary_parts) if summary_parts else "请选择要汇总的标签类别"
253
+
254
+ # 完成处理,返回最终结果
255
+ yield (
256
+ gr.update(interactive=True, value="开始分析"),
257
+ gr.update(visible=False),
258
+ general_html,
259
+ char_html,
260
+ rating_html,
261
+ summary_text
262
+ )
263
+
264
+ except Exception as e:
265
+ # 出错时的处理
266
+ yield (
267
+ gr.update(interactive=True, value="开始分析"),
268
+ gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
269
+ "", "", "", ""
270
+ )
271
 
272
+ # 绑定事件
273
  btn.click(
274
  process,
275
+ inputs=[img_in, gen_slider, char_slider, show_zh, sum_general, sum_char, sum_rating, sum_sep],
276
+ outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary],
277
+ show_progress=True
278
  )
279
 
280
  # ------------------------------------------------------------------
281
  # 启动
282
  # ------------------------------------------------------------------
283
  if __name__ == "__main__":
284
+ demo.launch(server_name="0.0.0.0", server_port=7860)