File size: 8,535 Bytes
41c03cf
 
c3429f6
462fddf
c3429f6
462fddf
 
ba9faee
c3429f6
7e29aa0
 
c3429f6
7e29aa0
 
 
462fddf
c9d4715
7e29aa0
c9d4715
c3429f6
 
462fddf
 
42d2b87
462fddf
 
 
 
c9d4715
7e29aa0
c9d4715
42d2b87
 
 
a653421
c9d4715
 
462fddf
7e29aa0
c9d4715
462fddf
c9d4715
462fddf
7e29aa0
462fddf
 
7e29aa0
c9d4715
 
 
 
462fddf
7e29aa0
462fddf
42d2b87
c3429f6
462fddf
 
 
 
c9d4715
462fddf
c3429f6
462fddf
 
c9d4715
462fddf
 
 
c9d4715
c3429f6
462fddf
 
 
c9d4715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462fddf
 
 
 
c9d4715
462fddf
c9d4715
462fddf
 
7e29aa0
 
c9d4715
462fddf
c9d4715
 
 
462fddf
c9d4715
462fddf
c9d4715
 
462fddf
 
 
c9d4715
 
462fddf
 
 
 
 
c9d4715
462fddf
 
13e82ee
462fddf
c9d4715
 
 
 
 
 
 
 
 
462fddf
c9d4715
462fddf
 
c9d4715
 
 
462fddf
c9d4715
462fddf
 
c9d4715
7e29aa0
a653421
689fb64
c9d4715
462fddf
 
 
 
 
 
 
c9d4715
 
462fddf
c9d4715
 
 
462fddf
61be320
 
462fddf
 
 
 
 
c9d4715
462fddf
 
 
a295d73
c3429f6
462fddf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os

# Load the trained YOLOv8n model
model = YOLO("best.pt")

# Constants
STUMPS_WIDTH = 0.2286  # meters
BALL_DIAMETER = 0.073  # meters
FRAME_RATE = 30  # Input video frame rate
SLOW_MOTION_FACTOR = 6
CONF_THRESHOLD = 0.3
RESIZE_DIM = 640  # Resize frames for faster processing

def process_video(video_path):
    if not os.path.exists(video_path):
        return [], [], "Error: Video file not found"
    cap = cv2.VideoCapture(video_path)
    frames = []
    ball_positions = []
    debug_log = []
    frame_count = 0
    max_frames = FRAME_RATE * 3  # Limit to 3 seconds of frames

    while cap.isOpened() and frame_count < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        # Resize frame for faster YOLO inference
        frame_resized = cv2.resize(frame, (RESIZE_DIM, RESIZE_DIM), interpolation=cv2.INTER_AREA)
        frames.append(frame.copy())  # Store original frame
        # Detect ball
        results = model.predict(frame_resized, conf=CONF_THRESHOLD, imgsz=RESIZE_DIM)
        detections = 0
        scale_x, scale_y = frame.shape[1] / RESIZE_DIM, frame.shape[0] / RESIZE_DIM
        for detection in results[0].boxes:
            if detection.cls == 0:  # Ball class
                detections += 1
                x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
                # Scale coordinates back to original frame size
                x1, x2 = x1 * scale_x, x2 * scale_x
                y1, y2 = y1 * scale_y, y2 * scale_y
                ball_center = [(x1 + x2) / 2, (y1 + y2) / 2]
                ball_positions.append(ball_center)
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
        frames[-1] = frame
        debug_log.append(f"Frame {frame_count}: {detections} ball detections")
    cap.release()

    if not ball_positions:
        debug_log.append("No balls detected in any frame")
    else:
        debug_log.append(f"Total ball detections: {len(ball_positions)}")

    return frames, ball_positions, "\n".join(debug_log)

