pagling commited on
Commit
c191c81
·
verified ·
1 Parent(s): e56e569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -25
app.py CHANGED
@@ -1,36 +1,95 @@
1
  import gradio as gr
2
  import cv2
3
- import numpy as np
4
  from ultralytics import YOLO
5
- from importlib.metadata import version
6
 
7
- # 动态组件选择
8
- GRADIO_VERSION = version("gradio")
9
- if int(GRADIO_VERSION.split('.')[0]) >= 4 and int(GRADIO_VERSION.split('.')[1]) >= 12:
10
- Webcam = gr.Webcam
11
- else:
12
- class Webcam(gr.Image):
13
- def __init__(self, ​**kwargs):
14
- kwargs.update(source="webcam", streaming=True)
15
- super().__init__(**kwargs)
16
 
17
- # YOLO模型初始化
18
- model = None
 
 
 
 
 
19
 
20
- def load_model(model_file):
21
- global model
22
- model = YOLO(model_file.name)
23
- return "✅ 模型加载成功"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  with gr.Blocks() as demo:
26
- with gr.Tab("📷 实时检测"):
27
- webcam = Webcam(label="摄像头")
28
- output = gr.Image()
29
- webcam.stream(
30
- fn=lambda x: model.predict(x)[0].plot(),
31
- inputs=webcam,
32
- outputs=output
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  demo.launch()
 
1
  import gradio as gr
2
  import cv2
3
+ import tempfile
4
  from ultralytics import YOLO
5
+ from pathlib import Path
6
 
7
+ # 全局变量存储当前模型
8
+ current_model = None
 
 
 
 
 
 
 
9
 
10
+ def load_model(model_path):
11
+ global current_model
12
+ try:
13
+ current_model = YOLO(model_path)
14
+ return "模型加载成功!"
15
+ except Exception as e:
16
+ return f"模型加载失败:{str(e)}"
17
 
18
+ def detect_image(input_image, conf_threshold):
19
+ if current_model is None:
20
+ raise gr.Error("请先上传模型文件")
21
+
22
+ results = current_model(input_image, conf=conf_threshold)
23
+ plotted = results[0].plot()
24
+ return plotted[:, :, ::-1] # BGR转RGB
25
+
26
+ def detect_video(input_video, conf_threshold):
27
+ if current_model is None:
28
+ raise gr.Error("请先上传模型文件")
29
+
30
+ cap = cv2.VideoCapture(input_video)
31
+ fps = cap.get(cv2.CAP_PROP_FPS)
32
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
33
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
34
+
35
+ # 创建临时输出文件
36
+ temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
37
+ out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
38
+
39
+ while cap.isOpened():
40
+ ret, frame = cap.read()
41
+ if not ret:
42
+ break
43
+
44
+ results = current_model(frame, conf=conf_threshold)
45
+ plotted = results[0].plot()
46
+ out.write(plotted)
47
+
48
+ cap.release()
49
+ out.release()
50
+ return temp_file.name
51
+
52
+ def detect_webcam(camera_input, conf_threshold):
53
+ if current_model is None:
54
+ raise gr.Error("请先上传模型文件")
55
+
56
+ results = current_model(camera_input, conf=conf_threshold)
57
+ plotted = results[0].plot()
58
+ return plotted[:, :, ::-1] # BGR转RGB
59
 
60
  with gr.Blocks() as demo:
61
+ gr.Markdown("# YOLOv8 自定义模型检测系统")
62
+
63
+ with gr.Row():
64
+ model_input = gr.File(label="上传模型文件 (.pt)", type="filepath")
65
+ model_status = gr.Textbox(label="模型状态", interactive=False)
66
+
67
+ model_input.upload(fn=load_model, inputs=model_input, outputs=model_status)
68
+
69
+ with gr.Tabs():
70
+ with gr.TabItem("图片检测"):
71
+ with gr.Row():
72
+ img_input = gr.Image(label="输入图片", type="filepath")
73
+ img_output = gr.Image(label="检测结果")
74
+ img_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
75
+ img_button = gr.Button("执行检测")
76
+
77
+ with gr.TabItem("视频检测"):
78
+ with gr.Row():
79
+ video_input = gr.Video(label="输入视频")
80
+ video_output = gr.Video(label="检测结果")
81
+ video_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
82
+ video_button = gr.Button("执行检测")
83
+
84
+ with gr.TabItem("实时摄像头"):
85
+ webcam_input = gr.Image(label="摄像头画面", source="webcam", streaming=True)
86
+ webcam_output = gr.Image(label="检测结果", streaming=True)
87
+ webcam_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
88
+
89
+ # 绑定事件处理
90
+ img_button.click(fn=detect_image, inputs=[img_input, img_conf], outputs=img_output)
91
+ video_button.click(fn=detect_video, inputs=[video_input, video_conf], outputs=video_output)
92
+ webcam_input.stream(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output)
93
 
94
  if __name__ == "__main__":
95
  demo.launch()