Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import pickle | |
from resnest.torch import resnest50 | |
from rembg import remove | |
from PIL import Image | |
import io | |
import json | |
import time | |
import threading | |
import concurrent.futures | |
# 加载类别名称 | |
with open('class_names.pkl', 'rb') as f: | |
class_names = pickle.load(f) | |
# 初始化模型 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = resnest50(pretrained=False) | |
model.fc = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear(model.fc.in_features, len(class_names)) | |
) | |
model.load_state_dict(torch.load('best_model.pth', map_location=device)) | |
model = model.to(device) | |
model.eval() | |
# 预处理流程 | |
preprocess = transforms.Compose([ | |
transforms.Resize((100, 100)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# 创建线程池 | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
class RealtimeState: | |
def __init__(self): | |
self.last_result = None | |
self.last_update_time = 0 | |
self.is_processing = False | |
self.lock = threading.Lock() | |
realtime_state = RealtimeState() | |
def remove_background(img): | |
"""使用rembg去除背景并添加白色背景""" | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format='PNG') | |
img_bytes = img_byte_arr.getvalue() | |
removed_bg_bytes = remove(img_bytes) | |
removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA') | |
white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255)) | |
combined = Image.alpha_composite(white_bg, removed_bg_img) | |
return combined.convert('RGB') | |
def predict_image(img, remove_bg=False): | |
"""分类预测主函数""" | |
if remove_bg: | |
processed_img = remove_background(img) | |
else: | |
processed_img = img.convert('RGB') | |
input_tensor = preprocess(processed_img) | |
input_batch = input_tensor.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_batch) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
top3_probs, top3_indices = torch.topk(probabilities, 3) | |
results = { | |
class_names[i]: round(p.item(), 4) | |
for p, i in zip(top3_probs, top3_indices) | |
} | |
best_class = class_names[top3_indices[0]] | |
best_conf = top3_probs[0].item() * 100 | |
with open('prediction_results.txt', 'a') as f: | |
f.write(f"Remove BG: {remove_bg}\n") | |
f.write(f"Predicted: {best_class} ({best_conf:.2f}%)\n") | |
f.write(f"Top 3: {results}\n\n") | |
# 添加一个空字符串作为 prediction_id | |
prediction_id = "" | |
return prediction_id, processed_img, best_class, f"{best_conf:.2f}%", results | |
def predict_realtime(video_frame, remove_bg): | |
"""实时预测主函数,结果保留2秒""" | |
global realtime_state | |
if video_frame is None: | |
return None, None, None, None, None | |
current_time = time.time() | |
# 检查是否有未过期的结果 | |
with realtime_state.lock: | |
if realtime_state.last_result and current_time - realtime_state.last_update_time < 2: | |
return realtime_state.last_result | |
# 如果正在处理中,返回None | |
if realtime_state.is_processing: | |
return None, None, None, None, None | |
# 标记为正在处理 | |
realtime_state.is_processing = True | |
# 异步处理帧 | |
def process_frame(): | |
try: | |
result = predict_image(video_frame, remove_bg) | |
with realtime_state.lock: | |
realtime_state.last_result = result | |
realtime_state.last_update_time = time.time() | |
realtime_state.is_processing = False | |
except Exception as e: | |
print(f"处理帧时出错: {e}") | |
with realtime_state.lock: | |
realtime_state.is_processing = False | |
# 提交到线程池处理 | |
executor.submit(process_frame) | |
return None, None, None, None, None | |
def add_feedback(prediction_id, feedback): | |
"""模拟将反馈信息保存,实际上不做任何操作""" | |
print(f"收到反馈: {feedback} 对预测ID: {prediction_id}") | |
return True | |
def create_interface(): | |
examples = [ | |
"r0_0_100.jpg", | |
"r0_18_100.jpg", | |
"9_100.jpg", | |
"5ecc819f1a579f513e0a1500fabb3f0.png", | |
"1105.jpg" | |
] | |
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("""# 🍎 智能水果识别系统""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(): | |
gr.Markdown("## ⚙️ 处理模式选择") | |
with gr.Row(): | |
bg_removal = gr.Checkbox(label="背景去除", value=False, interactive=True) | |
with gr.Column(): | |
original_image = gr.Image(label="📤 上传图片", type="pil") | |
gr.Examples(examples=examples, inputs=original_image) | |
submit_btn = gr.Button("🚀 开始识别", variant="primary") | |
gr.Markdown("""## ⚡ 实时识别""") | |
camera = gr.Image(label="📷 摄像头捕获", type="pil", streaming=True) | |
with gr.Column(): | |
prediction_id_output = gr.Textbox(label="🔍 预测ID", interactive=False, visible=False) | |
processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False) | |
best_pred = gr.Textbox(label="🔍 识别结果") | |
confidence = gr.Textbox(label="📊 置信度") | |
full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3) | |
with gr.Row(): | |
feedback_input = gr.Textbox(label="📝 输入反馈信息") | |
with gr.Row(): | |
feedback_btn = gr.Button("📢 提交反馈", variant="secondary") | |
submit_btn.click( | |
fn=predict_image, | |
inputs=[original_image, bg_removal], | |
outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results] | |
) | |
camera.stream( | |
fn=predict_realtime, | |
inputs=[camera, bg_removal], | |
outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results] | |
) | |
feedback_btn.click( | |
fn=lambda prediction_id, feedback: ( | |
add_feedback(prediction_id, feedback), "反馈成功!", gr.update(value="")), | |
inputs=[prediction_id_output, feedback_input], | |
outputs=[prediction_id_output, feedback_input] | |
) | |
return demo | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(share=True) | |