File size: 2,913 Bytes
6cb50ff
a03d512
 
04f4d0b
a03d512
4527f8f
a668c53
fa66a0f
04f4d0b
5099bc6
04f4d0b
6cb50ff
4527f8f
fa66a0f
 
4527f8f
a03d512
4527f8f
 
a03d512
 
 
6cb50ff
4527f8f
 
fa66a0f
a03d512
 
 
 
 
6cb50ff
4527f8f
04f4d0b
 
a03d512
04f4d0b
 
4527f8f
 
 
fa66a0f
4527f8f
 
fa66a0f
4527f8f
 
04f4d0b
4527f8f
 
6cb50ff
4527f8f
 
 
 
 
 
fa66a0f
4527f8f
 
 
2e1fb9d
4527f8f
 
 
 
 
 
2010b2d
a03d512
 
6cb50ff
4527f8f
fa66a0f
4527f8f
 
 
 
 
 
6cb50ff
a03d512
fa66a0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import cv2
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt

# Load YOLOv8 model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = YOLO('./data/best.pt')  # Path to your model
model.to(device)

# List to store frames with detections
frames_with_detections = []

# Define the function to process the video
def process_video(video):
    # Open the video file
    input_video = cv2.VideoCapture(video)  
    frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = input_video.get(cv2.CAP_PROP_FPS)

    # Resize frames to 640x480 (optional, to reduce computational load)
    new_width, new_height = 640, 480

    while True:
        # Read a frame from the video
        ret, frame = input_video.read()
        if not ret:
            break  # End of video

        # Resize the frame
        frame = cv2.resize(frame, (new_width, new_height))

        # Perform inference on the frame
        results = model(frame)  # Automatically uses GPU if available

        # If there are detections
        if len(results[0].boxes) > 0: 
            boxes = results[0].boxes.xyxy.cpu().numpy()  # Get the bounding boxes

            # Annotate the frame with bounding boxes
            annotated_frame = results[0].plot()  

            # Convert the frame to RGB
            annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

            # Append the frame with detection to list
            frames_with_detections.append(annotated_frame_rgb)

            # Create a simple bar chart to show the count of detected objects
            fig, ax = plt.subplots()
            ax.bar([1], [len(boxes)], color='blue')  # Bar for the current frame detection
            ax.set_xlabel('Frame')
            ax.set_ylabel('Number of Detections')
            ax.set_title('Detection Count per Frame')

            # Convert plot to an image to return it in Gradio output
            plt.tight_layout()
            plt.close(fig)

            # Save the plot as an image in memory
            buf = np.frombuffer(fig.canvas.print_to_buffer()[0], dtype=np.uint8)
            img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
            
            # Yield the detected frame and the graph at the same time
            yield annotated_frame_rgb, img

    # Release resources
    input_video.release()

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        video_input = gr.Video(label="Upload Video")
        gallery_output = gr.Gallery(label="Detection Album").style(columns=3)  # Display images in a row
        graph_output = gr.Image(label="Detection Counts Graph", type="numpy")  # For displaying graph

    video_input.change(process_video, inputs=video_input, outputs=[gallery_output, graph_output])

# Launch the interface
demo.launch()