File size: 4,173 Bytes
2704b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
from collections import defaultdict
import tempfile
import os

class PersonCounter:
    def __init__(self, line_position=0.5):
        self.model = YOLO("yolov8n.pt")
        self.tracker = defaultdict(list)
        self.crossed_ids = set()
        self.line_position = line_position
        self.count = 0

    def process_frame(self, frame):
        height, width = frame.shape[:2]
        line_y = int(height * self.line_position)

        # Draw counting line
        cv2.line(frame, (0, line_y), (width, line_y), (0, 255, 0), 2)

        # Run detection and tracking
        results = self.model.track(frame, persist=True, classes=[0])
        
        if results[0].boxes.id is not None:
            boxes = results[0].boxes.xyxy.cpu().numpy()
            track_ids = results[0].boxes.id.cpu().numpy().astype(int)

            for box, track_id in zip(boxes, track_ids):
                # Draw bounding box
                cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), 
                            (255, 0, 0), 2)

                # Get feet position
                center_x = (box[0] + box[2]) / 2
                feet_y = box[3]
                
                # Draw tracking point
                cv2.circle(frame, (int(center_x), int(feet_y)), 5, (0, 255, 255), -1)
                
                # Store tracking history
                if track_id in self.tracker:
                    prev_y = self.tracker[track_id][-1]
                    # Check if person has crossed the line
                    if prev_y < line_y and feet_y >= line_y and track_id not in self.crossed_ids:
                        self.crossed_ids.add(track_id)
                        self.count += 1
                        # Draw crossing indicator
                        cv2.circle(frame, (int(center_x), int(line_y)), 8, (0, 0, 255), -1)
                
                self.tracker[track_id] = [feet_y]

        # Draw count with background
        count_text = f"Count: {self.count}"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1.5
        thickness = 3
        (text_width, text_height), _ = cv2.getTextSize(count_text, font, font_scale, thickness)
        
        cv2.rectangle(frame, (10, 10), (20 + text_width, 20 + text_height),
                     (0, 0, 0), -1)
        cv2.putText(frame, count_text, (15, 15 + text_height), 
                    font, font_scale, (0, 255, 0), thickness)
        
        return frame

def process_video(video_path, progress=gr.Progress()):
    # Create temp directory for output
    temp_dir = tempfile.mkdtemp()
    output_path = os.path.join(temp_dir, "result.mp4")
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Could not open video file")
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    counter = PersonCounter(line_position=0.5)
    
    for frame_idx in progress.tqdm(range(total_frames)):
        ret, frame = cap.read()
        if not ret:
            break
            
        processed_frame = counter.process_frame(frame)
        writer.write(processed_frame)
    
    cap.release()
    writer.release()
    
    return output_path, f"Final count: {counter.count} people entered"

# Create Gradio interface
demo = gr.Interface(
    fn=process_video,
    inputs=gr.Video(label="Upload a video file"),
    outputs=[
        gr.Video(label="Processed Video"),
        gr.Textbox(label="Results")
    ],
    title="Store Entry People Counter",
    description="Upload a video to count the number of people entering through a line. The green line represents the counting threshold, blue boxes show detected people, and the counter increases when someone crosses the line from top to bottom.",
    examples=[],
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch()