File size: 7,801 Bytes
41c03cf c3429f6 462fddf c3429f6 462fddf ba9faee 7e29aa0 c3429f6 7e29aa0 c3429f6 7e29aa0 462fddf 7e29aa0 c3429f6 7e29aa0 462fddf 42d2b87 462fddf 7e29aa0 42d2b87 a653421 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 42d2b87 c3429f6 7e29aa0 462fddf c3429f6 462fddf 7e29aa0 462fddf c3429f6 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 13e82ee 462fddf 13e82ee 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 a653421 689fb64 7e29aa0 462fddf 7e29aa0 462fddf 7e29aa0 462fddf 61be320 462fddf 7e29aa0 462fddf 7e29aa0 462fddf a295d73 c3429f6 462fddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
import time
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants
STUMPS_WIDTH = 0.2286 # meters
BALL_DIAMETER = 0.073 # meters
FRAME_RATE = 30 # Input video frame rate
SLOW_MOTION_FACTOR = 3 # Reduced from 6 for faster processing
CONF_THRESHOLD = 0.3
MAX_DETECTIONS = 10 # Stop after detecting enough ball positions
PROCESS_EVERY_N_FRAME = 2 # Process every 2nd frame
RESIZE_FACTOR = 0.5 # Downscale frames to 50% for faster processing
def process_video(video_path):
start_time = time.time()
if not os.path.exists(video_path):
return [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
debug_log = []
frame_count = 0
processed_frames = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % PROCESS_EVERY_N_FRAME != 0:
continue # Skip frames
processed_frames += 1
# Resize frame for faster processing
frame_small = cv2.resize(frame, (0, 0), fx=RESIZE_FACTOR, fy=RESIZE_FACTOR)
frames.append(frame.copy()) # Store original frame
# Detect ball
results = model.predict(frame_small, conf=CONF_THRESHOLD)
detections = 0
for detection in results[0].boxes:
if detection.cls == 0: # Ball class
detections += 1
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
# Scale coordinates back to original frame size
x1, y1, x2, y2 = [v / RESIZE_FACTOR for v in [x1, y1, x2, y2]]
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
# Draw bounding box on original frame
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
if len(ball_positions) >= MAX_DETECTIONS:
debug_log.append(f"Stopping early after {len(ball_positions)} detections")
break
cap.release()
debug_log.append(f"Processed {processed_frames} frames in {time.time() - start_time:.2f} seconds")
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
return frames, ball_positions, "\n".join(debug_log)
def estimate_trajectory(ball_positions, frames):
start_time = time.time()
if len(ball_positions) < 2:
return None, None, "Error: Fewer than 2 ball detections for trajectory"
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return None, None, f"Error in trajectory interpolation: {str(e)}"
t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
x_future = fx(t_future)
y_future = fy(t_future)
debug_log = f"Trajectory estimated in {time.time() - start_time:.2f} seconds"
return list(zip(x_future, y_future)), t_future, debug_log
def lbw_decision(ball_positions, trajectory, frames):
start_time = time.time()
if not frames:
return "Error: No frames processed", None, None, None
if not trajectory or len(ball_positions) < 2:
return "Not enough data (insufficient ball detections)", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
pitch_point = ball_positions[0]
impact_point = ball_positions[-1]
pitch_x, pitch_y = pitch_point
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
impact_x, impact_y = impact_point
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
debug_log = f"LBW decision computed in {time.time() - start_time:.2f} seconds"
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
start_time = time.time()
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
for frame in frames:
if trajectory:
for x, y in trajectory:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
if pitch_point:
x, y = pitch_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
if impact_point:
x, y = impact_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
debug_log = f"Slow-motion video generated in {time.time() - start_time:.2f} seconds"
return output_path, debug_log
def drs_review(video):
start_time = time.time()
frames, ball_positions, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path, slow_motion_log = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
debug_output = f"{debug_log}\n{trajectory_log}\n{slow_motion_log}\nTotal processing time: {time.time() - start_time:.2f} seconds"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a 3-second video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
)
if __name__ == "__main__":
iface.launch() |