assentian1970 commited on
Commit
815427b
·
verified ·
1 Parent(s): b08dda1

Create yolo_detection.py

Browse files
Files changed (1) hide show
  1. yolo_detection.py +223 -0
yolo_detection.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+
7
+ # Initialize YOLO model
8
+ YOLO_MODEL = YOLO('./best_yolov11.pt')
9
+
10
+ def detect_people_and_machinery(media_path):
11
+ """Detect people and machinery using YOLOv11 for both images and videos"""
12
+ try:
13
+ # Initialize counters with maximum values
14
+ max_people_count = 0
15
+ max_machine_types = {
16
+ "Tower Crane": 0,
17
+ "Mobile Crane": 0,
18
+ "Compactor/Roller": 0,
19
+ "Bulldozer": 0,
20
+ "Excavator": 0,
21
+ "Dump Truck": 0,
22
+ "Concrete Mixer": 0,
23
+ "Loader": 0,
24
+ "Pump Truck": 0,
25
+ "Pile Driver": 0,
26
+ "Grader": 0,
27
+ "Other Vehicle": 0
28
+ }
29
+
30
+ # Check if input is video
31
+ if isinstance(media_path, str) and is_video(media_path):
32
+ cap = cv2.VideoCapture(media_path)
33
+ fps = cap.get(cv2.CAP_PROP_FPS)
34
+ sample_rate = max(1, int(fps)) # Sample 1 frame per second
35
+ frame_count = 0 # Initialize frame counter
36
+
37
+ while cap.isOpened():
38
+ ret, frame = cap.read()
39
+ if not ret:
40
+ break
41
+
42
+ # Process every nth frame based on sample rate
43
+ if frame_count % sample_rate == 0:
44
+ results = YOLO_MODEL(frame)
45
+ people, _, machine_types = process_yolo_results(results)
46
+
47
+ # Update maximum counts
48
+ max_people_count = max(max_people_count, people)
49
+ for k, v in machine_types.items():
50
+ max_machine_types[k] = max(max_machine_types[k], v)
51
+
52
+ frame_count += 1
53
+
54
+ cap.release()
55
+
56
+ else:
57
+ # Handle single image
58
+ if isinstance(media_path, str):
59
+ img = cv2.imread(media_path)
60
+ else:
61
+ # Handle PIL Image
62
+ img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR)
63
+
64
+ results = YOLO_MODEL(img)
65
+ max_people_count, _, max_machine_types = process_yolo_results(results)
66
+
67
+ # Filter out machinery types with zero count
68
+ max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
69
+ total_machinery_count = sum(max_machine_types.values())
70
+
71
+ return max_people_count, total_machinery_count, max_machine_types
72
+
73
+ except Exception as e:
74
+ print(f"Error in YOLO detection: {str(e)}")
75
+ return 0, 0, {}
76
+
77
+ def process_yolo_results(results):
78
+ """Process YOLO detection results and count people and machinery"""
79
+ people_count = 0
80
+ machine_types = {
81
+ "Tower Crane": 0,
82
+ "Mobile Crane": 0,
83
+ "Compactor/Roller": 0,
84
+ "Bulldozer": 0,
85
+ "Excavator": 0,
86
+ "Dump Truck": 0,
87
+ "Concrete Mixer": 0,
88
+ "Loader": 0,
89
+ "Pump Truck": 0,
90
+ "Pile Driver": 0,
91
+ "Grader": 0,
92
+ "Other Vehicle": 0
93
+ }
94
+
95
+ # Process detection results
96
+ for r in results:
97
+ boxes = r.boxes
98
+ for box in boxes:
99
+ cls = int(box.cls[0])
100
+ conf = float(box.conf[0])
101
+ class_name = YOLO_MODEL.names[cls]
102
+
103
+ # Count people (Worker class)
104
+ if class_name.lower() == 'worker' and conf > 0.5:
105
+ people_count += 1
106
+
107
+ # Map YOLO classes to machinery types
108
+ machinery_mapping = {
109
+ 'tower_crane': "Tower Crane",
110
+ 'mobile_crane': "Mobile Crane",
111
+ 'compactor': "Compactor/Roller",
112
+ 'roller': "Compactor/Roller",
113
+ 'bulldozer': "Bulldozer",
114
+ 'dozer': "Bulldozer",
115
+ 'excavator': "Excavator",
116
+ 'dump_truck': "Dump Truck",
117
+ 'truck': "Dump Truck",
118
+ 'concrete_mixer_truck': "Concrete Mixer",
119
+ 'loader': "Loader",
120
+ 'pump_truck': "Pump Truck",
121
+ 'pile_driver': "Pile Driver",
122
+ 'grader': "Grader",
123
+ 'other_vehicle': "Other Vehicle"
124
+ }
125
+
126
+ # Count machinery
127
+ if conf > 0.5:
128
+ class_lower = class_name.lower()
129
+ for key, value in machinery_mapping.items():
130
+ if key in class_lower:
131
+ machine_types[value] += 1
132
+ break
133
+
134
+ total_machinery = sum(machine_types.values())
135
+ return people_count, total_machinery, machine_types
136
+
137
+ def annotate_video_with_bboxes(video_path):
138
+ """
139
+ Reads the entire video frame-by-frame, runs YOLO, draws bounding boxes,
140
+ writes a per-frame summary of detected classes on the frame, and saves
141
+ as a new annotated video. Returns: annotated_video_path
142
+ """
143
+ cap = cv2.VideoCapture(video_path)
144
+ fps = cap.get(cv2.CAP_PROP_FPS)
145
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
146
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
147
+
148
+ # Create a temp file for output
149
+ out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
150
+ annotated_video_path = out_file.name
151
+ out_file.close()
152
+
153
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
154
+ writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h))
155
+
156
+ while True:
157
+ ret, frame = cap.read()
158
+ if not ret:
159
+ break
160
+
161
+ results = YOLO_MODEL(frame)
162
+
163
+ # Dictionary to hold per-frame counts of each class
164
+ frame_counts = {}
165
+
166
+ for r in results:
167
+ boxes = r.boxes
168
+ for box in boxes:
169
+ cls_id = int(box.cls[0])
170
+ conf = float(box.conf[0])
171
+ if conf < 0.5:
172
+ continue # Skip low-confidence
173
+
174
+ x1, y1, x2, y2 = box.xyxy[0]
175
+ class_name = YOLO_MODEL.names[cls_id]
176
+
177
+ # Convert to int
178
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
179
+
180
+ # Draw bounding box
181
+ color = (0, 255, 0)
182
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
183
+
184
+ label_text = f"{class_name} {conf:.2f}"
185
+ cv2.putText(frame, label_text, (x1, y1 - 6),
186
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
187
+
188
+ # Increment per-frame class count
189
+ frame_counts[class_name] = frame_counts.get(class_name, 0) + 1
190
+
191
+ # Build a summary line, e.g. "Worker: 2, Excavator: 1, ..."
192
+ summary_str = ", ".join(f"{cls_name}: {count}"
193
+ for cls_name, count in frame_counts.items())
194
+
195
+ # Put the summary text in the top-left
196
+ cv2.putText(
197
+ frame,
198
+ summary_str,
199
+ (15, 30), # position
200
+ cv2.FONT_HERSHEY_SIMPLEX,
201
+ 1.0,
202
+ (255, 255, 0),
203
+ 2
204
+ )
205
+
206
+ writer.write(frame)
207
+
208
+ cap.release()
209
+ writer.release()
210
+ return annotated_video_path
211
+
212
+ # File type validation
213
+ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
214
+ VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
215
+
216
+ def get_file_extension(filename):
217
+ return os.path.splitext(filename)[1].lower()
218
+
219
+ def is_image(filename):
220
+ return get_file_extension(filename) in IMAGE_EXTENSIONS
221
+
222
+ def is_video(filename):
223
+ return get_file_extension(filename) in VIDEO_EXTENSIONS