DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
c9d4715 verified
raw
history blame
8.54 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants
STUMPS_WIDTH = 0.2286 # meters
BALL_DIAMETER = 0.073 # meters
FRAME_RATE = 30 # Input video frame rate
SLOW_MOTION_FACTOR = 6
CONF_THRESHOLD = 0.3
RESIZE_DIM = 640 # Resize frames for faster processing
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
debug_log = []
frame_count = 0
max_frames = FRAME_RATE * 3 # Limit to 3 seconds of frames
while cap.isOpened() and frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# Resize frame for faster YOLO inference
frame_resized = cv2.resize(frame, (RESIZE_DIM, RESIZE_DIM), interpolation=cv2.INTER_AREA)
frames.append(frame.copy()) # Store original frame
# Detect ball
results = model.predict(frame_resized, conf=CONF_THRESHOLD, imgsz=RESIZE_DIM)
detections = 0
scale_x, scale_y = frame.shape[1] / RESIZE_DIM, frame.shape[0] / RESIZE_DIM
for detection in results[0].boxes:
if detection.cls == 0: # Ball class
detections += 1
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
# Scale coordinates back to original frame size
x1, x2 = x1 * scale_x, x2 * scale_x
y1, y2 = y1 * scale_y, y2 * scale_y
ball_center = [(x1 + x2) / 2, (y1 + y2) / 2]
ball_positions.append(ball_center)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
cap.release()
if not ball_positions:
debug_log.append("No balls detected in any frame")
else:
debug_log.append(f"Total ball detections: {len(ball_positions)}")
return frames, ball_positions, "\n".join(debug_log)
def estimate_trajectory(ball_positions, frames):
if len(ball_positions) < 2:
return [], [], "Error: Fewer than 2 ball detections for trajectory"
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except Exception as e:
return [], [], f"Error in trajectory interpolation: {str(e)}"
# Interpolate for all frames and future projection
t_all = np.linspace(0, times[-1] + 0.5, len(frames) + 10)
x_all = fx(t_all)
y_all = fy(t_all)
trajectory = list(zip(x_all, y_all))
return trajectory, t_all, "Trajectory estimated successfully"
def detect_impact_point(ball_positions, frames):
if len(ball_positions) < 3:
return ball_positions[-1] if ball_positions else None, len(ball_positions) - 1
# Assume batsman is near stumps (bottom center of frame)
frame_height, frame_width = frames[0].shape[:2]
batsman_x = frame_width / 2
batsman_y = frame_height * 0.8 # Approximate batsman position
min_dist = float('inf')
impact_idx = len(ball_positions) - 1
impact_point = ball_positions[-1]
# Look for sudden change in trajectory or proximity to batsman
for i in range(1, len(ball_positions) - 1):
x, y = ball_positions[i]
prev_x, prev_y = ball_positions[i-1]
next_x, next_y = ball_positions[i+1]
# Check direction change (simplified)
dx1, dy1 = x - prev_x, y - prev_y
dx2, dy2 = next_x - x, next_y - y
angle_change = abs(np.arctan2(dy2, dx2) - np.arctan2(dy1, dx1))
dist_to_batsman = np.sqrt((x - batsman_x)**2 + (y - batsman_y)**2)
if angle_change > np.pi/4 or dist_to_batsman < frame_width * 0.1:
impact_idx = i
impact_point = ball_positions[i]
break
return impact_point, impact_idx
def lbw_decision(ball_positions, trajectory, frames):
if not frames:
return "Error: No frames processed", None, None, None
if len(ball_positions) < 2:
return "Not enough data (insufficient ball detections)", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
pitch_point = ball_positions[0]
impact_point, impact_idx = detect_impact_point(ball_positions, frames)
# Check pitching point
pitch_x, pitch_y = pitch_point
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x Moderation: x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
# Check impact point
impact_x, impact_y = impact_point
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
# Check trajectory hitting stumps
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path):
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
for i, frame in enumerate(frames):
# Draw trajectory up to current frame
traj_points = [p for j, p in enumerate(trajectory) if j / SLOW_MOTION_FACTOR <= i]
for x, y in traj_points:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots
# Draw pitch point in early frames
if pitch_point and i < len(frames) // 2:
x, y = pitch_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# Draw impact point around impact frame
if impact_point and abs(i - impact_idx) < 5:
x, y = impact_point
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return output_path
def drs_review(video):
frames, ball_positions, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
_, impact_idx = detect_impact_point(ball_positions, frames)
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path)
debug_output = f"{debug_log}\n{trajectory_log}"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
],
title="AI-Powered DRS for LBW in Local Cricket",
)
if __name__ == "__main__":
iface.launch()