CatmanJr commited on
Commit
903ab5f
·
verified ·
1 Parent(s): 6aee1db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -88
app.py CHANGED
@@ -3,15 +3,10 @@ import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
- # 加载YOLO模型import gradio as gr
7
- import cv2
8
- import numpy as np
9
- from ultralytics import YOLO
10
 
11
- # 加载YOLO模型
12
- model = YOLO('yolo11s-earth.pt') # 加载你的模型
13
-
14
- # 默认类别
15
  default_classes = [
16
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
17
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
@@ -21,29 +16,29 @@ default_classes = [
21
  ]
22
 
23
  def process_frame(frame, classes_input):
24
- # 处理用户输入的类别
25
  if classes_input and classes_input.strip():
26
  classes_list = [cls.strip() for cls in classes_input.split(',')]
27
- model.set_classes(classes_list) # 设置模型的类别
28
  else:
29
- # 如果没有输入或输入为空,则使用默认类别
30
  model.set_classes(default_classes)
31
 
32
- # 复制帧为可写数组
33
  frame = frame.copy()
34
 
35
- # 调整图像大小以加快处理速度(可选)
36
  h, w = frame.shape[:2]
37
- new_size = (1024, int(h * (1024 / w))) if w > h else (int(w * (1024 / h)), 1024)
38
  resized_frame = cv2.resize(frame, new_size)
39
 
40
- # 转换图像格式
41
  rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
42
 
43
- # 使用模型进行检测
44
  results = model.predict(rgb_frame)
45
 
46
- # 绘制检测结果
47
  for result in results:
48
  boxes = result.boxes
49
  for box in boxes:
@@ -52,90 +47,26 @@ def process_frame(frame, classes_input):
52
  cls = box.cls[0]
53
  class_name = model.names[int(cls)]
54
 
55
- # 调整坐标到原始图像大小
56
  x1 = int(x1 * w / new_size[0])
57
  y1 = int(y1 * h / new_size[1])
58
  x2 = int(x2 * w / new_size[0])
59
  y2 = int(y2 * h / new_size[1])
60
 
61
- # 绘制边界框和标签
62
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
63
  cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
64
 
65
  return frame
66
 
67
  def main():
68
- # 创建Gradio界面
69
- with gr.Blocks() as demo:
70
- gr.Markdown("# YOLO11s-Earth open vocabulary detection(DIOR finetuning)")
71
- with gr.Row():
72
- cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
73
- classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
74
- output = gr.Image(label="results", type="numpy")
75
-
76
- cam_input.stream(
77
- process_frame,
78
- inputs=[cam_input, classes_input],
79
- outputs=output
80
- )
81
-
82
- # 启动Gradio应用
83
- demo.launch()
84
-
85
- if __name__ == "__main__":
86
- main()
87
- model = YOLO('yolo11s-earth.pt') # 加载你的模型
88
-
89
- # 默认类别
90
- default_classes = [
91
- 'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
92
- 'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
93
- 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
94
- 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
95
- 'windmill'
96
- ]
97
-
98
- def process_frame(frame, classes_input):
99
- # 将输入的类别字符串转为列表
100
- if classes_input:
101
- classes_list = [cls.strip() for cls in classes_input.split(',')]
102
- model.set_classes(classes_list) # 设置模型的类别
103
- else:
104
- # 如果没有输入,则使用默认类别
105
- model.set_classes(default_classes)
106
-
107
- # 复制帧为可写数组
108
- frame = frame.copy()
109
-
110
- # 转换图像格式
111
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
112
-
113
- # 使用模型进行检测
114
- results = model.predict(rgb_frame)
115
-
116
- # 绘制检测结果
117
- for result in results:
118
- boxes = result.boxes
119
- for box in boxes:
120
- x1, y1, x2, y2 = box.xyxy[0]
121
- conf = box.conf[0]
122
- cls = box.cls[0]
123
- class_name = model.names[int(cls)]
124
-
125
- # 绘制边界框和标签
126
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
127
- cv2.putText(frame, f'{class_name}:{conf:.2f}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
128
-
129
- return frame
130
-
131
- def main():
132
- # 创建Gradio界面
133
  with gr.Blocks() as demo:
134
- gr.Markdown("# YOLO11s-Earth open vocabulary detectionDIOR finetuning")
135
  with gr.Row():
136
  cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
137
- classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
138
- output = gr.Image(label="results", type="numpy", height=800)
139
 
140
  cam_input.stream(
141
  process_frame,
@@ -143,7 +74,7 @@ def main():
143
  outputs=output
144
  )
145
 
146
- # 启动Gradio应用
147
  demo.launch()
148
 
149
  if __name__ == "__main__":
 
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
+ # Load YOLO model
7
+ model = YOLO('yolo11s-earth.pt') # Load your model
 
 
8
 
9
+ # Default classes
 
 
 
10
  default_classes = [
11
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
 
16
  ]
17
 
18
  def process_frame(frame, classes_input):
19
+ # Process user input classes
20
  if classes_input and classes_input.strip():
21
  classes_list = [cls.strip() for cls in classes_input.split(',')]
22
+ model.set_classes(classes_list) # Set model classes
23
  else:
24
+ # Use default classes if no input or input is empty
25
  model.set_classes(default_classes)
26
 
27
+ # Copy frame to a writable array
28
  frame = frame.copy()
29
 
30
+ # Resize image to speed up processing (optional)
31
  h, w = frame.shape[:2]
32
+ new_size = (640, int(h * (640 / w))) if w > h else (int(w * (640 / h)), 640)
33
  resized_frame = cv2.resize(frame, new_size)
34
 
35
+ # Convert image format
36
  rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
37
 
38
+ # Use model for detection
39
  results = model.predict(rgb_frame)
40
 
41
+ # Draw detection results
42
  for result in results:
43
  boxes = result.boxes
44
  for box in boxes:
 
47
  cls = box.cls[0]
48
  class_name = model.names[int(cls)]
49
 
50
+ # Adjust coordinates to original image size
51
  x1 = int(x1 * w / new_size[0])
52
  y1 = int(y1 * h / new_size[1])
53
  x2 = int(x2 * w / new_size[0])
54
  y2 = int(y2 * h / new_size[1])
55
 
56
+ # Draw bounding box and label
57
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
58
  cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
59
 
60
  return frame
61
 
62
  def main():
63
+ # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with gr.Blocks() as demo:
65
+ gr.Markdown("# YOLO11s-Earth open vocabulary detection (DIOR finetuning)")
66
  with gr.Row():
67
  cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
68
+ classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
69
+ output = gr.Image(label="Results", type="numpy", height=480) # Set height to 480
70
 
71
  cam_input.stream(
72
  process_frame,
 
74
  outputs=output
75
  )
76
 
77
+ # Launch Gradio app
78
  demo.launch()
79
 
80
  if __name__ == "__main__":