Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,11 @@ import cv2
|
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
# 加载YOLO模型
|
7 |
model = YOLO('yolo11s-earth.pt') # 加载你的模型
|
8 |
|
@@ -15,6 +20,81 @@ default_classes = [
|
|
15 |
'windmill'
|
16 |
]
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def process_frame(frame, classes_input):
|
19 |
# 将输入的类别字符串转为列表
|
20 |
if classes_input:
|
@@ -51,11 +131,11 @@ def process_frame(frame, classes_input):
|
|
51 |
def main():
|
52 |
# 创建Gradio界面
|
53 |
with gr.Blocks() as demo:
|
54 |
-
gr.Markdown("#
|
55 |
with gr.Row():
|
56 |
-
cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="
|
57 |
-
classes_input = gr.Textbox(label="
|
58 |
-
output = gr.Image(label="
|
59 |
|
60 |
cam_input.stream(
|
61 |
process_frame,
|
|
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
5 |
|
6 |
+
# 加载YOLO模型import gradio as gr
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from ultralytics import YOLO
|
10 |
+
|
11 |
# 加载YOLO模型
|
12 |
model = YOLO('yolo11s-earth.pt') # 加载你的模型
|
13 |
|
|
|
20 |
'windmill'
|
21 |
]
|
22 |
|
23 |
+
def process_frame(frame, classes_input):
|
24 |
+
# 处理用户输入的类别
|
25 |
+
if classes_input and classes_input.strip():
|
26 |
+
classes_list = [cls.strip() for cls in classes_input.split(',')]
|
27 |
+
model.set_classes(classes_list) # 设置模型的类别
|
28 |
+
else:
|
29 |
+
# 如果没有输入或输入为空,则使用默认类别
|
30 |
+
model.set_classes(default_classes)
|
31 |
+
|
32 |
+
# 复制帧为可写数组
|
33 |
+
frame = frame.copy()
|
34 |
+
|
35 |
+
# 调整图像大小以加快处理速度(可选)
|
36 |
+
h, w = frame.shape[:2]
|
37 |
+
new_size = (1024, int(h * (1024 / w))) if w > h else (int(w * (1024 / h)), 1024)
|
38 |
+
resized_frame = cv2.resize(frame, new_size)
|
39 |
+
|
40 |
+
# 转换图像格式
|
41 |
+
rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
|
42 |
+
|
43 |
+
# 使用模型进行检测
|
44 |
+
results = model.predict(rgb_frame)
|
45 |
+
|
46 |
+
# 绘制检测结果
|
47 |
+
for result in results:
|
48 |
+
boxes = result.boxes
|
49 |
+
for box in boxes:
|
50 |
+
x1, y1, x2, y2 = box.xyxy[0]
|
51 |
+
conf = box.conf[0]
|
52 |
+
cls = box.cls[0]
|
53 |
+
class_name = model.names[int(cls)]
|
54 |
+
|
55 |
+
# 调整坐标到原始图像大小
|
56 |
+
x1 = int(x1 * w / new_size[0])
|
57 |
+
y1 = int(y1 * h / new_size[1])
|
58 |
+
x2 = int(x2 * w / new_size[0])
|
59 |
+
y2 = int(y2 * h / new_size[1])
|
60 |
+
|
61 |
+
# 绘制边界框和标签
|
62 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
63 |
+
cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
|
64 |
+
|
65 |
+
return frame
|
66 |
+
|
67 |
+
def main():
|
68 |
+
# 创建Gradio界面
|
69 |
+
with gr.Blocks() as demo:
|
70 |
+
gr.Markdown("# YOLO11s-Earth open vocabulary detection(DIOR finetuning)")
|
71 |
+
with gr.Row():
|
72 |
+
cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
|
73 |
+
classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
|
74 |
+
output = gr.Image(label="results", type="numpy")
|
75 |
+
|
76 |
+
cam_input.stream(
|
77 |
+
process_frame,
|
78 |
+
inputs=[cam_input, classes_input],
|
79 |
+
outputs=output
|
80 |
+
)
|
81 |
+
|
82 |
+
# 启动Gradio应用
|
83 |
+
demo.launch()
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main()
|
87 |
+
model = YOLO('yolo11s-earth.pt') # 加载你的模型
|
88 |
+
|
89 |
+
# 默认类别
|
90 |
+
default_classes = [
|
91 |
+
'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
|
92 |
+
'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
|
93 |
+
'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
|
94 |
+
'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
|
95 |
+
'windmill'
|
96 |
+
]
|
97 |
+
|
98 |
def process_frame(frame, classes_input):
|
99 |
# 将输入的类别字符串转为列表
|
100 |
if classes_input:
|
|
|
131 |
def main():
|
132 |
# 创建Gradio界面
|
133 |
with gr.Blocks() as demo:
|
134 |
+
gr.Markdown("# YOLO11s-Earth open vocabulary detection(DIOR finetuning)")
|
135 |
with gr.Row():
|
136 |
+
cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
|
137 |
+
classes_input = gr.Textbox(label="new classes(逗号分隔)", placeholder="exp:airplane, airport, tennis court")
|
138 |
+
output = gr.Image(label="results", type="numpy", height=800)
|
139 |
|
140 |
cam_input.stream(
|
141 |
process_frame,
|