yqcyqc commited on
Commit
7cbe0e2
·
verified ·
1 Parent(s): 23776fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -41
app.py CHANGED
@@ -7,8 +7,13 @@ from resnest.torch import resnest50
7
  from rembg import remove
8
  from PIL import Image
9
  import io
 
 
 
 
10
 
11
  # 加载类别名称
 
12
  with open('class_names.pkl', 'rb') as f:
13
  class_names = pickle.load(f)
14
 
@@ -30,21 +35,30 @@ preprocess = transforms.Compose([
30
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31
  ])
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def remove_background(img):
35
  """使用rembg去除背景并添加白色背景"""
36
- # 转换图像为字节流
37
  img_byte_arr = io.BytesIO()
38
  img.save(img_byte_arr, format='PNG')
39
  img_bytes = img_byte_arr.getvalue()
40
 
41
- # 去除背景
42
  removed_bg_bytes = remove(img_bytes)
43
-
44
- # 转换为PIL图像并处理透明度
45
  removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
46
 
47
- # 创建白色背景
48
  white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
49
  combined = Image.alpha_composite(white_bg, removed_bg_img)
50
  return combined.convert('RGB')
@@ -52,17 +66,14 @@ def remove_background(img):
52
 
53
  def predict_image(img, remove_bg=False):
54
  """分类预测主函数"""
55
- # 根据选择处理图像
56
  if remove_bg:
57
  processed_img = remove_background(img)
58
  else:
59
- processed_img = img.convert('RGB') # 确保为RGB格式
60
 
61
- # 预处理
62
  input_tensor = preprocess(processed_img)
63
  input_batch = input_tensor.unsqueeze(0).to(device)
64
 
65
- # 预测
66
  with torch.no_grad():
67
  output = model(input_batch)
68
 
@@ -70,15 +81,68 @@ def predict_image(img, remove_bg=False):
70
  top3_probs, top3_indices = torch.topk(probabilities, 3)
71
 
72
  results = {
73
- class_names[i]: p.item()
74
  for p, i in zip(top3_probs, top3_indices)
75
  }
76
 
77
- # 记录结果
78
  best_class = class_names[top3_indices[0]]
79
  best_conf = top3_probs[0].item() * 100
80
 
81
- return processed_img, best_class, f"{best_conf:.2f}%", results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  def create_interface():
@@ -86,55 +150,58 @@ def create_interface():
86
  "r0_0_100.jpg",
87
  "r0_18_100.jpg",
88
  "9_100.jpg",
89
- "100_100.jpg",
90
- "1105.jpg",
91
- "5ecc819f1a579f513e0a1500fabb3f0.png"
92
  ]
93
 
94
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
95
- gr.Markdown("""
96
- # 🍎 智能水果识别系统
97
- """)
98
 
99
- # 新增:模式选择卡片(视觉强化)
100
  with gr.Row():
101
  with gr.Column(scale=3):
102
  with gr.Group():
103
- gr.Markdown("### ⚙️ 处理模式选择")
104
  with gr.Row():
105
- bg_removal = gr.Checkbox(
106
- label="背景去除",
107
- value=False,
108
- interactive=True
109
- )
110
 
111
- # 主操作区域
112
- with gr.Row():
113
- with gr.Column():
114
- original_image = gr.Image(label="📤 上传图片", type="pil")
115
- gr.Examples(examples=examples, inputs=original_image)
116
  submit_btn = gr.Button("🚀 开始识别", variant="primary")
117
 
118
- # 添加模式说明提示
119
- gr.Markdown("""
120
- <div style="background: #f3f4f6; padding: 15px; border-radius: 8px; margin-top: 10px">
121
- <b>💡 使用建议:</b><br>
122
- • 上传图片:选择一张图片,点击'开始识别'按钮<br>
123
- • 勾选背景去除:适合杂乱背景的图片(识别更准确)<br>
124
- • 不勾选:适合纯色背景的图片(速度更快)
125
- </div>
126
- """)
127
 
128
  with gr.Column():
 
129
  processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
130
  best_pred = gr.Textbox(label="🔍 识别结果")
131
  confidence = gr.Textbox(label="📊 置信度")
132
  full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3)
133
 
 
 
 
 
 
134
  submit_btn.click(
135
  fn=predict_image,
136
  inputs=[original_image, bg_removal],
137
- outputs=[processed_image, best_pred, confidence, full_results]
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
140
  return demo
@@ -142,4 +209,4 @@ def create_interface():
142
 
143
  if __name__ == "__main__":
144
  interface = create_interface()
145
- interface.launch(share=True)
 
7
  from rembg import remove
8
  from PIL import Image
9
  import io
10
+ import json
11
+ import time
12
+ import threading
13
+ import concurrent.futures
14
 
15
  # 加载类别名称
16
+
17
  with open('class_names.pkl', 'rb') as f:
18
  class_names = pickle.load(f)
19
 
 
35
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
36
  ])
37
 
38
+ # 创建线程池
39
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
40
+
41
+
42
+ class RealtimeState:
43
+ def __init__(self):
44
+ self.last_result = None
45
+ self.last_update_time = 0
46
+ self.is_processing = False
47
+ self.lock = threading.Lock()
48
+
49
+
50
+ realtime_state = RealtimeState()
51
+
52
 
53
  def remove_background(img):
54
  """使用rembg去除背景并添加白色背景"""
 
55
  img_byte_arr = io.BytesIO()
56
  img.save(img_byte_arr, format='PNG')
57
  img_bytes = img_byte_arr.getvalue()
58
 
 
59
  removed_bg_bytes = remove(img_bytes)
 
 
