pagling commited on
Commit
2764f0d
·
verified ·
1 Parent(s): dd303b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -57
app.py CHANGED
@@ -4,84 +4,118 @@ import cv2
4
  from ultralytics import YOLO
5
  from pathlib import Path
6
  import tempfile
 
7
 
8
- # 初始化模型占位符
9
- model = YOLO('yolov8n.pt') # 默认加载官方模型
10
 
11
  def load_custom_model(model_path):
12
  global model
13
- model = YOLO(model_path)
14
- return "模型加载成功!"
 
 
 
15
 
16
- def detect_media(input_type, input_data):
17
- if input_type == "camera":
18
- # 摄像头实时检测
19
- frame = input_data
20
- results = model.predict(source=frame, stream=True)
21
- annotated_frame = results[0].plot()
22
- return annotated_frame[:, :, ::-1] # BGR转RGB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- elif input_type == "image":
25
- # 图片检测
26
- results = model.predict(source=input_data)
27
- return results[0].plot()[:, :, ::-1]
28
 
29
- elif input_type == "video":
30
- # 视频检测处理
31
- cap = cv2.VideoCapture(input_data)
32
- output_frames = []
33
-
34
- while cap.isOpened():
35
- ret, frame = cap.read()
36
- if not ret: break
37
-
38
- results = model.predict(source=frame)
39
- annotated_frame = results[0].plot()
40
- output_frames.append(annotated_frame)
41
-
42
- # 生成临时输出视频
43
- output_path = str(Path(tempfile.gettempdir()) / "output.mp4")
44
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'),
45
- 30, (annotated_frame.shape[1], annotated_frame.shape[0]))
46
- for f in output_frames:
47
- out.write(f)
48
- out.release()
49
-
50
- return output_path
51
 
52
  # Gradio界面布局
53
- with gr.Blocks(title="YOLOv8检测系统") as demo:
54
- gr.Markdown("# 🚀 YOLOv8多功能检测系统")
 
55
 
56
  with gr.Tab("⚙️ 模型管理"):
57
- model_upload = gr.File(label="上传模型文件(.pt)", file_types=[".pt"])
58
- load_btn = gr.Button("加载模型", variant="primary")
59
- model_status = gr.Textbox(label="模型状态")
 
 
 
 
 
 
 
 
 
60
 
61
- with gr.Tab("📷 实时摄像头"):
62
- webcam = gr.Image(source="webcam", streaming=True, label="摄像头画面")
63
- cam_output = gr.Image(label="检测结果", interactive=False)
64
- webcam.stream(fn=lambda x: detect_media("camera", x),
65
- inputs=webcam, outputs=cam_output)
 
 
 
 
 
66
 
67
  with gr.Tab("🖼️ 图片检测"):
68
- img_input = gr.Image(type="filepath", label="上传图片")
 
 
69
  img_btn = gr.Button("开始检测", variant="primary")
70
- img_output = gr.Image(label="检测结果")
 
 
 
 
71
 
72
  with gr.Tab("🎥 视频检测"):
73
- vid_input = gr.Video(label="上传视频")
 
 
74
  vid_btn = gr.Button("处理视频", variant="primary")
75
- vid_output = gr.Video(label="处理结果")
 
 
 
 
76
 
77
- # 事件绑定
78
- load_btn.click(fn=load_custom_model, inputs=model_upload, outputs=model_status)
79
- img_btn.click(fn=lambda x: detect_media("image", x), inputs=img_input, outputs=img_output)
80
- vid_btn.click(fn=lambda x: detect_media("video", x), inputs=vid_input, outputs=vid_output)
 
 
81
 
82
  if __name__ == "__main__":
83
- demo.queue(concurrency_count=3).launch(
84
  server_name="0.0.0.0",
85
  server_port=7860,
86
- share=True
87
  )
 
4
  from ultralytics import YOLO
5
  from pathlib import Path
6
  import tempfile
7
+ import os
8
 
9
+ # 初始化默认模型
10
+ model = YOLO('yolov8n.pt') # 自动下载基础模型
11
 
12
  def load_custom_model(model_path):
