DRS_AI / app.py
dschandra's picture
Update app.py
544dd15 verified
raw
history blame
10.3 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
FRAME_RATE = 20 # Input video frame rate
SLOW_MOTION_FACTOR = 2 # Reduced for faster output
CONF_THRESHOLD = 0.25 # Confidence threshold for detection
PITCH_ZONE_Y = 0.9 # Fraction of frame height for pitch zone
IMPACT_ZONE_Y = 0.8 # Fraction of frame height for impact zone
IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
STUMPS_HEIGHT = 0.711 # meters (height of stumps)
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
detection_frames = []
debug_log = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % 2 == 0: # Process every 2nd frame
frames.append(frame.copy())
results = model.predict(frame, conf=CONF_THRESHOLD)
detections = [det for det in results[0].boxes if det.cls == 0]
if len(detections) == 1:
x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
detection_frames.append(len(frames) - 1)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
frame_count += 1
cap.release()
if not ball_positions:
debug_log.append("No valid single-ball detections in any frame")
else:
debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")
return frames, ball_positions, detection_frames, "\n".join(debug_log)
def estimate_trajectory(ball_positions, detection_frames, frames):
if len(ball_positions) < 2:
return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
frame_height = frames[0].shape[0]
# Filter to unique positions to reduce interpolation points
unique_positions = [ball_positions[0]]
for pos in ball_positions[1:]:
if abs(pos[0] - unique_positions[-1][0]) > 10 or abs(pos[1] - unique_positions[-1][1]) > 10:
unique_positions.append(pos)
x_coords = [pos[0] for pos in unique_positions]
y_coords = [pos[1] for pos in unique_positions]
times = np.array([i / FRAME_RATE for i in range(len(unique_positions))])
pitch_idx = 0
for i, y in enumerate(y_coords):
if y > frame_height * PITCH_ZONE_Y:
pitch_idx = i
break
pitch_point = unique_positions[pitch_idx]
pitch_frame = detection_frames[pitch_idx]
impact_idx = None
for i in range(1, len(y_coords)):
if (y_coords[i] > frame_height * IMPACT_ZONE_Y and
abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y):
impact_idx = i
break
if impact_idx is None:
impact_idx = len(y_coords) - 1
impact_point = unique_positions[impact_idx]
impact_frame = detection_frames[impact_idx]
x_coords = x_coords[:impact_idx + 1]
y_coords = y_coords[:impact_idx + 1]
times = times[:impact_idx + 1]
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
vis_trajectory = list(zip(x_coords, y_coords))
t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 5) # Reduced points
x_full = fx(t_full)
y_full = fy(t_full)
full_trajectory = list(zip(x_full, y_full))
debug_log = (f"Trajectory estimated successfully\n"
f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})")
return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, debug_log
def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
if not frames:
return "Error: No frames processed", None, None, None
if not full_trajectory or len(ball_positions) < 2:
return "Not enough data (insufficient valid single-ball detections)", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
batsman_area_y = frame_height * 0.8
pitch_x, pitch_y = pitch_point
impact_x, impact_y = impact_point
in_line_threshold = stumps_width_pixels / 2
if pitch_x < stumps_x - in_line_threshold or pitch_x > stumps_x + in_line_threshold:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", full_trajectory, pitch_point, impact_point
if impact_y < batsman_area_y or impact_x < stumps_x - in_line_threshold or impact_x > stumps_x + in_line_threshold:
return f"Not Out (Impact outside line or above batsman at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
hit_stumps = False
for x, y in full_trajectory:
if (abs(x - stumps_x) < in_line_threshold and
abs(y - stumps_y) < frame_height * 0.1):
hit_stumps = True
break
if hit_stumps:
if abs(x - stumps_x) < in_line_threshold * 0.1:
return f"Umpire's Call - Not Out (Ball clips stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path):
if not frames:
return None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
for i, frame in enumerate(frames):
# Draw stumps (single line for efficiency)
cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
(int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
# Draw crease line
cv2.line(frame, (0, int(stumps_y)), (frame_width, int(stumps_y)), (255, 255, 0), 2)
if i in detection_frames and trajectory_points.size > 0:
idx = detection_frames.index(i) + 1
if idx <= len(trajectory_points):
cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
if pitch_point and i == pitch_frame:
x, y = pitch_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 0), -1)
cv2.putText(frame, "Pitching Factor", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
if impact_point and i == impact_frame:
x, y = impact_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
cv2.putText(frame, "Impact Factor", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
if impact_point and i == impact_frame and "Out" in lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)[0]:
cv2.putText(frame, "Wicket Factor", (int(stumps_x) - 50, int(stumps_y) - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 1)
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return output_path
def drs_review(video):
frames, ball_positions, detection_frames, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory(ball_positions, detection_frames, frames)
decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path)
debug_output = f"{debug_log}\n{trajectory_log}"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Optimized Slow-Motion Replay with Pitching Factor (Green), Impact Factor (Red), Wicket Factor (Orange), Stumps (White), Crease (Yellow)")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a video clip of a cricket delivery to get an LBW decision and optimized slow-motion replay showing pitching factor (green circle), impact factor (red circle), wicket factor (orange text), stumps (white lines), and crease line (yellow line)."
)
if __name__ == "__main__":
iface.launch()