pagling commited on
Commit
d42ac84
·
verified ·
1 Parent(s): 05ec12c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -72
app.py CHANGED
@@ -1,104 +1,95 @@
1
  import gradio as gr
2
- import numpy as np
3
  import cv2
 
4
  from ultralytics import YOLO
5
  from pathlib import Path
6
  import tempfile
7
  import os
8
- import requests
9
- from pathlib import Path
10
- def download_file(url):
11
- filename = url.split("/")[-1]
12
- if not Path(filename).exists():
13
- print(f"正在下载 {filename}...")
14
- r = requests.get(url)
15
- with open(filename, "wb") as f:
16
- f.write(r.content)
17
- return filename
18
 
19
- with gr.Blocks() as demo:
20
- with gr.Tab("⚙️ 模型管理"):
21
- model_upload = gr.File(file_types=[".pt"])
22
- gr.Examples(
23
- examples=[
24
- ["https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt"],
25
- ["https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt"]
26
- ],
27
- inputs=model_upload,
28
- fn=download_file,
29
- run_on_click=True # 点击示例自动触发下载
30
- )
31
- # 初始化默认模型
32
- model = YOLO('yolov8n.pt') # 自动下载基础模型
33
 
34
- def load_custom_model(model_path):
35
  global model
36
  try:
37
- model = YOLO(model_path.name) # 适配HuggingFace的文件对象
 
 
 
 
 
 
38
  return "✅ 模型加载成功!"
39
  except Exception as e:
40
  return f"❌ 加载失败:{str(e)}"
41
 
42
  def process_frame(frame, input_type):
43
- # 统一处理函数
 
 
 
44
  results = model.predict(
45
  source=frame,
46
- verbose=False, # 关闭控制台输出
47
- device="cpu", # 适配免费Space环境
48
- conf=0.5 # 置信度阈值
49
  )
50
- annotated = results[0].plot()
51
- return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
52
 
53
  def video_pipeline(input_video):
54
- # 视频处理优化
 
 
 
55
  temp_dir = tempfile.mkdtemp()
56
  output_path = os.path.join(temp_dir, "output.mp4")
57
 
 
58
  cap = cv2.VideoCapture(input_video)
59
  fps = cap.get(cv2.CAP_PROP_FPS)
60
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
61
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
 
63
  writer = cv2.VideoWriter(
64
  output_path,
65
  cv2.VideoWriter_fourcc(*'mp4v'),
66
- fps,
67
  (width, height)
68
  )
69
 
70
- while cap.isOpened():
71
- ret, frame = cap.read()
72
- if not ret: break
73
- writer.write(process_frame(frame, "video"))
 
 
 
 
 
74
 
75
- cap.release()
76
- writer.release()
77
  return output_path
78
 
79
- # Gradio界面布局
80
- with gr.Blocks(title="YOLOv8检测系统", css=".gradio-container {background: #f0f0f0}") as demo:
81
- gr.Markdown("""# 🚀 YOLOv8多功能检测系统
82
- *欢迎上传自定义模型或使用默认模型进行检测*""")
83
 
84
- with gr.Tab("⚙️ 模型管理"):
85
  with gr.Row():
86
- model_upload = gr.UploadButton(
87
  "上传模型文件 (.pt)",
88
  file_types=[".pt"],
89
  variant="primary"
90
  )
91
  model_status = gr.Textbox(label="状态", interactive=False)
92
- gr.Examples(
93
- examples=["yolov8n.pt", "yolov8s.pt"],
94
- inputs=model_upload,
95
- label="示例模型"
 
96
  )
97
 
98
  with gr.Tab("📷 实时检测"):
99
- with gr.Row():
100
- webcam = gr.Webcam(label="摄像头输入", mirror=True)
101
- cam_output = gr.Image(label="检测结果")
102
  webcam.stream(
103
  fn=lambda x: process_frame(x, "camera"),
104
  inputs=webcam,
@@ -107,33 +98,22 @@ with gr.Blocks(title="YOLOv8检测系统", css=".gradio-container {background: #
107
  )
108
 
109
  with gr.Tab("🖼️ 图片检测"):
110
- with gr.Row():
111
- img_input = gr.Image(type="filepath", sources=["upload"], label="输入图片")
112
- img_output = gr.Image(label="检测结果")
113
- img_btn = gr.Button("开始检测", variant="primary")
114
- img_btn.click(
115
  fn=lambda x: process_frame(cv2.imread(x), "image"),
116
  inputs=img_input,
117
  outputs=img_output
118
  )
119
 
120
  with gr.Tab("🎥 视频检测"):
121
- with gr.Row():
122
- vid_input = gr.Video(label="输入视频", sources=["upload"])
123
- vid_output = gr.Video(label="处理结果")
124
- vid_btn = gr.Button("处理视频", variant="primary")
125
- vid_btn.click(
126
  fn=video_pipeline,
127
  inputs=vid_input,
128
  outputs=vid_output
129
  )
130
-
131
- # 模型加载事件
132
- model_upload.upload(
133
- fn=load_custom_model,
134
- inputs=model_upload,
135
- outputs=model_status
136
- )
137
 
138
  if __name__ == "__main__":
139
  demo.launch(
 
1
  import gradio as gr
 
2
  import cv2
3
+ import numpy as np
4
  from ultralytics import YOLO
5
  from pathlib import Path
6
  import tempfile
7
  import os
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # 初始化空模型
10
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def load_custom_model(model_file):
13
  global model
14
  try:
15
+ # 保存上传文件到临时目录
16
+ save_path = f"/tmp/{model_file.name}"
17
+ with open(save_path, "wb") as f:
18
+ f.write(model_file.read())
19
+
20
+ # 加载模型
21
+ model = YOLO(save_path)
22
  return "✅ 模型加载成功!"
23
  except Exception as e:
24
  return f"❌ 加载失败:{str(e)}"
25
 
26
  def process_frame(frame, input_type):
27
+ global model
28
+ if model is None:
29
+ raise gr.Error("请先上传并加载模型")
30
+
31
  results = model.predict(
32
  source=frame,
33
+ verbose=False,
34
+ device="cpu",
35
+ conf=0.5
36
  )
37
+ return cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
 
38
 
39
  def video_pipeline(input_video):
40
+ if model is None:
41
+ raise gr.Error("请先上传并加载模型")
42
+
43
+ # 创建临时文件
44
  temp_dir = tempfile.mkdtemp()
45
  output_path = os.path.join(temp_dir, "output.mp4")
46
 
47
+ # 处理视频
48
  cap = cv2.VideoCapture(input_video)
49
  fps = cap.get(cv2.CAP_PROP_FPS)
50
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
51
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
52
 
53
  writer = cv2.VideoWriter(
54
  output_path,
55
  cv2.VideoWriter_fourcc(*'mp4v'),
56
+ fps,
57
  (width, height)
58
  )
59
 
60
+ try:
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if not ret:
64
+ break
65
+ writer.write(process_frame(frame, "video"))
66
+ finally:
67
+ cap.release()
68
+ writer.release()
69
 
 
 
70
  return output_path
71
 
72
+ with gr.Blocks(title="自定义模型检测系统") as demo:
73
+ gr.Markdown("# 🛠️ 自定义YOLOv8模型检测系统")
 
 
74
 
75
+ with gr.Tab("🔧 模型管理"):
76
  with gr.Row():
77
+ upload_btn = gr.UploadButton(
78
  "上传模型文件 (.pt)",
79
  file_types=[".pt"],
80
  variant="primary"
81
  )
82
  model_status = gr.Textbox(label="状态", interactive=False)
83
+
84
+ upload_btn.upload(
85
+ fn=load_custom_model,
86
+ inputs=upload_btn,
87
+ outputs=model_status
88
  )
89
 
90
  with gr.Tab("📷 实时检测"):
91
+ webcam = gr.Webcam(label="摄像头")
92
+ cam_output = gr.Image(label="检测结果")
 
93
  webcam.stream(
94
  fn=lambda x: process_frame(x, "camera"),
95
  inputs=webcam,
 
98
  )
99
 
100
  with gr.Tab("🖼️ 图片检测"):
101
+ img_input = gr.Image(type="filepath", label="上传图片")
102
+ img_output = gr.Image(label="检测结果")
103
+ img_input.upload(
 
 
104
  fn=lambda x: process_frame(cv2.imread(x), "image"),
105
  inputs=img_input,
106
  outputs=img_output
107
  )
108
 
109
  with gr.Tab("🎥 视频检测"):
110
+ vid_input = gr.Video(label="上传视频")
111
+ vid_output = gr.Video(label="处理结果")
112
+ vid_input.upload(
 
 
113
  fn=video_pipeline,
114
  inputs=vid_input,
115
  outputs=vid_output
116
  )
 
 
 
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  demo.launch(