60
  removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
61
 
 
62
  white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
63
  combined = Image.alpha_composite(white_bg, removed_bg_img)
64
  return combined.convert('RGB')
 
66
 
67
  def predict_image(img, remove_bg=False):
68
  """分类预测主函数"""
 
69
  if remove_bg:
70
  processed_img = remove_background(img)
71
  else:
72
+ processed_img = img.convert('RGB')
73
 
 
74
  input_tensor = preprocess(processed_img)
75
  input_batch = input_tensor.unsqueeze(0).to(device)
76
 
 
77
  with torch.no_grad():
78
  output = model(input_batch)
79
 
 
81
  top3_probs, top3_indices = torch.topk(probabilities, 3)
82
 
83
  results = {
84
+ class_names[i]: round(p.item(), 4)
85
  for p, i in zip(top3_probs, top3_indices)
86
  }
87
 
 
88
  best_class = class_names[top3_indices[0]]
89
  best_conf = top3_probs[0].item() * 100
90
 
91
+ with open('output/prediction_results.txt', 'a') as f:
92
+ f.write(f"Remove BG: {remove_bg}\n")
93
+ f.write(f"Predicted: {best_class} ({best_conf:.2f}%)\n")
94
+ f.write(f"Top 3: {results}\n\n")
95
+
96
+ # 添加一个空字符串作为 prediction_id
97
+ prediction_id = ""
98
+
99
+ return prediction_id, processed_img, best_class, f"{best_conf:.2f}%", results
100
+
101
+
102
+ def predict_realtime(video_frame, remove_bg):
103
+ """实时预测主函数,结果保留2秒"""
104
+ global realtime_state
105
+
106
+ if video_frame is None:
107
+ return None, None, None, None, None
108
+
109
+ current_time = time.time()
110
+
111
+ # 检查是否有未过期的结果
112
+ with realtime_state.lock:
113
+ if realtime_state.last_result and current_time - realtime_state.last_update_time < 2:
114
+ return realtime_state.last_result
115
+
116
+ # 如果正在处理中,返回None
117
+ if realtime_state.is_processing:
118
+ return None, None, None, None, None
119
+
120
+ # 标记为正在处理
121
+ realtime_state.is_processing = True
122
+
123
+ # 异步处理帧
124
+ def process_frame():
125
+ try:
126
+ result = predict_image(video_frame, remove_bg)
127
+ with realtime_state.lock:
128
+ realtime_state.last_result = result
129
+ realtime_state.last_update_time = time.time()
130
+ realtime_state.is_processing = False
131
+ except Exception as e:
132
+ print(f"处理帧时出错: {e}")
133
+ with realtime_state.lock:
134
+ realtime_state.is_processing = False
135
+
136
+ # 提交到线程池处理
137
+ executor.submit(process_frame)
138
+
139
+ return None, None, None, None, None
140
+
141
+
142
+ def add_feedback(prediction_id, feedback):
143
+ """模拟将反馈信息保存,实际上不做任何操作"""
144
+ print(f"收到反馈: {feedback} 对预测ID: {prediction_id}")
145
+ return True
146
 
147
 
148
  def create_interface():
 
150
  "r0_0_100.jpg",
151
  "r0_18_100.jpg",
152
  "9_100.jpg",
153
+ "127_100.jpg",
154
+ "5ecc819f1a579f513e0a1500fabb3f0.png",
155
+ "1105.jpg"
156
  ]
157
 
158
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown("""# 🍎 智能水果识别系统""")
 
 
160
 
 
161
  with gr.Row():
162
  with gr.Column(scale=3):
163
  with gr.Group():
164
+ gr.Markdown("## ⚙️ 处理模式选择")
165
  with gr.Row():
166
+ bg_removal = gr.Checkbox(label="背景去除", value=False, interactive=True)
167
+ with gr.Column():
168
+ original_image = gr.Image(label="📤 上传图片", type="pil")
169
+ gr.Examples(examples=examples, inputs=original_image)
 
170
 
 
 
 
 
 
171
  submit_btn = gr.Button("🚀 开始识别", variant="primary")
172
 
173
+ gr.Markdown("""## ⚡ 实时识别""")
174
+ camera = gr.Image(label="📷 摄像头捕获", type="pil", streaming=True)
 
 
 
 
 
 
 
175
 
176
  with gr.Column():
177
+ prediction_id_output = gr.Textbox(label="🔍 预测ID", interactive=False, visible=False)
178
  processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
179
  best_pred = gr.Textbox(label="🔍 识别结果")
180
  confidence = gr.Textbox(label="📊 置信度")
181
  full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3)
182
 
183
+ with gr.Row():
184
+ feedback_input = gr.Textbox(label="📝 输入反馈信息")
185
+ with gr.Row():
186
+ feedback_btn = gr.Button("📢 提交反馈", variant="secondary")
187
+
188
  submit_btn.click(
189
  fn=predict_image,
190
  inputs=[original_image, bg_removal],
191
+ outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
192
+ )
193
+
194
+ camera.stream(
195
+ fn=predict_realtime,
196
+ inputs=[camera, bg_removal],
197
+ outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
198
+ )
199
+
200
+ feedback_btn.click(
201
+ fn=lambda prediction_id, feedback: (
202
+ add_feedback(prediction_id, feedback), "反馈成功!", gr.update(value="")),
203
+ inputs=[prediction_id_output, feedback_input],
204
+ outputs=[prediction_id_output, feedback_input]
205
  )
206
 
207
  return demo
 
209
 
210
  if __name__ == "__main__":
211
  interface = create_interface()
212
+ interface.launch(share=True)