DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
61be320 verified
raw
history blame
3.54 kB
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import os
# Load the YOLOv5 model
model = YOLO("best.pt")
def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5):
"""
Perform ball detection on image or video input.
Args:
input_media: Uploaded image or video file
conf_threshold: Confidence threshold for detection
iou_threshold: IoU threshold for non-max suppression
Returns:
Annotated image or video path
"""
# Check if input is image or video based on file extension
file_extension = os.path.splitext(input_media)[1].lower()
if file_extension in ['.jpg', '.jpeg', '.png']:
# Process image
img = cv2.imread(input_media)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Perform detection
results = model.predict(img, conf=conf_threshold, iou=iou_threshold)
# Draw bounding boxes
for box in results[0].boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
label = f"Ball: {conf:.2f}"
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Convert to PIL Image for Gradio output
output_img = Image.fromarray(img)
return output_img
elif file_extension in ['.mp4', '.avi', '.mov']:
# Process video
cap = cv2.VideoCapture(input_media)
output_path = "output_video.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 30.0,
(int(cap.get(3)), int(cap.get(4))))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Perform detection
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
# Draw bounding boxes
for box in results[0].boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
label = f"Ball: {conf:.2f}"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
out.write(frame)
cap.release()
out.release()
return output_path
else:
return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)."
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Decision Review System (DRS) for Ball Detection")
gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.")
--
input_media = gr.File(label="Upload Image or Video")
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
output = gr.Image(label="Output (Image or Video)")
submit_button = gr.Button("Detect Ball")
submit_button.click(
fn=detect_ball,
inputs=[input_media, conf_slider, iou_slider],
outputs=output
)
demo.launch()