13
  global model
14
+ try:
15
+ model = YOLO(model_path.name) # 适配HuggingFace的文件对象
16
+ return "✅ 模型加载成功!"
17
+ except Exception as e:
18
+ return f"❌ 加载失败:{str(e)}"
19
 
20
+ def process_frame(frame, input_type):
21
+ # 统一处理函数
22
+ results = model.predict(
23
+ source=frame,
24
+ verbose=False, # 关闭控制台输出
25
+ device="cpu", # 适配免费Space环境
26
+ conf=0.5 # 置信度阈值
27
+ )
28
+ annotated = results[0].plot()
29
+ return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
30
+
31
+ def video_pipeline(input_video):
32
+ # 视频处理优化
33
+ temp_dir = tempfile.mkdtemp()
34
+ output_path = os.path.join(temp_dir, "output.mp4")
35
+
36
+ cap = cv2.VideoCapture(input_video)
37
+ fps = cap.get(cv2.CAP_PROP_FPS)
38
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
39
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
40
+
41
+ writer = cv2.VideoWriter(
42
+ output_path,
43
+ cv2.VideoWriter_fourcc(*'mp4v'),
44
+ fps,
45
+ (width, height)
46
+ )
47
 
48
+ while cap.isOpened():
49
+ ret, frame = cap.read()
50
+ if not ret: break
51
+ writer.write(process_frame(frame, "video"))
52
 
53
+ cap.release()
54
+ writer.release()
55
+ return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Gradio界面布局
58
+ with gr.Blocks(title="YOLOv8检测系统", css=".gradio-container {background: #f0f0f0}") as demo:
59
+ gr.Markdown("""# 🚀 YOLOv8多功能检测系统
60
+ *欢迎上传自定义模型或使用默认模型进行检测*""")
61
 
62
  with gr.Tab("⚙️ 模型管理"):
63
+ with gr.Row():
64
+ model_upload = gr.UploadButton(
65
+ "上传模型文件 (.pt)",
66
+ file_types=[".pt"],
67
+ variant="primary"
68
+ )
69
+ model_status = gr.Textbox(label="状态", interactive=False)
70
+ gr.Examples(
71
+ examples=["yolov8n.pt", "yolov8s.pt"],
72
+ inputs=model_upload,
73
+ label="示例模型"
74
+ )
75
 
76
+ with gr.Tab("📷 实时检测"):
77
+ with gr.Row():
78
+ webcam = gr.Webcam(label="摄像头输入", mirror=True)
79
+ cam_output = gr.Image(label="检测结果")
80
+ webcam.stream(
81
+ fn=lambda x: process_frame(x, "camera"),
82
+ inputs=webcam,
83
+ outputs=cam_output,
84
+ show_progress="hidden"
85
+ )
86
 
87
  with gr.Tab("🖼️ 图片检测"):
88
+ with gr.Row():
89
+ img_input = gr.Image(type="filepath", sources=["upload"], label="输入图片")
90
+ img_output = gr.Image(label="检测结果")
91
  img_btn = gr.Button("开始检测", variant="primary")
92
+ img_btn.click(
93
+ fn=lambda x: process_frame(cv2.imread(x), "image"),
94
+ inputs=img_input,
95
+ outputs=img_output
96
+ )
97
 
98
  with gr.Tab("🎥 视频检测"):
99
+ with gr.Row():
100
+ vid_input = gr.Video(label="输入视频", sources=["upload"])
101
+ vid_output = gr.Video(label="处理结果")
102
  vid_btn = gr.Button("处理视频", variant="primary")
103
+ vid_btn.click(
104
+ fn=video_pipeline,
105
+ inputs=vid_input,
106
+ outputs=vid_output
107
+ )
108
 
109
+ # 模型加载事件
110
+ model_upload.upload(
111
+ fn=load_custom_model,
112
+ inputs=model_upload,
113
+ outputs=model_status
114
+ )
115
 
116
  if __name__ == "__main__":
117
+ demo.launch(
118
  server_name="0.0.0.0",
119
  server_port=7860,
120
+ show_error=True
121
  )