Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,13 +6,13 @@ from pathlib import Path
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
-
#
|
10 |
model = None
|
11 |
|
12 |
def load_custom_model(model_file):
|
13 |
global model
|
14 |
try:
|
15 |
-
#
|
16 |
save_path = f"/tmp/{model_file.name}"
|
17 |
with open(save_path, "wb") as f:
|
18 |
f.write(model_file.read())
|
@@ -23,101 +23,61 @@ def load_custom_model(model_file):
|
|
23 |
except Exception as e:
|
24 |
return f"❌ 加载失败:{str(e)}"
|
25 |
|
26 |
-
def process_frame(frame
|
27 |
-
global model
|
28 |
if model is None:
|
29 |
-
raise gr.Error("
|
30 |
|
31 |
results = model.predict(
|
32 |
source=frame,
|
33 |
-
verbose=False,
|
34 |
device="cpu",
|
35 |
-
conf=0.5
|
|
|
36 |
)
|
37 |
return cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
# 创建临时文件
|
44 |
-
temp_dir = tempfile.mkdtemp()
|
45 |
-
output_path = os.path.join(temp_dir, "output.mp4")
|
46 |
-
|
47 |
-
# 处理视频
|
48 |
-
cap = cv2.VideoCapture(input_video)
|
49 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
50 |
-
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
51 |
-
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
52 |
-
|
53 |
-
writer = cv2.VideoWriter(
|
54 |
-
output_path,
|
55 |
-
cv2.VideoWriter_fourcc(*'mp4v'),
|
56 |
-
fps,
|
57 |
-
(width, height)
|
58 |
-
)
|
59 |
-
|
60 |
-
try:
|
61 |
-
while cap.isOpened():
|
62 |
-
ret, frame = cap.read()
|
63 |
-
if not ret:
|
64 |
-
break
|
65 |
-
writer.write(process_frame(frame, "video"))
|
66 |
-
finally:
|
67 |
-
cap.release()
|
68 |
-
writer.release()
|
69 |
-
|
70 |
-
return output_path
|
71 |
-
|
72 |
-
with gr.Blocks(title="自定义模型检测系统") as demo:
|
73 |
-
gr.Markdown("# 🛠️ 自定义YOLOv8模型检测系统")
|
74 |
|
75 |
with gr.Tab("🔧 模型管理"):
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
variant="primary"
|
81 |
-
)
|
82 |
-
model_status = gr.Textbox(label="状态", interactive=False)
|
83 |
-
|
84 |
-
upload_btn.upload(
|
85 |
-
fn=load_custom_model,
|
86 |
-
inputs=upload_btn,
|
87 |
-
outputs=model_status
|
88 |
)
|
|
|
|
|
89 |
|
90 |
with gr.Tab("📷 实时检测"):
|
91 |
-
webcam = gr.
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
webcam.stream(
|
94 |
-
fn=lambda x: process_frame(x
|
95 |
inputs=webcam,
|
96 |
-
outputs=
|
97 |
-
show_progress="hidden"
|
98 |
)
|
99 |
|
100 |
with gr.Tab("🖼️ 图片检测"):
|
101 |
-
img_input = gr.Image(type="filepath"
|
102 |
-
img_output = gr.Image(
|
103 |
-
img_input.
|
104 |
-
fn=lambda x: process_frame(cv2.imread(x)
|
105 |
inputs=img_input,
|
106 |
outputs=img_output
|
107 |
)
|
108 |
|
109 |
with gr.Tab("🎥 视频检测"):
|
110 |
-
vid_input = gr.Video(
|
111 |
-
vid_output = gr.Video(
|
112 |
-
vid_input.
|
113 |
-
fn=video_pipeline,
|
114 |
inputs=vid_input,
|
115 |
outputs=vid_output
|
116 |
)
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
-
demo.launch(
|
120 |
-
server_name="0.0.0.0",
|
121 |
-
server_port=7860,
|
122 |
-
show_error=True
|
123 |
-
)
|
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
+
# 初始化模型占位符
|
10 |
model = None
|
11 |
|
12 |
def load_custom_model(model_file):
|
13 |
global model
|
14 |
try:
|
15 |
+
# 保存上传文件
|
16 |
save_path = f"/tmp/{model_file.name}"
|
17 |
with open(save_path, "wb") as f:
|
18 |
f.write(model_file.read())
|
|
|
23 |
except Exception as e:
|
24 |
return f"❌ 加载失败:{str(e)}"
|
25 |
|
26 |
+
def process_frame(frame):
|
|
|
27 |
if model is None:
|
28 |
+
raise gr.Error("请先上传模型文件")
|
29 |
|
30 |
results = model.predict(
|
31 |
source=frame,
|
|
|
32 |
device="cpu",
|
33 |
+
conf=0.5,
|
34 |
+
verbose=False
|
35 |
)
|
36 |
return cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
|
37 |
|
38 |
+
# Gradio界面(兼容旧版)
|
39 |
+
with gr.Blocks(title="YOLO检测系统") as demo:
|
40 |
+
gr.Markdown("# 🔍 自定义YOLO检测系统")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
with gr.Tab("🔧 模型管理"):
|
43 |
+
upload_btn = gr.UploadButton(
|
44 |
+
"上传.pt模型文件",
|
45 |
+
file_types=[".pt"],
|
46 |
+
variant="primary"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
)
|
48 |
+
status = gr.Textbox(label="状态")
|
49 |
+
upload_btn.upload(load_custom_model, upload_btn, status)
|
50 |
|
51 |
with gr.Tab("📷 实时检测"):
|
52 |
+
webcam = gr.Image(
|
53 |
+
source="webcam",
|
54 |
+
streaming=True,
|
55 |
+
label="摄像头画面"
|
56 |
+
)
|
57 |
+
output = gr.Image(label="检测结果")
|
58 |
webcam.stream(
|
59 |
+
fn=lambda x: process_frame(x),
|
60 |
inputs=webcam,
|
61 |
+
outputs=output
|
|
|
62 |
)
|
63 |
|
64 |
with gr.Tab("🖼️ 图片检测"):
|
65 |
+
img_input = gr.Image(type="filepath")
|
66 |
+
img_output = gr.Image()
|
67 |
+
img_input.change(
|
68 |
+
fn=lambda x: process_frame(cv2.imread(x)),
|
69 |
inputs=img_input,
|
70 |
outputs=img_output
|
71 |
)
|
72 |
|
73 |
with gr.Tab("🎥 视频检测"):
|
74 |
+
vid_input = gr.Video()
|
75 |
+
vid_output = gr.Video()
|
76 |
+
vid_input.change(
|
77 |
+
fn=lambda x: video_pipeline(x),
|
78 |
inputs=vid_input,
|
79 |
outputs=vid_output
|
80 |
)
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
+
demo.launch(server_name="0.0.0.0")
|
|
|
|
|
|
|
|