IdlecloudX commited on
Commit
fcde2f2
·
verified ·
1 Parent(s): 96c8569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -176
app.py CHANGED
@@ -1,203 +1,145 @@
1
- import os
2
  import gradio as gr
3
- import huggingface_hub
4
- import numpy as np
5
- import onnxruntime as rt
6
- import pandas as pd
7
  from PIL import Image
8
  from huggingface_hub import login
9
 
 
 
 
10
  # 模型配置
11
- MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
12
- MODEL_FILENAME = "model.onnx"
13
- LABEL_FILENAME = "selected_tags.csv"
14
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
15
 
16
- if not os.environ.get("HF_TOKEN"):
17
- print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
 
18
  else:
19
- login(token=os.environ.get("HF_TOKEN"))
20
-
21
- # 标签处理配置
22
- kaomojis = [
23
- "0_0",
24
- "(o)_(o)",
25
- "+_+",
26
- "+_-",
27
- "._.",
28
- "<o>_<o>",
29
- "<|>_<|>",
30
- "=_=",
31
- ">_<",
32
- "3_3",
33
- "6_9",
34
- ">_o",
35
- "@_@",
36
- "^_^",
37
- "o_o",
38
- "u_u",
39
- "x_x",
40
- "|_|",
41
- "||_||",
42
- ]
43
 
 
 
 
44
  class Tagger:
45
  def __init__(self):
46
- self.model = None
47
- self.tag_names = []
48
- self.model_size = None
49
- self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
50
- self._init_model()
51
-
52
- def _init_model(self):
53
- """初始化模型和标签"""
54
- try:
55
- label_path = huggingface_hub.hf_hub_download(
56
- MODEL_REPO,
57
- LABEL_FILENAME,
58
- token=self.hf_token
59
- )
60
- model_path = huggingface_hub.hf_hub_download(
61
- MODEL_REPO,
62
- MODEL_FILENAME,
63
- token=self.hf_token
64
- )
65
-
66
- # 加载标签
67
- tags_df = pd.read_csv(label_path)
68
- self.tag_names = tags_df["name"].tolist()
69
- self.categories = {
70
- "rating": np.where(tags_df["category"] == 9)[0],
71
- "general": np.where(tags_df["category"] == 0)[0],
72
- "character": np.where(tags_df["category"] == 4)[0]
73
- }
74
-
75
- # 加载ONNX模型
76
- self.model = rt.InferenceSession(model_path)
77
- self.model_size = self.model.get_inputs()[0].shape[1]
78
- except huggingface_hub.utils.HfHubHTTPError as e:
79
- if "401" in str(e):
80
- raise RuntimeError(
81
- "模型下载认证失败,请:\n"
82
- "1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
83
- "2. 点击Agree and continue\n"
84
- "3. 确保HF_TOKEN已正确设置"
85
- )
86
- else:
87
- raise
88
-
89
- def _preprocess(self, img):
90
- """图像预处理"""
91
- # 转换为RGB
92
  if img.mode != "RGB":
93
  img = img.convert("RGB")
