pagling commited on
Commit
12b5f76
·
verified ·
1 Parent(s): 2bc7c28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -96
app.py CHANGED
@@ -1,107 +1,79 @@
1
  import gradio as gr
2
  import cv2
3
- import tempfile
 
4
  from ultralytics import YOLO
5
- from pathlib import Path
6
 
7
- # 全局变量存储当前模型
8
- current_model = None
 
 
 
9
 
10
- def load_model(model_path):
11
- global current_model
12
- try:
13
- current_model = YOLO(model_path)
14
- return "模型加载成功!"
15
- except Exception as e:
16
- return f"模型加载失败:{str(e)}"
17
 
18
- def detect_image(input_image, conf_threshold):
19
- if current_model is None:
20
- raise gr.Error("请先上传模型文件")
21
-
22
- results = current_model(input_image, conf=conf_threshold)
23
- plotted = results[0].plot()
24
- return plotted[:, :, ::-1] # BGR转RGB
25
 
26
- def detect_video(input_video, conf_threshold):
27
- if current_model is None:
28
- raise gr.Error("请先上传模型文件")
29
-
30
- cap = cv2.VideoCapture(input_video)
31
- fps = cap.get(cv2.CAP_PROP_FPS)
32
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
33
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
34
-
35
- # 创建临时输出文件
36
- temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
37
- out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
38
-
39
- while cap.isOpened():
40
- ret, frame = cap.read()
41
- if not ret:
42
- break
43
-
44
- results = current_model(frame, conf=conf_threshold)
45
- plotted = results[0].plot()
46
- out.write(plotted)
47
-
48
- cap.release()
49
- out.release()
50
- return temp_file.name
51
 
52
- def detect_webcam(camera_input, conf_threshold):
53
- if current_model is None:
54
- raise gr.Error("请先上传模型文件")
55
-
56
- if camera_input is None:
57
- return None
58
-
59
- results = current_model(camera_input, conf=conf_threshold)
60
- plotted = results[0].plot()
61
- return plotted[:, :, ::-1] # BGR转RGB
62
 
63
- with gr.Blocks() as demo:
64
- gr.Markdown("# YOLOv8 自定义模型检测系统")
65
-
66
- with gr.Row():
67
- model_input = gr.File(label="上传模型文件 (.pt)", type="filepath")
68
- model_status = gr.Textbox(label="模型状态", interactive=False)
69
-
70
- model_input.upload(fn=load_model, inputs=model_input, outputs=model_status)
71
-
72
- with gr.Tabs():
73
- with gr.TabItem("图片检测"):
74
- with gr.Row():
75
- img_input = gr.Image(label="输入图片", type="filepath")
76
- img_output = gr.Image(label="检测结果")
77
- img_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
78
- img_button = gr.Button("执行检测")
79
- # 正确缩进位置 ↓
80
- img_button.click(fn=detect_image, inputs=[img_input, img_conf], outputs=img_output)
81
-
82
- with gr.TabItem("视频检测"):
83
- with gr.Row():
84
- video_input = gr.Video(label="输入视频")
85
- video_output = gr.Video(label="检测结果")
86
- video_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
87
- video_button = gr.Button("执行检测")
88
- # 正确缩进位置 ↓
89
- video_button.click(fn=detect_video, inputs=[video_input, video_conf], outputs=video_output)
90
-
91
- with gr.TabItem("实时摄像头"):
92
- webcam_input = gr.Webcam(label="摄像头画面")
93
- webcam_output = gr.Image(label="检测结果")
94
- webcam_conf = gr.Slider(0, 1, value=0.5, label="置信度阈值")
95
- webcam_button = gr.Button("开始检测")
96
- # 正确缩进位置 ↓
97
- webcam_button.click(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output)
98
 
99
- if __name__ == "__main__":
100
- demo.launch()
 
 
 
101
 
102
- webcam_button.click(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output)
103
-
104
- # 绑定事件处理
105
- img_button.click(fn=detect_image, inputs=[img_input, img_conf], outputs=img_output)
106
- video_button.click(fn=detect_video, inputs=[video_input, video_conf], outputs=video_output)
107
- webcam_button.click(fn=detect_webcam, inputs=[webcam_input, webcam_conf], outputs=webcam_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
+ import numpy as np
4
+ import os
5
  from ultralytics import YOLO
 
6
 
7
+ # 设置上传和结果文件夹
8
+ UPLOAD_FOLDER = 'uploads'
9
+ RESULT_FOLDER = 'results'
10
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
11
+ os.makedirs(RESULT_FOLDER, exist_ok=True)
12
 
13
+ # 加载模型
14
+ model = YOLO('yolov8n.pt')
 
 
 
 
 
15
 
16
+ def process_image(image):
17
+ # 保存上传的图像
18
+ filename = 'uploaded_image.jpg'
19
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
20
+ cv2.imwrite(file_path, image)
 
 
21
 
22
+ # 处理图像
23
+ results = model(image)
24
+ detection_results = []
25
+ class_counts = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ for result in results:
28
+ boxes = result.boxes
29
+ for box in boxes:
30
+ x1, y1, x2, y2 = box.xyxy[0]
31
+ conf = box.conf[0]
32
+ cls = box.cls[0]
33
+ class_name = model.names[int(cls)]
 
 
 
34
 
35
+ # 计算每种类别的数量
36
+ if class_name in class_counts:
37
+ class_counts[class_name] += 1
38
+ else:
39
+ class_counts[class_name] = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # 绘制边界框和标签
42
+ cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
43
+ cv2.putText(image, f'{class_name}:{conf:.2f}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
44
+ (36, 255, 12), 2)
45
+ detection_results.append(f'Class: {class_name}, Confidence: {conf:.2f}, Box: ({x1}, {y1}), ({x2}, {y2})')
46
 
47
+ # 在图像上显示检测到的物体数量信息
48
+ y_offset = 30
49
+ for class_name, count in class_counts.items():
50
+ cv2.putText(image, f'{class_name}: {count}', (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
51
+ y_offset += 30
52
+
53
+ # 保存处理后的图像
54
+ result_filename = 'result_image.jpg'
55
+ result_path = os.path.join(RESULT_FOLDER, result_filename)
56
+ cv2.imwrite(result_path, image)
57
+
58
+ return image, '\n'.join(detection_results)
59
+
60
+ # 创建Gradio界面
61
+ iface = gr.Interface(
62
+ fn=process_image,
63
+ inputs=gr.Image(type="numpy", label="上传图像"),
64
+ outputs=[gr.Image(type="numpy", label="处理后的图像"), gr.Textbox(label="检测结果")],
65
+ title="YOLOv8 图像检测",
66
+ description="上传图像并使用YOLOv8模型进行检测"
67
+ )
68
+
69
+ # 启动Gradio应用
70
+ iface.launch(
71
+ share=True,
72
+ server_name="0.0.0.0",
73
+ server_port=7860,
74
+ debug=True,
75
+ #auth=("username", "password"),
76
+ auth=("test", "12345"),
77
+ auth_message="Please enter your credentials to access the app.",
78
+ inbrowser=True
79
+ )