Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import gradio as gr | |
import numpy as np | |
from ultralytics import YOLO | |
import matplotlib.pyplot as plt | |
# Load YOLOv8 model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = YOLO('./data/best.pt') # Path to your model | |
model.to(device) | |
# List to store frames with detections | |
frames_with_detections = [] | |
# Define the function to process the video | |
def process_video(video): | |
# Open the video file | |
input_video = cv2.VideoCapture(video) | |
frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = input_video.get(cv2.CAP_PROP_FPS) | |
# Resize frames to 640x480 (optional, to reduce computational load) | |
new_width, new_height = 640, 480 | |
while True: | |
# Read a frame from the video | |
ret, frame = input_video.read() | |
if not ret: | |
break # End of video | |
# Resize the frame | |
frame = cv2.resize(frame, (new_width, new_height)) | |
# Perform inference on the frame | |
results = model(frame) # Automatically uses GPU if available | |
# If there are detections | |
if len(results[0].boxes) > 0: | |
boxes = results[0].boxes.xyxy.cpu().numpy() # Get the bounding boxes | |
# Annotate the frame with bounding boxes | |
annotated_frame = results[0].plot() | |
# Convert the frame to RGB | |
annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
# Append the frame with detection to list | |
frames_with_detections.append(annotated_frame_rgb) | |
# Create a simple bar chart to show the count of detected objects | |
fig, ax = plt.subplots() | |
ax.bar([1], [len(boxes)], color='blue') # Bar for the current frame detection | |
ax.set_xlabel('Frame') | |
ax.set_ylabel('Number of Detections') | |
ax.set_title('Detection Count per Frame') | |
# Convert plot to an image to return it in Gradio output | |
plt.tight_layout() | |
plt.close(fig) | |
# Save the plot as an image in memory | |
buf = np.frombuffer(fig.canvas.print_to_buffer()[0], dtype=np.uint8) | |
img = cv2.imdecode(buf, cv2.IMREAD_COLOR) | |
# Yield the detected frame and the graph at the same time | |
yield annotated_frame_rgb, img | |
# Release resources | |
input_video.release() | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
video_input = gr.Video(label="Upload Video") | |
gallery_output = gr.Gallery(label="Detection Album").style(columns=3) # Display images in a row | |
graph_output = gr.Image(label="Detection Counts Graph", type="numpy") # For displaying graph | |
video_input.change(process_video, inputs=video_input, outputs=[gallery_output, graph_output]) | |
# Launch the interface | |
demo.launch() | |