DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
6e725f6 verified
raw
history blame
5.17 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
# Load the trained YOLOv8n model from the Space's root directory
model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
# Constants for LBW decision
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
FRAME_RATE = 30 # Default frame rate for video processing
def process_video(video_path):
# Initialize video capture
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame.copy()) # Store original frame
# Detect ball using the trained YOLOv8n model
results = model.predict(frame, conf=0.5) # Adjust confidence threshold if needed
for detection in results[0].boxes:
if detection.cls == 0: # Assuming class 0 is the ball
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
# Draw bounding box on frame for visualization
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame # Update frame with bounding box
cap.release()
return frames, ball_positions
def estimate_trajectory(ball_positions, frames):
# Simplified physics-based trajectory projection
if len(ball_positions) < 2:
return None, None
# Extract x, y coordinates
x_coords = [pos[0] for pos in ball_positions]
y_coords = [pos[1] for pos in ball_positions]
times = np.arange(len(ball_positions)) / FRAME_RATE
# Interpolate to smooth trajectory
try:
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
except:
return None, None
# Project trajectory forward (0.5 seconds post-impact)
t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
x_future = fx(t_future)
y_future = fy(t_future)
return list(zip(x_future, y_future)), t_future
def lbw_decision(ball_positions, trajectory, frames):
# Simplified LBW logic
if not trajectory or len(ball_positions) < 2:
return "Not enough data", None
# Assume stumps are at the bottom center of the frame (calibration needed)
frame_height, frame_width = frames[0].shape[:2]
stumps_x = frame_width / 2
stumps_y = frame_height * 0.9 # Approximate stumps position
stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
# Check pitching point (first detected position)
pitch_x, pitch_y = ball_positions[0]
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
return "Not Out (Pitched outside line)", None
# Check impact point (last detected position)
impact_x, impact_y = ball_positions[-1]
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
return "Not Out (Impact outside line)", None
# Check trajectory hitting stumps
for x, y in trajectory:
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
return "Out", trajectory
return "Not Out (Missing stumps)", trajectory
def generate_slow_motion(frames, trajectory, output_path):
# Generate slow-motion video with ball detection and trajectory overlay
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / 2, (frames[0].shape[1], frames[0].shape[0]))
for frame in frames:
if trajectory:
for x, y in trajectory:
cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
out.write(frame)
out.write(frame) # Duplicate frames for slow-motion effect
out.release()
return output_path
def drs_review(video):
# Process video and generate DRS output
if not os.path.exists(video):
return "Error: Video file not found", None
frames, ball_positions = process_video(video)
trajectory, _ = estimate_trajectory(ball_positions, frames)
decision, trajectory = lbw_decision(ball_positions, trajectory, frames)
# Generate slow-motion replay
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, trajectory, output_path)
return decision, slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision"),
gr.Video(label="Slow-Motion Replay with Ball Detection and Trajectory")
],
title="AI-Powered DRS for LBW in Local Cricket",
description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection and trajectory."
)
if __name__ == "__main__":
iface.launch()