|
import gradio as gr |
|
import torch |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5): |
|
""" |
|
Perform ball detection on image or video input. |
|
|
|
Args: |
|
input_media: Uploaded image or video file |
|
conf_threshold: Confidence threshold for detection |
|
iou_threshold: IoU threshold for non-max suppression |
|
|
|
Returns: |
|
Annotated image or video path |
|
""" |
|
|
|
file_extension = os.path.splitext(input_media)[1].lower() |
|
|
|
if file_extension in ['.jpg', '.jpeg', '.png']: |
|
|
|
img = cv2.imread(input_media) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
results = model.predict(img, conf=conf_threshold, iou=iou_threshold) |
|
|
|
|
|
for box in results[0].boxes: |
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
conf = box.conf[0] |
|
label = f"Ball: {conf:.2f}" |
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
|
|
|
|
output_img = Image.fromarray(img) |
|
return output_img |
|
|
|
elif file_extension in ['.mp4', '.avi', '.mov']: |
|
|
|
cap = cv2.VideoCapture(input_media) |
|
output_path = "output_video.mp4" |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, 30.0, |
|
(int(cap.get(3)), int(cap.get(4)))) |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold) |
|
|
|
|
|
for box in results[0].boxes: |
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
conf = box.conf[0] |
|
label = f"Ball: {conf:.2f}" |
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
|
|
out.write(frame) |
|
|
|
cap.release() |
|
out.release() |
|
return output_path |
|
|
|
else: |
|
return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)." |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Decision Review System (DRS) for Ball Detection") |
|
gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.") |
|
|
|
-- |
|
|
|
input_media = gr.File(label="Upload Image or Video") |
|
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold") |
|
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold") |
|
output = gr.Image(label="Output (Image or Video)") |
|
submit_button = gr.Button("Detect Ball") |
|
|
|
submit_button.click( |
|
fn=detect_ball, |
|
inputs=[input_media, conf_slider, iou_slider], |
|
outputs=output |
|
) |
|
|
|
demo.launch() |