CatmanJr commited on
Commit
39aad74
·
verified ·
1 Parent(s): 263cc8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from ultralytics import YOLO
5
+
6
+ # 加载YOLO模型
7
+ model = YOLO('yolo11s-earth.pt') # 加载你的模型
8
+
9
+ # 默认类别
10
+ default_classes = [
11
+ 'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
+ 'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
13
+ 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
14
+ 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
15
+ 'windmill'
16
+ ]
17
+
18
+ def process_frame(frame, classes_input):
19
+ # 将输入的类别字符串转为列表
20
+ if classes_input:
21
+ classes_list = [cls.strip() for cls in classes_input.split(',')]
22
+ model.set_classes(classes_list) # 设置模型的类别
23
+ else:
24
+ # 如果没有输入,则使用默认类别
25
+ model.set_classes(default_classes)
26
+
27
+ # 复制帧为可写数组
28
+ frame = frame.copy()
29
+
30
+ # 转换图像格式
31
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32
+
33
+ # 使用模型进行检测
34
+ results = model.predict(rgb_frame)
35
+
36
+ # 绘制检测结果
37
+ for result in results:
38
+ boxes = result.boxes
39
+ for box in boxes:
40
+ x1, y1, x2, y2 = box.xyxy[0]
41
+ conf = box.conf[0]
42
+ cls = box.cls[0]
43
+ class_name = model.names[int(cls)]
44
+
45
+ # 绘制边界框和标签
46
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
47
+ cv2.putText(frame, f'{class_name}:{conf:.2f}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
48
+
49
+ return frame
50
+
51
+ def main():
52
+ # 创建Gradio界面
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# YOLOv11s-Earth 实时检测")
55
+ with gr.Row():
56
+ cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="摄像头输入")
57
+ classes_input = gr.Textbox(label="输入类别(逗号分隔)", placeholder="例如:0,1,2")
58
+ output = gr.Image(label="检测结果", type="numpy")
59
+
60
+ cam_input.stream(
61
+ process_frame,
62
+ inputs=[cam_input, classes_input],
63
+ outputs=output
64
+ )
65
+
66
+ # 启动Gradio应用
67
+ demo.launch()
68
+
69
+ if __name__ == "__main__":
70
+ main()