CatmanJr commited on
Commit
78efd15
·
verified ·
1 Parent(s): 5a6ed3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -3,10 +3,10 @@ import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
- # 加载YOLO模型
7
- model = YOLO('yolo11s-earth.pt') # 加载你的模型
8
 
9
- # 默认类别
10
  default_classes = [
11
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
@@ -16,41 +16,62 @@ default_classes = [
16
  ]
17
 
18
  def process_frame(frame, classes_input):
19
- # 将输入的类别字符串转为列表
20
- if classes_input:
21
  classes_list = [cls.strip() for cls in classes_input.split(',')]
22
- model.set_classes(classes_list) # 设置模型的类别
 
 
 
 
 
23
  else:
24
- # 如果没有输入,则使用默认类别
25
  model.set_classes(default_classes)
26
 
27
- # 复制帧为可写数组
28
  frame = frame.copy()
29
 
30
- # 转换图像格式
31
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
32
 
33
- # 使用模型进行检测
34
- results = model.predict(rgb_frame, imgsz=800)
35
 
36
- # 绘制检测结果
 
 
 
37
  for result in results:
38
  boxes = result.boxes
39
  for box in boxes:
40
  x1, y1, x2, y2 = box.xyxy[0]
41
  conf = box.conf[0]
42
  cls = box.cls[0]
43
- class_name = model.names[int(cls)]
 
 
 
 
 
 
 
 
 
 
44
 
45
- # 绘制边界框和标签
46
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
47
- cv2.putText(frame, f'{class_name}:{conf:.2f}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
48
 
49
  return frame
50
 
51
  def main():
52
- # 创建Gradio界面
53
  with gr.Blocks() as demo:
 
54
  with gr.Row():
55
  cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
56
  classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
@@ -62,7 +83,7 @@ def main():
62
  outputs=output
63
  )
64
 
65
- # 启动Gradio应用
66
  demo.launch()
67
 
68
  if __name__ == "__main__":
 
3
  import numpy as np
4
  from ultralytics import YOLO
5
 
6
+ # Load YOLO model
7
+ model = YOLO('yolo11s-earth.pt') # Load your model
8
 
9
+ # Default classes
10
  default_classes = [
11
  'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
12
  'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
 
16
  ]
17
 
18
  def process_frame(frame, classes_input):
19
+ # Process user input classes
20
+ if classes_input and classes_input.strip():
21
  classes_list = [cls.strip() for cls in classes_input.split(',')]
22
+ # Validate classes_list
23
+ for cls in classes_list:
24
+ if not isinstance(cls, str):
25
+ print("Invalid class name:", cls)
26
+ continue
27
+ model.set_classes(classes_list) # Set model classes
28
  else:
29
+ # Use default classes if no input or input is empty
30
  model.set_classes(default_classes)
31
 
32
+ # Copy frame to a writable array
33
  frame = frame.copy()
34
 
35
+ # Resize image to speed up processing (optional)
36
+ h, w = frame.shape[:2]
37
+ new_size = (640, int(h * (640 / w))) if w > h else (int(w * (640 / h)), 640)
38
+ resized_frame = cv2.resize(frame, new_size)
39
 
40
+ # Convert image format
41
+ rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
42
 
43
+ # Use model for detection
44
+ results = model.predict(rgb_frame)
45
+
46
+ # Draw detection results
47
  for result in results:
48
  boxes = result.boxes
49
  for box in boxes:
50
  x1, y1, x2, y2 = box.xyxy[0]
51
  conf = box.conf[0]
52
  cls = box.cls[0]
53
+ try:
54
+ class_name = model.names[int(cls)]
55
+ except (IndexError, TypeError) as e:
56
+ print(f"Error accessing model.names: {e}")
57
+ class_name = "Unknown" # Provide a default value
58
+
59
+ # Adjust coordinates to original image size
60
+ x1 = int(x1 * w / new_size[0])
61
+ y1 = int(y1 * h / new_size[1])
62
+ x2 = int(x2 * w / new_size[0])
63
+ y2 = int(y2 * h / new_size[1])
64
 
65
+ # Draw bounding box and label
66
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
67
+ cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
68
 
69
  return frame
70
 
71
  def main():
72
+ # Create Gradio interface
73
  with gr.Blocks() as demo:
74
+ gr.Markdown("# YOLO11s-Earth open vocabulary detection (DIOR finetuning)")
75
  with gr.Row():
76
  cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
77
  classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
 
83
  outputs=output
84
  )
85
 
86
+ # Launch Gradio app
87
  demo.launch()
88
 
89
  if __name__ == "__main__":