File size: 3,620 Bytes
6cb50ff
a03d512
 
04f4d0b
fa66a0f
a03d512
a668c53
fa66a0f
04f4d0b
5099bc6
04f4d0b
6cb50ff
fa66a0f
 
 
 
a03d512
 
33ab799
 
6cb50ff
a03d512
 
 
 
6cb50ff
04f4d0b
 
 
6cb50ff
fa66a0f
 
 
 
a03d512
 
 
 
 
6cb50ff
04f4d0b
 
 
a03d512
04f4d0b
 
 
 
fa66a0f
 
 
 
 
 
 
 
 
 
 
 
 
04f4d0b
fa66a0f
 
6cb50ff
fa66a0f
 
 
 
 
 
2010b2d
a03d512
 
6cb50ff
fa66a0f
 
 
 
 
 
 
 
 
f72b9bb
fa66a0f
 
 
 
 
 
 
 
6cb50ff
a03d512
fa66a0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import cv2
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO

# Load YOLOv8 model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = YOLO('./data/best.pt')  # Path to your model
model.to(device)

# Store frames with detected objects
frames_with_detections = []
detection_counts = []

# Define the function that processes the uploaded video
def process_video(video):
    # video is now the file path string, not a file object
    input_video = cv2.VideoCapture(video)  # Directly pass the path to cv2.VideoCapture

    # Get frame width, height, and fps from input video
    frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = input_video.get(cv2.CAP_PROP_FPS)

    # Resize to reduce computation (optional)
    new_width, new_height = 640, 480  # Resize to 640x480 resolution
    frame_width, frame_height = new_width, new_height

    # Track detected objects by their bounding box coordinates
    detected_boxes = set()
    total_detections = 0

    while True:
        # Read a frame from the video
        ret, frame = input_video.read()
        if not ret:
            break  # End of video

        # Resize the frame to reduce computational load
        frame = cv2.resize(frame, (new_width, new_height))

        # Perform inference on the frame
        results = model(frame)  # Automatically uses GPU if available

        # Check if any object was detected
        if len(results[0].boxes) > 0:  # If there are detected objects
            # Get the bounding boxes for each detected object
            boxes = results[0].boxes.xyxy.cpu().numpy()  # Get xyxy coordinates

            # Loop through each detection and only show the frame for new objects
            for box in boxes:
                x1, y1, x2, y2 = box
                detection_box = (x1, y1, x2, y2)

                # Check if this box was already processed
                if detection_box not in detected_boxes:
                    # Add the box to the set to avoid repeating the detection
                    detected_boxes.add(detection_box)
                    total_detections += 1

                    # Annotate the frame with bounding boxes
                    annotated_frame = results[0].plot()  # Plot the frame with bounding boxes

                    # Convert the annotated frame to RGB format for displaying
                    annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

                    # Add this frame to the list of frames with detections
                    frames_with_detections.append(annotated_frame_rgb)
                    detection_counts.append(total_detections)

    # Release resources
    input_video.release()

    # Return the frames with detections for display
    return frames_with_detections

# Create a Gradio Blocks interface
with gr.Blocks() as demo:
    # Define a file input for video upload
    video_input = gr.Video(label="Upload Video")
    
    # Define the output area to show processed frames
    gallery_output = gr.Gallery(label="Detection Album", show_label=True).style(columns=3)  # Display images in a row (album)

    # Define the function to update frames in the album
    def update_gallery(video):
        detected_frames = process_video(video)
        return detected_frames  # Return all frames with detections

    # Connect the video input to the gallery update
    video_input.change(update_gallery, inputs=video_input, outputs=gallery_output)

# Launch the interface
demo.launch()