pagling commited on
Commit
a770113
·
verified ·
1 Parent(s): 97af053

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -72
app.py CHANGED
@@ -6,13 +6,13 @@ from pathlib import Path
6
  import tempfile
7
  import os
8
 
9
- # 初始化空模型
10
  model = None
11
 
12
  def load_custom_model(model_file):
13
  global model
14
  try:
15
- # 保存上传文件到临时目录
16
  save_path = f"/tmp/{model_file.name}"
17
  with open(save_path, "wb") as f:
18
  f.write(model_file.read())
@@ -23,101 +23,61 @@ def load_custom_model(model_file):
23
  except Exception as e:
24
  return f"❌ 加载失败:{str(e)}"
25
 
26
- def process_frame(frame, input_type):
27
- global model
28
  if model is None:
29
- raise gr.Error("请先上传并加载模型")
30
 
31
  results = model.predict(
32
  source=frame,
33
- verbose=False,
34
  device="cpu",
35
- conf=0.5
 
36
  )
37
  return cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
38
 
39
- def video_pipeline(input_video):
40
- if model is None:
41
- raise gr.Error("请先上传并加载模型")
42
-
43
- # 创建临时文件
44
- temp_dir = tempfile.mkdtemp()
45
- output_path = os.path.join(temp_dir, "output.mp4")
46
-
47
- # 处理视频
48
- cap = cv2.VideoCapture(input_video)
49
- fps = cap.get(cv2.CAP_PROP_FPS)
50
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
51
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
52
-
53
- writer = cv2.VideoWriter(
54
- output_path,
55
- cv2.VideoWriter_fourcc(*'mp4v'),
56
- fps,
57
- (width, height)
58
- )
59
-
60
- try:
61
- while cap.isOpened():
62
- ret, frame = cap.read()
63
- if not ret:
64
- break
65
- writer.write(process_frame(frame, "video"))
66
- finally:
67
- cap.release()
68
- writer.release()
69
-
70
- return output_path
71
-
72
- with gr.Blocks(title="自定义模型检测系统") as demo:
73
- gr.Markdown("# 🛠️ 自定义YOLOv8模型检测系统")
74
 
75
  with gr.Tab("🔧 模型管理"):
76
- with gr.Row():
77
- upload_btn = gr.UploadButton(
78
- "上传模型文件 (.pt)",
79
- file_types=[".pt"],
80
- variant="primary"
81
- )
82
- model_status = gr.Textbox(label="状态", interactive=False)
83
-
84
- upload_btn.upload(
85
- fn=load_custom_model,
86
- inputs=upload_btn,
87
- outputs=model_status
88
  )
 
 
89
 
90
  with gr.Tab("📷 实时检测"):
91
- webcam = gr.Webcam(label="摄像头")
92
- cam_output = gr.Image(label="检测结果")
 
 
 
 
93
  webcam.stream(
94
- fn=lambda x: process_frame(x, "camera"),
95
  inputs=webcam,
96
- outputs=cam_output,
97
- show_progress="hidden"
98
  )
99
 
100
  with gr.Tab("🖼️ 图片检测"):
101
- img_input = gr.Image(type="filepath", label="上传图片")
102
- img_output = gr.Image(label="检测结果")
103
- img_input.upload(
104
- fn=lambda x: process_frame(cv2.imread(x), "image"),
105
  inputs=img_input,
106
  outputs=img_output
107
  )
108
 
109
  with gr.Tab("🎥 视频检测"):
110
- vid_input = gr.Video(label="上传视频")
111
- vid_output = gr.Video(label="处理结果")
112
- vid_input.upload(
113
- fn=video_pipeline,
114
  inputs=vid_input,
115
  outputs=vid_output
116
  )
117
 
118
  if __name__ == "__main__":
119
- demo.launch(
120
- server_name="0.0.0.0",
121
- server_port=7860,
122
- show_error=True
123
- )
 
6
  import tempfile
7
  import os
8
 
9
+ # 初始化模型占位符
10
  model = None
11
 
12
  def load_custom_model(model_file):
13
  global model
14
  try:
15
+ # 保存上传文件
16
  save_path = f"/tmp/{model_file.name}"
17
  with open(save_path, "wb") as f:
18
  f.write(model_file.read())
 
23
  except Exception as e:
24
  return f"❌ 加载失败:{str(e)}"
25
 
26
+ def process_frame(frame):
 
27
  if model is None:
28
+ raise gr.Error("请先上传模型文件")
29
 
30
  results = model.predict(
31
  source=frame,
 
32
  device="cpu",
33
+ conf=0.5,
34
+ verbose=False
35
  )
36
  return cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
37
 
38
+ # Gradio界面(兼容旧版)
39
+ with gr.Blocks(title="YOLO检测系统") as demo:
40
+ gr.Markdown("# 🔍 自定义YOLO检测系统")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  with gr.Tab("🔧 模型管理"):
43
+ upload_btn = gr.UploadButton(
44
+ "上传.pt模型文件",
45
+ file_types=[".pt"],
46
+ variant="primary"
 
 
 
 
 
 
 
 
47
  )
48
+ status = gr.Textbox(label="状态")
49
+ upload_btn.upload(load_custom_model, upload_btn, status)
50
 
51
  with gr.Tab("📷 实时检测"):
52
+ webcam = gr.Image(
53
+ source="webcam",
54
+ streaming=True,
55
+ label="摄像头画面"
56
+ )
57
+ output = gr.Image(label="检测结果")
58
  webcam.stream(
59
+ fn=lambda x: process_frame(x),
60
  inputs=webcam,
61
+ outputs=output
 
62
  )
63
 
64
  with gr.Tab("🖼️ 图片检测"):
65
+ img_input = gr.Image(type="filepath")
66
+ img_output = gr.Image()
67
+ img_input.change(
68
+ fn=lambda x: process_frame(cv2.imread(x)),
69
  inputs=img_input,
70
  outputs=img_output
71
  )
72
 
73
  with gr.Tab("🎥 视频检测"):
74
+ vid_input = gr.Video()
75
+ vid_output = gr.Video()
76
+ vid_input.change(
77
+ fn=lambda x: video_pipeline(x),
78
  inputs=vid_input,
79
  outputs=vid_output
80
  )
81
 
82
  if __name__ == "__main__":
83
+ demo.launch(server_name="0.0.0.0")