File size: 10,069 Bytes
41c03cf
 
c3429f6
462fddf
c3429f6
462fddf
 
ba9faee
c3429f6
7e29aa0
 
c3429f6
58199b3
 
 
1213ff3
 
7e44cd3
1213ff3
 
 
885c61f
c3429f6
 
462fddf
7e44cd3
42d2b87
462fddf
 
db7efde
462fddf
7e29aa0
58199b3
 
42d2b87
 
 
a653421
58199b3
 
db7efde
 
 
 
 
 
7e29aa0
db7efde
42d2b87
c3429f6
462fddf
db7efde
462fddf
db7efde
c9d4715
7e44cd3
0a9dc78
1213ff3
462fddf
db7efde
58199b3
1213ff3
58199b3
462fddf
 
1213ff3
7e44cd3
db7efde
1213ff3
d41a272
1213ff3
 
d41a272
885c61f
 
1213ff3
 
58199b3
1213ff3
 
 
58199b3
 
 
1213ff3
58199b3
1213ff3
0a9dc78
885c61f
58199b3
 
 
 
c3429f6
462fddf
 
 
db7efde
462fddf
1213ff3
 
 
 
58199b3
 
 
1213ff3
0a9dc78
1213ff3
 
 
 
0a9dc78
1213ff3
0a9dc78
 
1213ff3
db7efde
0a9dc78
 
 
1213ff3
0a9dc78
 
 
 
 
1213ff3
0a9dc78
1213ff3
0a9dc78
1213ff3
0a9dc78
1213ff3
0a9dc78
 
1213ff3
0a9dc78
1213ff3
 
462fddf
1213ff3
462fddf
 
1213ff3
885c61f
 
 
 
c9d4715
1213ff3
 
7e44cd3
1213ff3
 
d41a272
7e44cd3
885c61f
 
 
 
 
 
 
 
 
1213ff3
7e44cd3
1213ff3
 
 
c9d4715
1213ff3
 
462fddf
d41a272
 
 
 
1213ff3
 
462fddf
d41a272
 
 
 
7e29aa0
a653421
689fb64
c9d4715
462fddf
 
7e44cd3
462fddf
1213ff3
 
 
c9d4715
462fddf
1213ff3
c9d4715
1213ff3
 
61be320
 
462fddf
 
 
 
1213ff3
885c61f
462fddf
 
885c61f
462fddf
a295d73
c3429f6
1213ff3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os

# Load the trained YOLOv8n model
model = YOLO("best.pt")

# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286  # meters (width of stumps)
BALL_DIAMETER = 0.073  # meters (approx. cricket ball diameter)
FRAME_RATE = 20  # Input video frame rate
SLOW_MOTION_FACTOR = 3  # Adjusted for 20 FPS
CONF_THRESHOLD = 0.25  # Confidence threshold for detection
IMPACT_ZONE_Y = 0.85  # Fraction of frame height for impact zone
PITCH_ZONE_Y = 0.75  # Fraction of frame height for pitch zone
IMPACT_DELTA_Y = 50  # Pixels for detecting sudden y-position change
STUMPS_HEIGHT = 0.711  # meters (height of stumps)

def process_video(video_path):
    if not os.path.exists(video_path):
        return [], [], [], "Error: Video file not found"
    cap = cv2.VideoCapture(video_path)
    frames = []
    ball_positions = []
    detection_frames = []  # Track frames with exactly one detection
    debug_log = []

    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        frames.append(frame.copy())
        results = model.predict(frame, conf=CONF_THRESHOLD)
        detections = [det for det in results[0].boxes if det.cls == 0]  # Class 0 is cricketBall
        if len(detections) == 1:  # Only consider frames with exactly one detection
            x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
            ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
            detection_frames.append(frame_count - 1)  # 0-based index
            cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
        frames[-1] = frame
        debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
    cap.release()

    if not ball_positions:
        debug_log.append("No valid single-ball detections in any frame")
    else:
        debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")

    return frames, ball_positions, detection_frames, "\n".join(debug_log)

