CatmanJr commited on
Commit
b059d1f
·
verified ·
1 Parent(s): 903ab5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -33
app.py CHANGED
@@ -3,10 +3,10 @@ import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
- # Load YOLO model
7
- model = YOLO('yolo11s-earth.pt') # Load your model
8
 
9
- # Default classes
10
  default_classes = [
11
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
@@ -16,29 +16,24 @@ default_classes = [
16
  ]
17
 
18
  def process_frame(frame, classes_input):
19
- # Process user input classes
20
- if classes_input and classes_input.strip():
21
  classes_list = [cls.strip() for cls in classes_input.split(',')]
22
- model.set_classes(classes_list) # Set model classes
23
  else:
24
- # Use default classes if no input or input is empty
25
  model.set_classes(default_classes)
26
 
27
- # Copy frame to a writable array
28
  frame = frame.copy()
29
 
30
- # Resize image to speed up processing (optional)
31
- h, w = frame.shape[:2]
32
- new_size = (640, int(h * (640 / w))) if w > h else (int(w * (640 / h)), 640)
33
- resized_frame = cv2.resize(frame, new_size)
34
 
35
- # Convert image format
36
- rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
37
 
38
- # Use model for detection
39
- results = model.predict(rgb_frame)
40
-
41
- # Draw detection results
42
  for result in results:
43
  boxes = result.boxes
44
  for box in boxes:
@@ -47,26 +42,20 @@ def process_frame(frame, classes_input):
47
  cls = box.cls[0]
48
  class_name = model.names[int(cls)]
49
 
50
- # Adjust coordinates to original image size
51
- x1 = int(x1 * w / new_size[0])
52
- y1 = int(y1 * h / new_size[1])
53
- x2 = int(x2 * w / new_size[0])
54
- y2 = int(y2 * h / new_size[1])
55
-
56
- # Draw bounding box and label
57
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
58
- cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
59
 
60
  return frame
61
 
62
  def main():
63
- # Create Gradio interface
64
  with gr.Blocks() as demo:
65
- gr.Markdown("# YOLO11s-Earth open vocabulary detection (DIOR finetuning)")
66
  with gr.Row():
67
- cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
68
- classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
69
- output = gr.Image(label="Results", type="numpy", height=480) # Set height to 480
70
 
71
  cam_input.stream(
72
  process_frame,
@@ -74,7 +63,7 @@ def main():
74
  outputs=output
75
  )
76
 
77
- # Launch Gradio app
78
  demo.launch()
79
 
80
  if __name__ == "__main__":
 
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
+ # 加载YOLO模型
7
+ model = YOLO('yolo11s-earth.pt') # 加载你的模型
8
 
9
+ # 默认类别
10
  default_classes = [
11
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
 
16
  ]
17
 
18
  def process_frame(frame, classes_input):
19
+ # 将输入的类别字符串转为列表
20
+ if classes_input:
21
  classes_list = [cls.strip() for cls in classes_input.split(',')]
22
+ model.set_classes(classes_list) # 设置模型的类别
23
  else:
24
+ # 如果没有输入,则使用默认类别
25
  model.set_classes(default_classes)
26
 
27
+ # 复制帧为可写数组
28
  frame = frame.copy()
29
 
30
+ # 转换图像格式
31
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
32
 
33
+ # 使用模型进行检测
34
+ results = model.predict(rgb_frame, imgsz=800)
35
 
36
+ # 绘制检测结果
 
 
 
37
  for result in results:
38
  boxes = result.boxes
39
  for box in boxes:
 
42
  cls = box.cls[0]
43
  class_name = model.names[int(cls)]
44
 
45
+ # 绘制边界框和标签
46
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
47
+ cv2.putText(frame, f'{class_name}:{conf:.2f}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
 
 
 
 
 
 
48
 
49
  return frame
50
 
51
  def main():
52
+ # 创建Gradio界面
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("# YOLOv11s-Earth 实时检测")
55
  with gr.Row():
56
+ cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="摄像头输入")
57
+ classes_input = gr.Textbox(label="输入类别(逗号分隔)", placeholder="例如:0,1,2")
58
+ output = gr.Image(label="检测结果", type="numpy")
59
 
60
  cam_input.stream(
61
  process_frame,
 
63
  outputs=output
64
  )
65
 
66
+ # 启动Gradio应用
67
  demo.launch()
68
 
69
  if __name__ == "__main__":