File size: 3,535 Bytes
61be320
 
 
41c03cf
 
a295d73
ba9faee
a295d73
61be320
 
a295d73
61be320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a295d73
61be320
 
a295d73
61be320
 
 
 
 
 
 
a295d73
61be320
 
 
 
 
a295d73
61be320
 
 
 
 
 
 
 
 
 
a295d73
61be320
 
 
a295d73
61be320
 
 
 
 
 
 
a295d73
61be320
a295d73
61be320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a295d73
61be320
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import os

# Load the YOLOv5 model
model = YOLO("best.pt")

def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5):
    """
    Perform ball detection on image or video input.
    
    Args:
        input_media: Uploaded image or video file
        conf_threshold: Confidence threshold for detection
        iou_threshold: IoU threshold for non-max suppression
    
    Returns:
        Annotated image or video path
    """
    # Check if input is image or video based on file extension
    file_extension = os.path.splitext(input_media)[1].lower()
    
    if file_extension in ['.jpg', '.jpeg', '.png']:
        # Process image
        img = cv2.imread(input_media)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Perform detection
        results = model.predict(img, conf=conf_threshold, iou=iou_threshold)
        
        # Draw bounding boxes
        for box in results[0].boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            conf = box.conf[0]
            label = f"Ball: {conf:.2f}"
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        
        # Convert to PIL Image for Gradio output
        output_img = Image.fromarray(img)
        return output_img
    
    elif file_extension in ['.mp4', '.avi', '.mov']:
        # Process video
        cap = cv2.VideoCapture(input_media)
        output_path = "output_video.mp4"
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, 30.0, 
                             (int(cap.get(3)), int(cap.get(4))))
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                
            # Perform detection
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
            
            # Draw bounding boxes
            for box in results[0].boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                conf = box.conf[0]
                label = f"Ball: {conf:.2f}"
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            
            out.write(frame)
        
        cap.release()
        out.release()
        return output_path
    
    else:
        return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)."

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Decision Review System (DRS) for Ball Detection")
    gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.")
    
 --

    input_media = gr.File(label="Upload Image or Video")
    conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
    iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
    output = gr.Image(label="Output (Image or Video)")
    submit_button = gr.Button("Detect Ball")
    
    submit_button.click(
        fn=detect_ball,
        inputs=[input_media, conf_slider, iou_slider],
        outputs=output
    )

demo.launch()