File size: 8,479 Bytes
41c03cf c3429f6 462fddf c3429f6 462fddf ba9faee c3429f6 7e29aa0 c3429f6 58199b3 462fddf 58199b3 0a9dc78 58199b3 c3429f6 462fddf 42d2b87 462fddf 7e29aa0 58199b3 42d2b87 a653421 58199b3 462fddf 58199b3 462fddf 58199b3 462fddf 7e29aa0 462fddf 42d2b87 c3429f6 462fddf c9d4715 0a9dc78 462fddf c3429f6 0a9dc78 462fddf 58199b3 0a9dc78 58199b3 462fddf c9d4715 58199b3 0a9dc78 58199b3 0a9dc78 58199b3 c3429f6 462fddf 58199b3 462fddf 58199b3 0a9dc78 462fddf 58199b3 462fddf 13e82ee 462fddf c9d4715 58199b3 c9d4715 58199b3 462fddf 58199b3 462fddf c9d4715 58199b3 462fddf 58199b3 462fddf c9d4715 7e29aa0 a653421 689fb64 c9d4715 462fddf 58199b3 c9d4715 462fddf 58199b3 c9d4715 462fddf 61be320 462fddf c9d4715 462fddf 58199b3 462fddf a295d73 c3429f6 462fddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
FRAME_RATE = 30 # Input video frame rate
SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
CONF_THRESHOLD = 0.25 # Lowered confidence threshold to improve detection
IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
debug_log = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
frames.append(frame.copy())
results = model.predict(frame, conf=CONF_THRESHOLD)
detections = 0
for detection in results[0].boxes:
if detection.cls == 0: # Assuming class 0 is the ball
detections += 1
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
cap.release()
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
# Interpolate missing detections
if ball_positions:
ball_positions = interpolate_missing_positions(ball_positions, frame_count)
return frames, ball_positions, "\n".join(debug_log)
def interpolate_missing_positions(ball_positions, total_frames):
if len(ball_positions) < 2:
return ball_positions
times = np.linspace(0, total_frames / FRAME_RATE, total_frames)
detected_times = [i / FRAME_RATE for i, _ in enumerate(ball_positions)]
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
try:
fx = interp1d(detected_times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(detected_times, y_coords, kind='linear', fill_value="extrapolate")
interpolated_positions = [(fx(t), fy(t)) for t in times if t <= detected_times[-1]]
return interpolated_positions
except:
return ball_positions
def estimate_trajectory(ball_positions, frames):
if len(ball_positions) < 2:
return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
frame_height = frames[0].shape[0]
# Extract x, y coordinates
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
# Find impact point (closest to batsman, near stumps)
impact_idx = None
for i, y in enumerate(y_coords):
if y > frame_height * IMPACT_ZONE_Y: # Ball is near stumps/batsman
impact_idx = i
break
if impact_idx is None:
impact_idx = len(ball_positions) - 1 # Fallback to last detection
pitch_point = ball_positions[0]
impact_point = ball_positions[impact_idx]
# Use positions up to impact for interpolation
x_coords = x_coords[:impact_idx + 1]
y_coords = y_coords[:impact_idx + 1]
times = times[:impact_idx + 1]
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return None, None, None, f"Error in trajectory interpolation: {str(e)}"
# Project trajectory (detected + future)
t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
x_full = fx(t_full)
y_full = fy(t_full)
trajectory = list(zip(x_full, y_full))
return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
if not frames:
return "Error: No frames processed", None, None, None
if not trajectory or len(ball_positions) < 2:
return "Not enough data (insufficient ball detections)", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
pitch_x, pitch_y = pitch_point
impact_x, impact_y = impact_point
# Check pitching point
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
# Check impact point
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
# Check trajectory hitting stumps
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
for frame in frames:
# Draw full trajectory (blue dots)
if trajectory:
for x, y in trajectory:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
# Draw pitch point (red circle with label)
if pitch_point:
x, y = pitch_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# Draw impact point (yellow circle with label)
if impact_point:
x, y = impact_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return output_path
def drs_review(video):
frames, ball_positions, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
debug_output = f"{debug_log}\n{trajectory_log}"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
)
if __name__ == "__main__":
iface.launch() |