DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
d41a272 verified
raw
history blame
8.91 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
FRAME_RATE = 20 # Input video frame rate (reduced to 20 FPS)
SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS (slower playback without being too slow)
CONF_THRESHOLD = 0.25 # Confidence threshold for detection
IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
detection_frames = [] # Track frames with detections
debug_log = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
frames.append(frame.copy())
results = model.predict(frame, conf=CONF_THRESHOLD)
detections = 0
for detection in results[0].boxes:
if detection.cls == 0: # Assuming class 0 is the ball
detections += 1
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
detection_frames.append(frame_count - 1) # Store frame index (0-based)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
cap.release()
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
return frames, ball_positions, detection_frames, "\n".join(debug_log)
def estimate_trajectory(ball_positions, frames):
if len(ball_positions) < 2:
return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
frame_height = frames[0].shape[0]
# Extract x, y coordinates
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
# Detect the pitch point: find when the ball touches the ground
pitch_point = None
for i, y in enumerate(y_coords):
if y > frame_height * 0.75: # Threshold for ground contact (near the bottom of the frame)
pitch_point = ball_positions[i]
break
# Find impact point (closest to batsman, near stumps)
impact_idx = None
for i, y in enumerate(y_coords):
if y > frame_height * IMPACT_ZONE_Y: # Ball is near stumps/batsman
impact_idx = i
break
if impact_idx is None:
impact_idx = len(ball_positions) - 1 # Fallback to last detection
impact_point = ball_positions[impact_idx]
# Use positions up to impact for interpolation
x_coords = x_coords[:impact_idx + 1]
y_coords = y_coords[:impact_idx + 1]
times = times[:impact_idx + 1]
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return None, None, None, f"Error in trajectory interpolation: {str(e)}"
# Project trajectory (detected + future for LBW decision)
t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
x_full = fx(t_full)
y_full = fy(t_full)
trajectory = list(zip(x_full, y_full))
return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
if not frames:
return "Error: No frames processed", None, None, None
if not trajectory or len(ball_positions) < 2:
return "Not enough data (insufficient ball detections)", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9 # Position of the stumps at the bottom of the frame
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
pitch_x, pitch_y = pitch_point
impact_x, impact_y = impact_point
# Check pitching point - the ball should land between stumps
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
# Check impact point - the ball should hit within the stumps area
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
# Check trajectory hitting stumps
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path):
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
pitch_point_detected = False
impact_point_detected = False
for i, frame in enumerate(frames):
# Draw trajectory (blue line) only for frames with detections
if i in detection_frames and trajectory_points.size > 0:
cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
# Draw pitch point (red circle with label) when the ball touches the ground
if pitch_point and not pitch_point_detected:
x, y = pitch_point
if y > frame.shape[0] * 0.75: # Adjust this threshold for the ground position
pitch_point_detected = True
if pitch_point_detected:
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# Draw impact point (yellow circle with label) when ball is near stumps
if impact_point and not impact_point_detected:
x, y = impact_point
if y > frame.shape[0] * 0.85: # Adjust this threshold for impact point
impact_point_detected = True
if impact_point_detected:
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
# Write frames to output video
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return output_path
def drs_review(video):
frames, ball_positions, detection_frames, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video", None
trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path)
return f"DRS Decision: {decision}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision"),
gr.Video(label="Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
)
if __name__ == "__main__":
iface.launch()