def estimate_trajectory(ball_positions, detection_frames, frames):
    if len(ball_positions) < 2:
        return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
    frame_height = frames[0].shape[0]

    # Extract x, y coordinates
    x_coords = [pos[0] for pos in ball_positions]
    y_coords = [pos[1] for pos in ball_positions]
    times = np.array(detection_frames) / FRAME_RATE

    # Pitch point: first valid detection or when y exceeds PITCH_ZONE_Y
    pitch_idx = 0
    for i, y in enumerate(y_coords):
        if y > frame_height * PITCH_ZONE_Y:
            pitch_idx = i
            break
    pitch_point = ball_positions[pitch_idx]
    pitch_frame = detection_frames[pitch_idx]

    # Impact point: sudden y-change or y exceeds IMPACT_ZONE_Y
    impact_idx = None
    for i in range(1, len(y_coords)):
        if (y_coords[i] > frame_height * IMPACT_ZONE_Y or
            abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y):
            impact_idx = i
            break
    if impact_idx is None:
        impact_idx = len(ball_positions) - 1
    impact_point = ball_positions[impact_idx]
    impact_frame = detection_frames[impact_idx]

    # Use only detected positions for trajectory
    x_coords = x_coords[:impact_idx + 1]
    y_coords = y_coords[:impact_idx + 1]
    times = times[:impact_idx + 1]

    try:
        fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
    except Exception as e:
        return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"

    # Trajectory for visualization (detected frames only)
    vis_trajectory = list(zip(x_coords, y_coords))

    # Full trajectory for LBW (includes projection)
    t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
    x_full = fx(t_full)
    y_full = fy(t_full)
    full_trajectory = list(zip(x_full, y_full))

    debug_log = (f"Trajectory estimated successfully\n"
                 f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
                 f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})")
    return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, debug_log

def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
    if not frames:
        return "Error: No frames processed", None, None, None
    if not full_trajectory or len(ball_positions) < 2:
        return "Not enough data (insufficient valid single-ball detections)", None, None, None

    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)

    pitch_x, pitch_y = pitch_point
    impact_x, impact_y = impact_point

    # Check pitching point
    if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", full_trajectory, pitch_point, impact_point

    # Check impact point
    if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point

    # Check trajectory hitting stumps
    for x, y in full_trajectory:
        if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
            return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
    return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point

def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path):
    if not frames:
        return None
    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
    stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))

    # Prepare trajectory points for visualization
    trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))

    for i, frame in enumerate(frames):
        # Draw stumps (three white vertical lines)
        stump_positions = [
            (stumps_x - stumps_width_pixels / 2, stumps_y),  # Left stump
            (stumps_x, stumps_y),                           # Middle stump
            (stumps_x + stumps_width_pixels / 2, stumps_y)   # Right stump
        ]
        for x, y in stump_positions:
            cv2.line(frame, (int(x), int(y)), (int(x), int(y - stumps_height_pixels)), (255, 255, 255), 2)

        # Draw trajectory (blue line) only for detected frames
        if i in detection_frames and trajectory_points.size > 0:
            idx = detection_frames.index(i) + 1
            if idx <= len(trajectory_points):
                cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)

        # Draw pitch point (red circle) only in pitch frame
        if pitch_point and i == pitch_frame:
            x, y = pitch_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
            cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        # Draw impact point (yellow circle) only in impact frame
        if impact_point and i == impact_frame:
            x, y = impact_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
            cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

        for _ in range(SLOW_MOTION_FACTOR):
            out.write(frame)
    out.release()
    return output_path

def drs_review(video):
    frames, ball_positions, detection_frames, debug_log = process_video(video)
    if not frames:
        return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
    full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory(ball_positions, detection_frames, frames)
    decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)

    output_path = f"output_{uuid.uuid4()}.mp4"
    slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path)

    debug_output = f"{debug_log}\n{trajectory_log}"
    return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path

# Gradio interface
iface = gr.Interface(
    fn=drs_review,
    inputs=gr.Video(label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="DRS Decision and Debug Log"),
        gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow), Stumps (White)")
    ],
    title="AI-Powered DRS for LBW in Local Cricket",
    description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), impact point (yellow circle), and stumps (white lines)."
)

if __name__ == "__main__":
    iface.launch()