Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,8 +7,13 @@ from resnest.torch import resnest50
|
|
7 |
from rembg import remove
|
8 |
from PIL import Image
|
9 |
import io
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# 加载类别名称
|
|
|
12 |
with open('class_names.pkl', 'rb') as f:
|
13 |
class_names = pickle.load(f)
|
14 |
|
@@ -30,21 +35,30 @@ preprocess = transforms.Compose([
|
|
30 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
31 |
])
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def remove_background(img):
|
35 |
"""使用rembg去除背景并添加白色背景"""
|
36 |
-
# 转换图像为字节流
|
37 |
img_byte_arr = io.BytesIO()
|
38 |
img.save(img_byte_arr, format='PNG')
|
39 |
img_bytes = img_byte_arr.getvalue()
|
40 |
|
41 |
-
# 去除背景
|
42 |
removed_bg_bytes = remove(img_bytes)
|
43 |
-
|
44 |
-
# 转换为PIL图像并处理透明度
|
45 |
removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
|
46 |
|
47 |
-
# 创建白色背景
|
48 |
white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
|
49 |
combined = Image.alpha_composite(white_bg, removed_bg_img)
|
50 |
return combined.convert('RGB')
|
@@ -52,17 +66,14 @@ def remove_background(img):
|
|
52 |
|
53 |
def predict_image(img, remove_bg=False):
|
54 |
"""分类预测主函数"""
|
55 |
-
# 根据选择处理图像
|
56 |
if remove_bg:
|
57 |
processed_img = remove_background(img)
|
58 |
else:
|
59 |
-
processed_img = img.convert('RGB')
|
60 |
|
61 |
-
# 预处理
|
62 |
input_tensor = preprocess(processed_img)
|
63 |
input_batch = input_tensor.unsqueeze(0).to(device)
|
64 |
|
65 |
-
# 预测
|
66 |
with torch.no_grad():
|
67 |
output = model(input_batch)
|
68 |
|
@@ -70,15 +81,68 @@ def predict_image(img, remove_bg=False):
|
|
70 |
top3_probs, top3_indices = torch.topk(probabilities, 3)
|
71 |
|
72 |
results = {
|
73 |
-
class_names[i]: p.item()
|
74 |
for p, i in zip(top3_probs, top3_indices)
|
75 |
}
|
76 |
|
77 |
-
# 记录结果
|
78 |
best_class = class_names[top3_indices[0]]
|
79 |
best_conf = top3_probs[0].item() * 100
|
80 |
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
def create_interface():
|
@@ -86,55 +150,58 @@ def create_interface():
|
|
86 |
"r0_0_100.jpg",
|
87 |
"r0_18_100.jpg",
|
88 |
"9_100.jpg",
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"
|
92 |
]
|
93 |
|
94 |
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
|
95 |
-
gr.Markdown("""
|
96 |
-
# 🍎 智能水果识别系统
|
97 |
-
""")
|
98 |
|
99 |
-
# 新增:模式选择卡片(视觉强化)
|
100 |
with gr.Row():
|
101 |
with gr.Column(scale=3):
|
102 |
with gr.Group():
|
103 |
-
gr.Markdown("
|
104 |
with gr.Row():
|
105 |
-
bg_removal = gr.Checkbox(
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
)
|
110 |
|
111 |
-
# 主操作区域
|
112 |
-
with gr.Row():
|
113 |
-
with gr.Column():
|
114 |
-
original_image = gr.Image(label="📤 上传图片", type="pil")
|
115 |
-
gr.Examples(examples=examples, inputs=original_image)
|
116 |
submit_btn = gr.Button("🚀 开始识别", variant="primary")
|
117 |
|
118 |
-
|
119 |
-
gr.
|
120 |
-
<div style="background: #f3f4f6; padding: 15px; border-radius: 8px; margin-top: 10px">
|
121 |
-
<b>💡 使用建议:</b><br>
|
122 |
-
• 上传图片:选择一张图片,点击'开始识别'按钮<br>
|
123 |
-
• 勾选背景去除:适合杂乱背景的图片(识别更准确)<br>
|
124 |
-
• 不勾选:适合纯色背景的图片(速度更快)
|
125 |
-
</div>
|
126 |
-
""")
|
127 |
|
128 |
with gr.Column():
|
|
|
129 |
processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
|
130 |
best_pred = gr.Textbox(label="🔍 识别结果")
|
131 |
confidence = gr.Textbox(label="📊 置信度")
|
132 |
full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
134 |
submit_btn.click(
|
135 |
fn=predict_image,
|
136 |
inputs=[original_image, bg_removal],
|
137 |
-
outputs=[processed_image, best_pred, confidence, full_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
return demo
|
@@ -142,4 +209,4 @@ def create_interface():
|
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
interface = create_interface()
|
145 |
-
interface.launch(share=True)
|
|
|
7 |
from rembg import remove
|
8 |
from PIL import Image
|
9 |
import io
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
import threading
|
13 |
+
import concurrent.futures
|
14 |
|
15 |
# 加载类别名称
|
16 |
+
|
17 |
with open('class_names.pkl', 'rb') as f:
|
18 |
class_names = pickle.load(f)
|
19 |
|
|
|
35 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
36 |
])
|
37 |
|
38 |
+
# 创建线程池
|
39 |
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
40 |
+
|
41 |
+
|
42 |
+
class RealtimeState:
|
43 |
+
def __init__(self):
|
44 |
+
self.last_result = None
|
45 |
+
self.last_update_time = 0
|
46 |
+
self.is_processing = False
|
47 |
+
self.lock = threading.Lock()
|
48 |
+
|
49 |
+
|
50 |
+
realtime_state = RealtimeState()
|
51 |
+
|
52 |
|
53 |
def remove_background(img):
|
54 |
"""使用rembg去除背景并添加白色背景"""
|
|
|
55 |
img_byte_arr = io.BytesIO()
|
56 |
img.save(img_byte_arr, format='PNG')
|
57 |
img_bytes = img_byte_arr.getvalue()
|
58 |
|
|
|
59 |
removed_bg_bytes = remove(img_bytes)
|
|
|
|
|
60 |
removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
|
61 |
|
|
|
62 |
white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
|
63 |
combined = Image.alpha_composite(white_bg, removed_bg_img)
|
64 |
return combined.convert('RGB')
|
|
|
66 |
|
67 |
def predict_image(img, remove_bg=False):
|
68 |
"""分类预测主函数"""
|
|
|
69 |
if remove_bg:
|
70 |
processed_img = remove_background(img)
|
71 |
else:
|
72 |
+
processed_img = img.convert('RGB')
|
73 |
|
|
|
74 |
input_tensor = preprocess(processed_img)
|
75 |
input_batch = input_tensor.unsqueeze(0).to(device)
|
76 |
|
|
|
77 |
with torch.no_grad():
|
78 |
output = model(input_batch)
|
79 |
|
|
|
81 |
top3_probs, top3_indices = torch.topk(probabilities, 3)
|
82 |
|
83 |
results = {
|
84 |
+
class_names[i]: round(p.item(), 4)
|
85 |
for p, i in zip(top3_probs, top3_indices)
|
86 |
}
|
87 |
|
|
|
88 |
best_class = class_names[top3_indices[0]]
|
89 |
best_conf = top3_probs[0].item() * 100
|
90 |
|
91 |
+
with open('output/prediction_results.txt', 'a') as f:
|
92 |
+
f.write(f"Remove BG: {remove_bg}\n")
|
93 |
+
f.write(f"Predicted: {best_class} ({best_conf:.2f}%)\n")
|
94 |
+
f.write(f"Top 3: {results}\n\n")
|
95 |
+
|
96 |
+
# 添加一个空字符串作为 prediction_id
|
97 |
+
prediction_id = ""
|
98 |
+
|
99 |
+
return prediction_id, processed_img, best_class, f"{best_conf:.2f}%", results
|
100 |
+
|
101 |
+
|
102 |
+
def predict_realtime(video_frame, remove_bg):
|
103 |
+
"""实时预测主函数,结果保留2秒"""
|
104 |
+
global realtime_state
|
105 |
+
|
106 |
+
if video_frame is None:
|
107 |
+
return None, None, None, None, None
|
108 |
+
|
109 |
+
current_time = time.time()
|
110 |
+
|
111 |
+
# 检查是否有未过期的结果
|
112 |
+
with realtime_state.lock:
|
113 |
+
if realtime_state.last_result and current_time - realtime_state.last_update_time < 2:
|
114 |
+
return realtime_state.last_result
|
115 |
+
|
116 |
+
# 如果正在处理中,返回None
|
117 |
+
if realtime_state.is_processing:
|
118 |
+
return None, None, None, None, None
|
119 |
+
|
120 |
+
# 标记为正在处理
|
121 |
+
realtime_state.is_processing = True
|
122 |
+
|
123 |
+
# 异步处理帧
|
124 |
+
def process_frame():
|
125 |
+
try:
|
126 |
+
result = predict_image(video_frame, remove_bg)
|
127 |
+
with realtime_state.lock:
|
128 |
+
realtime_state.last_result = result
|
129 |
+
realtime_state.last_update_time = time.time()
|
130 |
+
realtime_state.is_processing = False
|
131 |
+
except Exception as e:
|
132 |
+
print(f"处理帧时出错: {e}")
|
133 |
+
with realtime_state.lock:
|
134 |
+
realtime_state.is_processing = False
|
135 |
+
|
136 |
+
# 提交到线程池处理
|
137 |
+
executor.submit(process_frame)
|
138 |
+
|
139 |
+
return None, None, None, None, None
|
140 |
+
|
141 |
+
|
142 |
+
def add_feedback(prediction_id, feedback):
|
143 |
+
"""模拟将反馈信息保存,实际上不做任何操作"""
|
144 |
+
print(f"收到反馈: {feedback} 对预测ID: {prediction_id}")
|
145 |
+
return True
|
146 |
|
147 |
|
148 |
def create_interface():
|
|
|
150 |
"r0_0_100.jpg",
|
151 |
"r0_18_100.jpg",
|
152 |
"9_100.jpg",
|
153 |
+
"127_100.jpg",
|
154 |
+
"5ecc819f1a579f513e0a1500fabb3f0.png",
|
155 |
+
"1105.jpg"
|
156 |
]
|
157 |
|
158 |
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
|
159 |
+
gr.Markdown("""# 🍎 智能水果识别系统""")
|
|
|
|
|
160 |
|
|
|
161 |
with gr.Row():
|
162 |
with gr.Column(scale=3):
|
163 |
with gr.Group():
|
164 |
+
gr.Markdown("## ⚙️ 处理模式选择")
|
165 |
with gr.Row():
|
166 |
+
bg_removal = gr.Checkbox(label="背景去除", value=False, interactive=True)
|
167 |
+
with gr.Column():
|
168 |
+
original_image = gr.Image(label="📤 上传图片", type="pil")
|
169 |
+
gr.Examples(examples=examples, inputs=original_image)
|
|
|
170 |
|
|
|
|
|
|
|
|
|
|
|
171 |
submit_btn = gr.Button("🚀 开始识别", variant="primary")
|
172 |
|
173 |
+
gr.Markdown("""## ⚡ 实时识别""")
|
174 |
+
camera = gr.Image(label="📷 摄像头捕获", type="pil", streaming=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
with gr.Column():
|
177 |
+
prediction_id_output = gr.Textbox(label="🔍 预测ID", interactive=False, visible=False)
|
178 |
processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
|
179 |
best_pred = gr.Textbox(label="🔍 识别结果")
|
180 |
confidence = gr.Textbox(label="📊 置信度")
|
181 |
full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3)
|
182 |
|
183 |
+
with gr.Row():
|
184 |
+
feedback_input = gr.Textbox(label="📝 输入反馈信息")
|
185 |
+
with gr.Row():
|
186 |
+
feedback_btn = gr.Button("📢 提交反馈", variant="secondary")
|
187 |
+
|
188 |
submit_btn.click(
|
189 |
fn=predict_image,
|
190 |
inputs=[original_image, bg_removal],
|
191 |
+
outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
|
192 |
+
)
|
193 |
+
|
194 |
+
camera.stream(
|
195 |
+
fn=predict_realtime,
|
196 |
+
inputs=[camera, bg_removal],
|
197 |
+
outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
|
198 |
+
)
|
199 |
+
|
200 |
+
feedback_btn.click(
|
201 |
+
fn=lambda prediction_id, feedback: (
|
202 |
+
add_feedback(prediction_id, feedback), "反馈成功!", gr.update(value="")),
|
203 |
+
inputs=[prediction_id_output, feedback_input],
|
204 |
+
outputs=[prediction_id_output, feedback_input]
|
205 |
)
|
206 |
|
207 |
return demo
|
|
|
209 |
|
210 |
if __name__ == "__main__":
|
211 |
interface = create_interface()
|
212 |
+
interface.launch(share=True)
|