asvs's picture
somewhat working commit of people counter
d1424b3
raw
history blame
5.17 kB
import cv2
import numpy as np
from ultralytics import YOLO
from collections import defaultdict
import argparse
class PersonCounter:
def __init__(self, line_position=0.5):
"""Initialize person counter.
Args:
line_position (float): Virtual line position as fraction of frame height (0-1)
"""
self.model = YOLO("yolov8n.pt") # Load pretrained YOLOv8 model
self.tracker = defaultdict(list) # Track object IDs
self.crossed_ids = set() # Store IDs that have crossed the line
self.line_position = line_position
self.count = 0
def _calculate_center(self, box):
"""Calculate center point of detection box."""
x1, y1, x2, y2 = box
return (x1 + x2) / 2, (y1 + y2) / 2
def process_frame(self, frame):
"""Process a single frame and update count.
Args:
frame: Input frame from video
Returns:
frame: Annotated frame
count: Current count of people who entered
"""
height, width = frame.shape[:2]
line_y = int(height * self.line_position)
# Draw counting line
cv2.line(frame, (0, line_y), (width, line_y), (0, 255, 0), 2)
# Run detection and tracking
results = self.model.track(frame, persist=True, classes=[0]) # class 0 is person
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().numpy()
track_ids = results[0].boxes.id.cpu().numpy().astype(int)
# Process each detection
for box, track_id in zip(boxes, track_ids):
# Draw bounding box
cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
(255, 0, 0), 2)
# Get center point of the bottom edge of the box (feet position)
center_x = (box[0] + box[2]) / 2
feet_y = box[3] # Bottom of the bounding box
# Draw tracking point
cv2.circle(frame, (int(center_x), int(feet_y)), 5, (0, 255, 255), -1)
# Store tracking history
if track_id in self.tracker:
prev_y = self.tracker[track_id][-1]
# Check if person has crossed the line (moving down)
if prev_y < line_y and feet_y >= line_y and track_id not in self.crossed_ids:
self.crossed_ids.add(track_id)
self.count += 1
# Draw crossing indicator
cv2.circle(frame, (int(center_x), int(line_y)), 8, (0, 0, 255), -1)
# Update tracking history
self.tracker[track_id] = [feet_y] # Only store current position
# Draw count with bigger font and background
count_text = f"Count: {self.count}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.5
thickness = 3
(text_width, text_height), _ = cv2.getTextSize(count_text, font, font_scale, thickness)
# Draw background rectangle
cv2.rectangle(frame, (10, 10), (20 + text_width, 20 + text_height),
(0, 0, 0), -1)
# Draw text
cv2.putText(frame, count_text, (15, 15 + text_height),
font, font_scale, (0, 255, 0), thickness)
return frame, self.count
def main():
parser = argparse.ArgumentParser(description='Count people entering through a line in video.')
parser.add_argument('video_path', help='Path to input video file')
parser.add_argument('--line-position', type=float, default=0.5,
help='Position of counting line (0-1, fraction of frame height)')
parser.add_argument('--output', default='result.mp4', help='Path to output video file (default: result.mp4)')
args = parser.parse_args()
# Initialize video capture
cap = cv2.VideoCapture(args.video_path)
if not cap.isOpened():
print(f"Error: Could not open video at {args.video_path}")
return
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
# Initialize person counter
counter = PersonCounter(line_position=args.line_position)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Process frame
processed_frame, count = counter.process_frame(frame)
# Display frame
cv2.imshow('Frame', processed_frame)
# Write processed frame to output video
writer.write(processed_frame)
# Break on 'q' press
if cv2.waitKey(1) & 0xFF == ord('q'):
break
print(f"Final count: {counter.count}")
# Clean up
cap.release()
writer.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()