CatmanJr commited on
Commit
6aee1db
·
verified ·
1 Parent(s): 4982ed5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -4
app.py CHANGED
@@ -3,6 +3,11 @@ import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
 
 
 
 
 
6
  # 加载YOLO模型
7
  model = YOLO('yolo11s-earth.pt') # 加载你的模型
8
 
@@ -15,6 +20,81 @@ default_classes = [
15
  'windmill'
16
  ]
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def process_frame(frame, classes_input):
19
  # 将输入的类别字符串转为列表
20
  if classes_input:
@@ -51,11 +131,11 @@ def process_frame(frame, classes_input):
51
  def main():
52
  # 创建Gradio界面
53
  with gr.Blocks() as demo:
54
- gr.Markdown("# YOLOv11s-Earth 实时检测")
55
  with gr.Row():
56
- cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="摄像头输入")
57
- classes_input = gr.Textbox(label="输入类别(逗号分隔)", placeholder="例如:0,1,2")
58
- output = gr.Image(label="检测结果", type="numpy")
59
 
60
  cam_input.stream(
61
  process_frame,
 
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
+ # 加载YOLO模型import gradio as gr
7
+ import cv2
8
+ import numpy as np
9
+ from ultralytics import YOLO
10
+
11
  # 加载YOLO模型
12
  model = YOLO('yolo11s-earth.pt') # 加载你的模型
13
 
 
20
  'windmill'
21
  ]
22
 
23
+ def process_frame(frame, classes_input):
24
+ # 处理用户输入的类别
25
+ if classes_input and classes_input.strip():
26
+ classes_list = [cls.strip() for cls in classes_input.split(',')]
27
+ model.set_classes(classes_list) # 设置模型的类别
28
+ else:
29
+ # 如果没有输入或输入为空,则使用默认类别
30
+ model.set_classes(default_classes)
31
+
32
+ # 复制帧为可写数组
33
+ frame = frame.copy()
34
+
35
+ # 调整图像大小以加快处理速度(可选)
36
+ h, w = frame.shape[:2]
37
+ new_size = (1024, int(h * (1024 / w))) if w > h else (int(w * (1024 / h)), 1024)
38
+ resized_frame = cv2.resize(frame, new_size)
39
+
40
+ # 转换图像格式
41
+ rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
42
+
43
+ # 使用模型进行检测
44
+ results = model.predict(rgb_frame)
45
+
46
+ # 绘制检测结果
47
+ for result in results:
48
+ boxes = result.boxes
49
+ for box in boxes:
50
+ x1, y1, x2, y2 = box.xyxy[0]
51
+ conf = box.conf[0]
52
+ cls = box.cls[0]
53
+ class_name = model.names[int(cls)]
54
+
55
+ # 调整坐标到原始图像大小
56
+ x1 = int(x1 * w / new_size[0])
57
+ y1 = int(y1 * h / new_size[1])
58
+ x2 = int(x2 * w / new_size[0])
59
+ y2 = int(y2 * h / new_size[1])
60
+
61
+ # 绘制边界框和标签
62
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
63
+ cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
64
+
65
+ return frame
66
+
67
+ def main():
68
+ # 创建Gradio界面
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("# YOLO11s-Earth open vocabulary detection(DIOR finetuning)")
71
+ with gr.Row():
72
+ cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
73
+ classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
74
+ output = gr.Image(label="results", type="numpy")
75
+
76
+ cam_input.stream(
77
+ process_frame,
78
+ inputs=[cam_input, classes_input],
79
+ outputs=output
80
+ )
81
+
82
+ # 启动Gradio应用
83
+ demo.launch()
84
+
85
+ if __name__ == "__main__":
86
+ main()
87
+ model = YOLO('yolo11s-earth.pt') # 加载你的模型
88
+
89
+ # 默认类别
90
+ default_classes = [
91
+ 'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
92
+ 'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
93
+ 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
94
+ 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
95
+ 'windmill'
96
+ ]
97
+
98
  def process_frame(frame, classes_input):
99
  # 将输入的类别字符串转为列表
100
  if classes_input:
 
131
  def main():
132
  # 创建Gradio界面
133
  with gr.Blocks() as demo:
134
+ gr.Markdown("# YOLO11s-Earth open vocabulary detection(DIOR finetuning)")
135
  with gr.Row():
136
+ cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
137
+ classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
138
+ output = gr.Image(label="results", type="numpy", height=800)
139
 
140
  cam_input.stream(
141
  process_frame,