File size: 5,166 Bytes
41c03cf
 
c3429f6
6e725f6
c3429f6
6e725f6
 
ba9faee
c3429f6
6e725f6
 
c3429f6
6e725f6
 
 
 
c3429f6
 
6e725f6
42d2b87
6e725f6
 
c3429f6
42d2b87
 
 
 
6e725f6
 
 
 
 
 
 
 
 
 
42d2b87
c3429f6
6e725f6
c3429f6
6e725f6
 
 
 
 
 
 
 
c3429f6
6e725f6
c3429f6
6e725f6
 
c3429f6
6e725f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e82ee
6e725f6
13e82ee
6e725f6
13e82ee
6e725f6
 
 
 
13e82ee
6e725f6
 
 
 
 
 
 
 
 
c3429f6
6e725f6
 
 
c3429f6
6e725f6
61be320
 
6e725f6
 
 
 
 
 
 
 
 
 
a295d73
c3429f6
6e725f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os

# Load the trained YOLOv8n model from the Space's root directory
model = YOLO("best.pt")  # Assumes best.pt is in the same directory as app.py

# Constants for LBW decision
STUMPS_WIDTH = 0.2286  # meters (width of stumps)
BALL_DIAMETER = 0.073  # meters (approx. cricket ball diameter)
FRAME_RATE = 30  # Default frame rate for video processing

def process_video(video_path):
    # Initialize video capture
    cap = cv2.VideoCapture(video_path)
    frames = []
    ball_positions = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame.copy())  # Store original frame
        # Detect ball using the trained YOLOv8n model
        results = model.predict(frame, conf=0.5)  # Adjust confidence threshold if needed
        for detection in results[0].boxes:
            if detection.cls == 0:  # Assuming class 0 is the ball
                x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
                ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
                # Draw bounding box on frame for visualization
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
        frames[-1] = frame  # Update frame with bounding box
    cap.release()

    return frames, ball_positions

def estimate_trajectory(ball_positions, frames):
    # Simplified physics-based trajectory projection
    if len(ball_positions) < 2:
        return None, None
    # Extract x, y coordinates
    x_coords = [pos[0] for pos in ball_positions]
    y_coords = [pos[1] for pos in ball_positions]
    times = np.arange(len(ball_positions)) / FRAME_RATE

    # Interpolate to smooth trajectory
    try:
        fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
    except:
        return None, None

    # Project trajectory forward (0.5 seconds post-impact)
    t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
    x_future = fx(t_future)
    y_future = fy(t_future)

    return list(zip(x_future, y_future)), t_future

def lbw_decision(ball_positions, trajectory, frames):
    # Simplified LBW logic
    if not trajectory or len(ball_positions) < 2:
        return "Not enough data", None

    # Assume stumps are at the bottom center of the frame (calibration needed)
    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9  # Approximate stumps position
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)  # Assume 3m pitch width

    # Check pitching point (first detected position)
    pitch_x, pitch_y = ball_positions[0]
    if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
        return "Not Out (Pitched outside line)", None

    # Check impact point (last detected position)
    impact_x, impact_y = ball_positions[-1]
    if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
        return "Not Out (Impact outside line)", None

    # Check trajectory hitting stumps
    for x, y in trajectory:
        if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
            return "Out", trajectory
    return "Not Out (Missing stumps)", trajectory

def generate_slow_motion(frames, trajectory, output_path):
    # Generate slow-motion video with ball detection and trajectory overlay
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / 2, (frames[0].shape[1], frames[0].shape[0]))

    for frame in frames:
        if trajectory:
            for x, y in trajectory:
                cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)  # Blue dots for trajectory
        out.write(frame)
        out.write(frame)  # Duplicate frames for slow-motion effect
    out.release()
    return output_path

def drs_review(video):
    # Process video and generate DRS output
    if not os.path.exists(video):
        return "Error: Video file not found", None
    frames, ball_positions = process_video(video)
    trajectory, _ = estimate_trajectory(ball_positions, frames)
    decision, trajectory = lbw_decision(ball_positions, trajectory, frames)

    # Generate slow-motion replay
    output_path = f"output_{uuid.uuid4()}.mp4"
    slow_motion_path = generate_slow_motion(frames, trajectory, output_path)

    return decision, slow_motion_path

# Gradio interface
iface = gr.Interface(
    fn=drs_review,
    inputs=gr.Video(label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="DRS Decision"),
        gr.Video(label="Slow-Motion Replay with Ball Detection and Trajectory")
    ],
    title="AI-Powered DRS for LBW in Local Cricket",
    description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection and trajectory."
)

if __name__ == "__main__":
    iface.launch()