94
-
95
- # 填充为正方形
96
- size = max(img.size)
97
- padded = Image.new("RGB", (size, size), (255, 255, 255))
98
- padded.paste(img, ((size - img.width)//2, (size - img.height)//2))
99
-
100
- # 调整尺寸
101
- if size != self.model_size:
102
- padded = padded.resize((self.model_size, self.model_size), Image.BICUBIC)
103
-
104
- # 转换为BGR格式
105
- return np.array(padded)[:, :, ::-1].astype(np.float32)
106
-
107
- def predict(self, img, general_thresh=0.35, character_thresh=0.85):
108
- """执行预测"""
109
- # 预处理
110
- img_data = self._preprocess(img)[np.newaxis]
111
-
112
- # 运行模型
113
- input_name = self.model.get_inputs()[0].name
114
- outputs = self.model.run(None, {input_name: img_data})[0][0]
115
-
116
- # 组织结果
117
- results = {
118
- "ratings": {},
119
- "general": {},
120
- "characters": {}
121
- }
122
-
123
- # 处理评分标签
124
  for idx in self.categories["rating"]:
125
- tag = self.tag_names[idx].replace("_", " ")
126
- results["ratings"][tag] = float(outputs[idx])
127
-
128
- # 处理通用标签
129
  for idx in self.categories["general"]:
130
- if outputs[idx] > general_thresh:
131
- tag = self.tag_names[idx].replace("_", " ")
132
- results["general"][tag] = float(outputs[idx])
133
-
134
- # 处理角色标签
135
  for idx in self.categories["character"]:
136
- if outputs[idx] > character_thresh:
137
- tag = self.tag_names[idx].replace("_", " ")
138
- results["characters"][tag] = float(outputs[idx])
139
-
140
- # 排序结果
141
- results["general"] = dict(sorted(
142
- results["general"].items(),
143
- key=lambda x: x[1],
144
- reverse=True
145
- ))
146
-
147
- return results
148
-
149
- # 创建Gradio界面
150
- with gr.Blocks(theme=gr.themes.Soft(), title="AI图像标签分析器") as demo:
151
- gr.Markdown("# 🖼️ AI图像标签分析器")
152
- gr.Markdown("上传图片自动分析图像内容标签")
153
-
154
  with gr.Row():
155
  with gr.Column(scale=1):
156
- img_input = gr.Image(type="pil", label="上传图片")
157
- with gr.Accordion("高级设置", open=False):
158
- general_slider = gr.Slider(0, 1, 0.35,
159
- label="通用标签阈值",
160
- info="值越高标签越少但更准确")
161
  char_slider = gr.Slider(0, 1, 0.85,
162
- label="角色标签阈值",
163
- info="推荐保持较高阈值")
164
- analyze_btn = gr.Button("开始分析", variant="primary")
 
 
 
165
 
166
  with gr.Column(scale=2):
167
  with gr.Tabs():
168
- with gr.TabItem("🏷️ 通用标签"):
169
- general_tags = gr.Label(label="检测到的通用标签")
170
- with gr.TabItem("👤 角色标签"):
171
- char_tags = gr.Label(label="检测到的角色标签")
172
- with gr.TabItem("⭐ 评分标签"):
173
- rating_tags = gr.Label(label="图像评级标签")
174
-
175
- output_text = gr.Textbox(label="标签文本",
176
- placeholder="生成的标签文本将显示在这里...")
177
-
178
- # 处理逻辑
179
- def process_image(img, gen_thresh, char_thresh):
180
- tagger = Tagger()
181
- results = tagger.predict(img, gen_thresh, char_thresh)
182
-
183
- # 格式化文本输出
184
- tag_text = ", ".join(results["general"].keys())
185
- if results["characters"]:
186
- tag_text += ", " + ", ".join(results["characters"].keys())
187
-
 
188
  return {
189
- general_tags: results["general"],
190
- char_tags: results["characters"],
191
- rating_tags: results["ratings"],
192
- output_text: tag_text
193
  }
194
 
195
- analyze_btn.click(
196
- process_image,
197
- inputs=[img_input, general_slider, char_slider],
198
- outputs=[general_tags, char_tags, rating_tags, output_text]
199
  )
200
 
201
- # 启动应用
 
 
202
  if __name__ == "__main__":
203
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os, json
2
  import gradio as gr
3
+ import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
 
 
 
4
  from PIL import Image
5
  from huggingface_hub import login
6
 
7
+ from translator import translate_texts
8
+
9
+ # ------------------------------------------------------------------
10
  # 模型配置
11
+ # ------------------------------------------------------------------
12
+ MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
13
+ MODEL_FILENAME = "model.onnx"
14
+ LABEL_FILENAME = "selected_tags.csv"
15
 
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
+ if HF_TOKEN:
18
+ login(token=HF_TOKEN)
19
  else:
20
+ print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # ------------------------------------------------------------------
23
+ # Tagger 类
24
+ # ------------------------------------------------------------------
25
  class Tagger:
26
  def __init__(self):
27
+ self.hf_token = HF_TOKEN
28
+ self._load_model_and_labels()
29
+
30
+ def _load_model_and_labels(self):
31
+ label_path = huggingface_hub.hf_hub_download(
32
+ MODEL_REPO, LABEL_FILENAME, token=self.hf_token
33
+ )
34
+ model_path = huggingface_hub.hf_hub_download(
35
+ MODEL_REPO, MODEL_FILENAME, token=self.hf_token
36
+ )
37
+
38
+ tags_df = pd.read_csv(label_path)
39
+ self.tag_names = tags_df["name"].tolist()
40
+ self.categories = {
41
+ "rating": np.where(tags_df["category"] == 9)[0],
42
+ "general": np.where(tags_df["category"] == 0)[0],
43
+ "character": np.where(tags_df["category"] == 4)[0],
44
+ }
45
+ self.model = rt.InferenceSession(model_path)
46
+ self.input_size = self.model.get_inputs()[0].shape[1]
47
+
48
+ # ------------------------- preprocess -------------------------
49
+ def _preprocess(self, img: Image.Image) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if img.mode != "RGB":
51
  img = img.convert("RGB")
52
+ size = max(img.size)
53
+ canvas = Image.new("RGB", (size, size), (255, 255, 255))
54
+ canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
55
+ if size != self.input_size:
56
+ canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
57
+ return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
58
+
59
+ # --------------------------- predict --------------------------
60
+ def predict(self, img: Image.Image,
61
+ gen_th: float = 0.35,
62
+ char_th: float = 0.85):
63
+ inp_name = self.model.get_inputs()[0].name
64
+ outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
65
+
66
+ res = {"ratings": {}, "general": {}, "characters": {}}
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  for idx in self.categories["rating"]:
69
+ res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
70
+
 
 
71
  for idx in self.categories["general"]:
72
+ if outputs[idx] > gen_th:
73
+ res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
74
+
 
 
75
  for idx in self.categories["character"]:
76
+ if outputs[idx] > char_th:
77
+ res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
78
+
79
+ res["general"] = dict(sorted(res["general"].items(),
80
+ key=lambda kv: kv[1],
81
+ reverse=True))
82
+ return res
83
+
84
+ # ------------------------------------------------------------------
85
+ # Gradio UI
86
+ # ------------------------------------------------------------------
87
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo:
88
+ gr.Markdown("# 🖼️ AI 图像标签分析器")
89
+ gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
90
+
 
 
 
91
  with gr.Row():
92
  with gr.Column(scale=1):
93
+ img_in = gr.Image(type="pil", label="上传图片")
94
+ with gr.Accordion("⚙️ 高级设置", open=False):
95
+ gen_slider = gr.Slider(0, 1, 0.35,
96
+ label="通用标签阈值", info="越高→标签更少更准")
 
97
  char_slider = gr.Slider(0, 1, 0.85,
98
+ label="角色标签阈值", info="推荐保持较高阈值")
99
+ lang_drop = gr.Dropdown(["zh", "en"], value="zh",
100
+ label="翻译目标语言",
101
+ info="当前仅内置中 / 英")
102
+
103
+ btn = gr.Button("开始分析", variant="primary")
104
 
105
  with gr.Column(scale=2):
106
  with gr.Tabs():
107
+ with gr.TabItem("🏷️ 通用标签 (英文)"):
108
+ out_general = gr.Label(label="General Tags")
109
+ with gr.TabItem("👤 角色标签 (英文)"):
110
+ out_char = gr.Label(label="Character Tags")
111
+ with gr.TabItem("⭐ 评分标签 (英文)"):
112
+ out_rating = gr.Label(label="Rating Tags")
113
+ with gr.TabItem("🌐 翻译结果"):
114
+ out_trans = gr.Textbox(label="翻译后的标签",
115
+ placeholder="翻译结果显示在此处")
116
+
117
+ # ----------------- 处理回调 -----------------
118
+ def process(img, g_th, c_th, tgt_lang):
119
+ tagger = Tagger()
120
+ res = tagger.predict(img, g_th, c_th)
121
+
122
+ # =========== 组织翻译 ===========
123
+ tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys())
124
+ translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang)
125
+ # 拼接字符串
126
+ trans_str = ", ".join(translations)
127
+
128
  return {
129
+ out_general: res["general"],
130
+ out_char: res["characters"],
131
+ out_rating: res["ratings"],
132
+ out_trans: trans_str
133
  }
134
 
135
+ btn.click(
136
+ process,
137
+ inputs=[img_in, gen_slider, char_slider, lang_drop],
138
+ outputs=[out_general, out_char, out_rating, out_trans]
139
  )
140
 
141
+ # ------------------------------------------------------------------
142
+ # 启动
143
+ # ------------------------------------------------------------------
144
  if __name__ == "__main__":
145
+ demo.launch(server_name="0.0.0.0", server_port=7860)