File size: 8,479 Bytes
41c03cf
 
c3429f6
462fddf
c3429f6
462fddf
 
ba9faee
c3429f6
7e29aa0
 
c3429f6
58199b3
 
 
462fddf
58199b3
0a9dc78
58199b3
c3429f6
 
462fddf
 
42d2b87
462fddf
 
 
7e29aa0
58199b3
 
42d2b87
 
 
a653421
58199b3
 
462fddf
 
58199b3
462fddf
 
58199b3
462fddf
7e29aa0
462fddf
42d2b87
c3429f6
462fddf
 
 
 
c9d4715
0a9dc78
 
 
 
462fddf
c3429f6
0a9dc78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462fddf
 
58199b3
 
0a9dc78
58199b3
462fddf
 
 
c9d4715
58199b3
 
 
 
 
 
 
 
 
0a9dc78
58199b3
0a9dc78
58199b3
 
 
 
 
c3429f6
462fddf
 
 
58199b3
462fddf
58199b3
 
 
 
0a9dc78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462fddf
 
58199b3
462fddf
 
13e82ee
462fddf
c9d4715
58199b3
 
 
 
 
c9d4715
58199b3
 
462fddf
58199b3
462fddf
 
c9d4715
58199b3
 
462fddf
58199b3
462fddf
 
c9d4715
7e29aa0
a653421
689fb64
c9d4715
462fddf
 
 
 
 
58199b3
 
c9d4715
462fddf
58199b3
c9d4715
 
462fddf
61be320
 
462fddf
 
 
 
 
c9d4715
462fddf
 
58199b3
462fddf
a295d73
c3429f6
462fddf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os

# Load the trained YOLOv8n model
model = YOLO("best.pt")

# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286  # meters (width of stumps)
BALL_DIAMETER = 0.073  # meters (approx. cricket ball diameter)
FRAME_RATE = 30  # Input video frame rate
SLOW_MOTION_FACTOR = 6  # For very slow motion (6x slower)
CONF_THRESHOLD = 0.25  # Lowered confidence threshold to improve detection
IMPACT_ZONE_Y = 0.85  # Fraction of frame height where impact is likely (near stumps)

def process_video(video_path):
    if not os.path.exists(video_path):
        return [], [], "Error: Video file not found"
    cap = cv2.VideoCapture(video_path)
    frames = []
    ball_positions = []
    debug_log = []

    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        frames.append(frame.copy())
        results = model.predict(frame, conf=CONF_THRESHOLD)
        detections = 0
        for detection in results[0].boxes:
            if detection.cls == 0:  # Assuming class 0 is the ball
                detections += 1
                x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
                ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
        frames[-1] = frame
        debug_log.append(f"Frame {frame_count}: {detections} ball detections")
    cap.release()

    if not ball_positions:
        debug_log.append("No balls detected in any frame")
    else:
        debug_log.append(f"Total ball detections: {len(ball_positions)}")

    # Interpolate missing detections
    if ball_positions:
        ball_positions = interpolate_missing_positions(ball_positions, frame_count)

    return frames, ball_positions, "\n".join(debug_log)

def interpolate_missing_positions(ball_positions, total_frames):
    if len(ball_positions) < 2:
        return ball_positions
    times = np.linspace(0, total_frames / FRAME_RATE, total_frames)
    detected_times = [i / FRAME_RATE for i, _ in enumerate(ball_positions)]
    x_coords = [pos[0] for pos in ball_positions]
    y_coords = [pos[1] for pos in ball_positions]

    try:
        fx = interp1d(detected_times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(detected_times, y_coords, kind='linear', fill_value="extrapolate")
        interpolated_positions = [(fx(t), fy(t)) for t in times if t <= detected_times[-1]]
        return interpolated_positions
    except:
        return ball_positions

def estimate_trajectory(ball_positions, frames):
    if len(ball_positions) < 2:
        return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
    frame_height = frames[0].shape[0]

    # Extract x, y coordinates
    x_coords = [pos[0] for pos in ball_positions]
    y_coords = [pos[1] for pos in ball_positions]
    times = np.arange(len(ball_positions)) / FRAME_RATE

    # Find impact point (closest to batsman, near stumps)
    impact_idx = None
    for i, y in enumerate(y_coords):
        if y > frame_height * IMPACT_ZONE_Y:  # Ball is near stumps/batsman
            impact_idx = i
            break
    if impact_idx is None:
        impact_idx = len(ball_positions) - 1  # Fallback to last detection

    pitch_point = ball_positions[0]
    impact_point = ball_positions[impact_idx]

    # Use positions up to impact for interpolation
    x_coords = x_coords[:impact_idx + 1]
    y_coords = y_coords[:impact_idx + 1]
    times = times[:impact_idx + 1]

    try:
        fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
    except Exception as e:
        return None, None, None, f"Error in trajectory interpolation: {str(e)}"

    # Project trajectory (detected + future)
    t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
    x_full = fx(t_full)
    y_full = fy(t_full)
    trajectory = list(zip(x_full, y_full))

    return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"

def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
    if not frames:
        return "Error: No frames processed", None, None, None
    if not trajectory or len(ball_positions) < 2:
        return "Not enough data (insufficient ball detections)", None, None, None

    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)

    pitch_x, pitch_y = pitch_point
    impact_x, impact_y = impact_point

    # Check pitching point
    if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point

    # Check impact point
    if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point

    # Check trajectory hitting stumps
    for x, y in trajectory:
        if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
            return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
    return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point

def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
    if not frames:
        return None
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))

    for frame in frames:
        # Draw full trajectory (blue dots)
        if trajectory:
            for x, y in trajectory:
                cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)

        # Draw pitch point (red circle with label)
        if pitch_point:
            x, y = pitch_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
            cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        # Draw impact point (yellow circle with label)
        if impact_point:
            x, y = impact_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
            cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

        for _ in range(SLOW_MOTION_FACTOR):
            out.write(frame)
    out.release()
    return output_path

def drs_review(video):
    frames, ball_positions, debug_log = process_video(video)
    if not frames:
        return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
    trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
    decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)

    output_path = f"output_{uuid.uuid4()}.mp4"
    slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)

    debug_output = f"{debug_log}\n{trajectory_log}"
    return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path

# Gradio interface
iface = gr.Interface(
    fn=drs_review,
    inputs=gr.Video(label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="DRS Decision and Debug Log"),
        gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
    ],
    title="AI-Powered DRS for LBW in Local Cricket",
    description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
)

if __name__ == "__main__":
    iface.launch()