Spaces:
Runtime error
Runtime error
Create yolo_detection.py
Browse files- yolo_detection.py +223 -0
yolo_detection.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Initialize YOLO model
|
8 |
+
YOLO_MODEL = YOLO('./best_yolov11.pt')
|
9 |
+
|
10 |
+
def detect_people_and_machinery(media_path):
|
11 |
+
"""Detect people and machinery using YOLOv11 for both images and videos"""
|
12 |
+
try:
|
13 |
+
# Initialize counters with maximum values
|
14 |
+
max_people_count = 0
|
15 |
+
max_machine_types = {
|
16 |
+
"Tower Crane": 0,
|
17 |
+
"Mobile Crane": 0,
|
18 |
+
"Compactor/Roller": 0,
|
19 |
+
"Bulldozer": 0,
|
20 |
+
"Excavator": 0,
|
21 |
+
"Dump Truck": 0,
|
22 |
+
"Concrete Mixer": 0,
|
23 |
+
"Loader": 0,
|
24 |
+
"Pump Truck": 0,
|
25 |
+
"Pile Driver": 0,
|
26 |
+
"Grader": 0,
|
27 |
+
"Other Vehicle": 0
|
28 |
+
}
|
29 |
+
|
30 |
+
# Check if input is video
|
31 |
+
if isinstance(media_path, str) and is_video(media_path):
|
32 |
+
cap = cv2.VideoCapture(media_path)
|
33 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
34 |
+
sample_rate = max(1, int(fps)) # Sample 1 frame per second
|
35 |
+
frame_count = 0 # Initialize frame counter
|
36 |
+
|
37 |
+
while cap.isOpened():
|
38 |
+
ret, frame = cap.read()
|
39 |
+
if not ret:
|
40 |
+
break
|
41 |
+
|
42 |
+
# Process every nth frame based on sample rate
|
43 |
+
if frame_count % sample_rate == 0:
|
44 |
+
results = YOLO_MODEL(frame)
|
45 |
+
people, _, machine_types = process_yolo_results(results)
|
46 |
+
|
47 |
+
# Update maximum counts
|
48 |
+
max_people_count = max(max_people_count, people)
|
49 |
+
for k, v in machine_types.items():
|
50 |
+
max_machine_types[k] = max(max_machine_types[k], v)
|
51 |
+
|
52 |
+
frame_count += 1
|
53 |
+
|
54 |
+
cap.release()
|
55 |
+
|
56 |
+
else:
|
57 |
+
# Handle single image
|
58 |
+
if isinstance(media_path, str):
|
59 |
+
img = cv2.imread(media_path)
|
60 |
+
else:
|
61 |
+
# Handle PIL Image
|
62 |
+
img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR)
|
63 |
+
|
64 |
+
results = YOLO_MODEL(img)
|
65 |
+
max_people_count, _, max_machine_types = process_yolo_results(results)
|
66 |
+
|
67 |
+
# Filter out machinery types with zero count
|
68 |
+
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
|
69 |
+
total_machinery_count = sum(max_machine_types.values())
|
70 |
+
|
71 |
+
return max_people_count, total_machinery_count, max_machine_types
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error in YOLO detection: {str(e)}")
|
75 |
+
return 0, 0, {}
|
76 |
+
|
77 |
+
def process_yolo_results(results):
|
78 |
+
"""Process YOLO detection results and count people and machinery"""
|
79 |
+
people_count = 0
|
80 |
+
machine_types = {
|
81 |
+
"Tower Crane": 0,
|
82 |
+
"Mobile Crane": 0,
|
83 |
+
"Compactor/Roller": 0,
|
84 |
+
"Bulldozer": 0,
|
85 |
+
"Excavator": 0,
|
86 |
+
"Dump Truck": 0,
|
87 |
+
"Concrete Mixer": 0,
|
88 |
+
"Loader": 0,
|
89 |
+
"Pump Truck": 0,
|
90 |
+
"Pile Driver": 0,
|
91 |
+
"Grader": 0,
|
92 |
+
"Other Vehicle": 0
|
93 |
+
}
|
94 |
+
|
95 |
+
# Process detection results
|
96 |
+
for r in results:
|
97 |
+
boxes = r.boxes
|
98 |
+
for box in boxes:
|
99 |
+
cls = int(box.cls[0])
|
100 |
+
conf = float(box.conf[0])
|
101 |
+
class_name = YOLO_MODEL.names[cls]
|
102 |
+
|
103 |
+
# Count people (Worker class)
|
104 |
+
if class_name.lower() == 'worker' and conf > 0.5:
|
105 |
+
people_count += 1
|
106 |
+
|
107 |
+
# Map YOLO classes to machinery types
|
108 |
+
machinery_mapping = {
|
109 |
+
'tower_crane': "Tower Crane",
|
110 |
+
'mobile_crane': "Mobile Crane",
|
111 |
+
'compactor': "Compactor/Roller",
|
112 |
+
'roller': "Compactor/Roller",
|
113 |
+
'bulldozer': "Bulldozer",
|
114 |
+
'dozer': "Bulldozer",
|
115 |
+
'excavator': "Excavator",
|
116 |
+
'dump_truck': "Dump Truck",
|
117 |
+
'truck': "Dump Truck",
|
118 |
+
'concrete_mixer_truck': "Concrete Mixer",
|
119 |
+
'loader': "Loader",
|
120 |
+
'pump_truck': "Pump Truck",
|
121 |
+
'pile_driver': "Pile Driver",
|
122 |
+
'grader': "Grader",
|
123 |
+
'other_vehicle': "Other Vehicle"
|
124 |
+
}
|
125 |
+
|
126 |
+
# Count machinery
|
127 |
+
if conf > 0.5:
|
128 |
+
class_lower = class_name.lower()
|
129 |
+
for key, value in machinery_mapping.items():
|
130 |
+
if key in class_lower:
|
131 |
+
machine_types[value] += 1
|
132 |
+
break
|
133 |
+
|
134 |
+
total_machinery = sum(machine_types.values())
|
135 |
+
return people_count, total_machinery, machine_types
|
136 |
+
|
137 |
+
def annotate_video_with_bboxes(video_path):
|
138 |
+
"""
|
139 |
+
Reads the entire video frame-by-frame, runs YOLO, draws bounding boxes,
|
140 |
+
writes a per-frame summary of detected classes on the frame, and saves
|
141 |
+
as a new annotated video. Returns: annotated_video_path
|
142 |
+
"""
|
143 |
+
cap = cv2.VideoCapture(video_path)
|
144 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
145 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
146 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
147 |
+
|
148 |
+
# Create a temp file for output
|
149 |
+
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
150 |
+
annotated_video_path = out_file.name
|
151 |
+
out_file.close()
|
152 |
+
|
153 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
154 |
+
writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h))
|
155 |
+
|
156 |
+
while True:
|
157 |
+
ret, frame = cap.read()
|
158 |
+
if not ret:
|
159 |
+
break
|
160 |
+
|
161 |
+
results = YOLO_MODEL(frame)
|
162 |
+
|
163 |
+
# Dictionary to hold per-frame counts of each class
|
164 |
+
frame_counts = {}
|
165 |
+
|
166 |
+
for r in results:
|
167 |
+
boxes = r.boxes
|
168 |
+
for box in boxes:
|
169 |
+
cls_id = int(box.cls[0])
|
170 |
+
conf = float(box.conf[0])
|
171 |
+
if conf < 0.5:
|
172 |
+
continue # Skip low-confidence
|
173 |
+
|
174 |
+
x1, y1, x2, y2 = box.xyxy[0]
|
175 |
+
class_name = YOLO_MODEL.names[cls_id]
|
176 |
+
|
177 |
+
# Convert to int
|
178 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
179 |
+
|
180 |
+
# Draw bounding box
|
181 |
+
color = (0, 255, 0)
|
182 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
183 |
+
|
184 |
+
label_text = f"{class_name} {conf:.2f}"
|
185 |
+
cv2.putText(frame, label_text, (x1, y1 - 6),
|
186 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
|
187 |
+
|
188 |
+
# Increment per-frame class count
|
189 |
+
frame_counts[class_name] = frame_counts.get(class_name, 0) + 1
|
190 |
+
|
191 |
+
# Build a summary line, e.g. "Worker: 2, Excavator: 1, ..."
|
192 |
+
summary_str = ", ".join(f"{cls_name}: {count}"
|
193 |
+
for cls_name, count in frame_counts.items())
|
194 |
+
|
195 |
+
# Put the summary text in the top-left
|
196 |
+
cv2.putText(
|
197 |
+
frame,
|
198 |
+
summary_str,
|
199 |
+
(15, 30), # position
|
200 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
201 |
+
1.0,
|
202 |
+
(255, 255, 0),
|
203 |
+
2
|
204 |
+
)
|
205 |
+
|
206 |
+
writer.write(frame)
|
207 |
+
|
208 |
+
cap.release()
|
209 |
+
writer.release()
|
210 |
+
return annotated_video_path
|
211 |
+
|
212 |
+
# File type validation
|
213 |
+
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
214 |
+
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
215 |
+
|
216 |
+
def get_file_extension(filename):
|
217 |
+
return os.path.splitext(filename)[1].lower()
|
218 |
+
|
219 |
+
def is_image(filename):
|
220 |
+
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
221 |
+
|
222 |
+
def is_video(filename):
|
223 |
+
return get_file_extension(filename) in VIDEO_EXTENSIONS
|