def estimate_trajectory(ball_positions, frames):
    if len(ball_positions) < 2:
        return [], [], "Error: Fewer than 2 ball detections for trajectory"
    x_coords = [pos[0] for pos in ball_positions]
    y_coords = [pos[1] for pos in ball_positions]
    times = np.arange(len(ball_positions)) / FRAME_RATE

    try:
        fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
    except Exception as e:
        return [], [], f"Error in trajectory interpolation: {str(e)}"

    # Interpolate for all frames and future projection
    t_all = np.linspace(0, times[-1] + 0.5, len(frames) + 10)
    x_all = fx(t_all)
    y_all = fy(t_all)
    trajectory = list(zip(x_all, y_all))
    return trajectory, t_all, "Trajectory estimated successfully"

def detect_impact_point(ball_positions, frames):
    if len(ball_positions) < 3:
        return ball_positions[-1] if ball_positions else None, len(ball_positions) - 1
    # Assume batsman is near stumps (bottom center of frame)
    frame_height, frame_width = frames[0].shape[:2]
    batsman_x = frame_width / 2
    batsman_y = frame_height * 0.8  # Approximate batsman position
    min_dist = float('inf')
    impact_idx = len(ball_positions) - 1
    impact_point = ball_positions[-1]

    # Look for sudden change in trajectory or proximity to batsman
    for i in range(1, len(ball_positions) - 1):
        x, y = ball_positions[i]
        prev_x, prev_y = ball_positions[i-1]
        next_x, next_y = ball_positions[i+1]
        # Check direction change (simplified)
        dx1, dy1 = x - prev_x, y - prev_y
        dx2, dy2 = next_x - x, next_y - y
        angle_change = abs(np.arctan2(dy2, dx2) - np.arctan2(dy1, dx1))
        dist_to_batsman = np.sqrt((x - batsman_x)**2 + (y - batsman_y)**2)
        if angle_change > np.pi/4 or dist_to_batsman < frame_width * 0.1:
            impact_idx = i
            impact_point = ball_positions[i]
            break

    return impact_point, impact_idx

def lbw_decision(ball_positions, trajectory, frames):
    if not frames:
        return "Error: No frames processed", None, None, None
    if len(ball_positions) < 2:
        return "Not enough data (insufficient ball detections)", None, None, None

    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)

    pitch_point = ball_positions[0]
    impact_point, impact_idx = detect_impact_point(ball_positions, frames)

    # Check pitching point
    pitch_x, pitch_y = pitch_point
    if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x Moderation: x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point

    # Check impact point
    impact_x, impact_y = impact_point
    if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point

    # Check trajectory hitting stumps
    for x, y in trajectory:
        if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
            return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
    return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point

def generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path):
    if not frames:
        return None
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))

    for i, frame in enumerate(frames):
        # Draw trajectory up to current frame
        traj_points = [p for j, p in enumerate(trajectory) if j / SLOW_MOTION_FACTOR <= i]
        for x, y in traj_points:
            cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)  # Blue dots

        # Draw pitch point in early frames
        if pitch_point and i < len(frames) // 2:
            x, y = pitch_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)  # Red circle
            cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        # Draw impact point around impact frame
        if impact_point and abs(i - impact_idx) < 5:
            x, y = impact_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)  # Yellow circle
            cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

        for _ in range(SLOW_MOTION_FACTOR):
            out.write(frame)
    out.release()
    return output_path

def drs_review(video):
    frames, ball_positions, debug_log = process_video(video)
    if not frames:
        return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
    trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
    decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
    _, impact_idx = detect_impact_point(ball_positions, frames)

    output_path = f"output_{uuid.uuid4()}.mp4"
    slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path)

    debug_output = f"{debug_log}\n{trajectory_log}"
    return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path

# Gradio interface
iface = gr.Interface(
    fn=drs_review,
    inputs=gr.Video(label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="DRS Decision and Debug Log"),
        gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
    ],
    title="AI-Powered DRS for LBW in Local Cricket",
)

if __name__ == "__main__":
    iface.launch()