Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cv2 | |
import tempfile | |
from ultralytics import YOLO | |
from pathlib import Path | |
# 全局变量存储当前模型 | |
current_model = None | |
def load_model(model_path): | |
global current_model | |
try: | |
current_model = YOLO(model_path) | |
return "模型加载成功!" | |
except Exception as e: | |
return f"模型加载失败:{str(e)}" | |
def detect_image(input_image, conf_threshold): | |
if current_model is None: | |
raise gr.Error("请先上传模型文件") | |
results = current_model(input_image, conf=conf_threshold) | |
plotted = results[0].plot() | |
return plotted[:, :, ::-1] # BGR转RGB | |
def detect_video(input_video, conf_threshold): | |
if current_model is None: | |
raise gr.Error("请先上传模型文件") | |
cap = cv2.VideoCapture(input_video) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# 创建临时输出文件 | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
results = current_model(frame, conf=conf_threshold) | |
plotted = results[0].plot() | |
out.write(plotted) | |
cap.release() | |
out.release() | |
return temp_file.name | |
def detect_webcam(camera_input, conf_threshold): | |
if current_model is None: | |
raise gr.Error("请先上传模型文件") | |
if camera_input is None: | |
return None | |
results = current_model(camera_input, conf=conf_threshold) | |
plotted = results[0].plot() | |
return plotted[:, :, ::-1] # BGR转RGB | |
with gr.Blocks() as demo: | |
gr.Markdown("# YOLOv8 自定义模型检测系统") | |
with gr.Row(): | |
model_input = gr.File(label="上传模型文件 (.pt)", type="filepath") | |
model_status = gr.Textbox(label="模型状态", interactive=False) | |
model_input.upload(fn=load_model, inputs=model_input, outputs=model_status) | |
with gr.Tabs(): | |
with gr.TabItem("图片检测"): | |
with gr.Row(): | |
img_input = gr.Image(label="输入图片", type="filepath") | |
img_output = gr.Image(label="检测结果") | |
img_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值") | |
img_button = gr.Button("执行检测") | |
with gr.TabItem("视频检测"): | |
with gr.Row(): | |
video_input = gr.Video(label="输入视频") | |
video_output = gr.Video(label="检测结果") | |
video_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值") | |
video_button = gr.Button("执行检测") | |
with gr.TabItem("实时摄像头"): | |
webcam_input = gr.Webcam(label="摄像头画面") # 使用官方 Webcam 组件 | |
webcam_output = gr.Image(label="检测结果") | |
webcam_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值") | |
webcam_button = gr.Button("开始检测") | |
webcam_button.click(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output) | |
# 绑定事件处理 | |
img_button.click(fn=detect_image, inputs=[img_input, img_conf], outputs=img_output) | |
video_button.click(fn=detect_video, inputs=[video_input, video_conf], outputs=video_output) | |
webcam_button.click(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output) | |
if __name__ == "__main__": | |
demo.launch() |