DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
42d2b87 verified
raw
history blame
7.67 kB
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
# Load YOLOv5 model
model = YOLO("best.pt")
class CentroidTracker:
def __init__(self, max_disappeared=50):
self.next_object_id = 0
self.objects = {}
self.disappeared = {}
self.max_disappeared = max_disappeared
def register(self, centroid):
self.objects[self.next_object_id] = centroid
self.disappeared[self.next_object_id] = 0
self.next_object_id += 1
def deregister(self, object_id):
del self.objects[object_id]
del self.disappeared[object_id]
def update(self, rects):
if len(rects) == 0:
for object_id in list(self.disappeared.keys()):
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
return self.objects
input_centroids = np.zeros((len(rects), 2), dtype="int")
for (i, (x1, y1, x2, y2)) in enumerate(rects):
cX = int((x1 + x2) / 2.0)
cY = int((y1 + y2) / 2.0)
input_centroids[i] = (cX, cY)
if len(self.objects) == 0:
for i in range(len(input_centroids)):
self.register(input_centroids[i])
else:
object_ids = list(self.objects.keys())
object_centroids = list(self.objects.values())
D = np.sqrt(((input_centroids[:, None] - object_centroids) ** 2).sum(axis=2))
rows = D.min(axis=1).argsort()
cols = D.argmin(axis=1)[rows]
used_rows = set()
used_cols = set()
for (row, col) in zip(rows, cols):
if row in used_rows or col in used_cols:
continue
object_id = object_ids[col]
self.objects[object_id] = input_centroids[row]
self.disappeared[object_id] = 0
used_rows.add(row)
used_cols.add(col)
unused_rows = set(range(0, D.shape[0])).difference(used_rows)
unused_cols = set(range(0, D.shape[1])).difference(used_cols)
if D.shape[0] >= D.shape[1]:
for row in unused_rows:
self.register(input_centroids[row])
else:
for col in unused_cols:
object_id = object_ids[col]
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
return self.objects
def detect_and_track_ball(video_path, conf_threshold=0.5, iou_threshold=0.5):
"""
Detect and track ball in video, generate pitch map, and predict LBW outcome.
Args:
video_path: Path to uploaded video
conf_threshold: Confidence threshold for detection
iou_threshold: IoU threshold for non-max suppression
Returns:
Tuple of (annotated video path, pitch map image path, LBW decision)
"""
# Initialize tracker
tracker = CentroidTracker(max_disappeared=10)
cap = cv2.VideoCapture(video_path)
output_path = "output_video.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 30.0,
(int(cap.get(3)), int(cap.get(4))))
# Store ball centroids for trajectory
centroids = []
pitch_points = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
rects = []
for box in results[0].boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
label = f"Ball: {conf:.2f}"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
rects.append((x1, y1, x2, y2))
# Update tracker
objects = tracker.update(rects)
for object_id, centroid in objects.items():
cv2.circle(frame, centroid, 5, (0, 0, 255), -1)
centroids.append(centroid)
out.write(frame)
cap.release()
out.release()
# Generate pitch map
pitch_map_path = "pitch_map.png"
fig, ax = plt.subplots(figsize=(8, 4))
ax.set_xlim(0, 22) # Pitch length in meters (approx)
ax.set_ylim(-1.5, 1.5) # Pitch width (approx)
ax.set_xlabel("Length (m)")
ax.set_ylabel("Width (m)")
ax.set_title("Pitch Map with Ball Trajectory")
# Plot stumps
ax.plot([20.12, 20.12], [-0.135, 0.135], 'k-', lw=5) # Stumps at bowling end
ax.plot([0, 0], [-0.135, 0.135], 'k-', lw=5) # Stumps at batting end
ax.plot([0, 20.12], [0, 0], 'k--') # Pitch center line
# Map centroids to pitch coordinates (simplified scaling)
if centroids:
x_coords = [20.12 - (c[1] / cap.get(4)) * 20.12 for c in centroids] # Scale y to pitch length
y_coords = [(c[0] / cap.get(3)) * 2.7 - 1.35 for c in centroids] # Scale x to pitch width
ax.plot(x_coords, y_coords, 'ro-', label="Ball Trajectory")
pitch_points = list(zip(x_coords, y_coords))
ax.legend()
plt.savefig(pitch_map_path)
plt.close()
# LBW Decision (simplified physics-based model)
lbw_decision = "Not Out"
if pitch_points:
# Check pitching, impact, and wickets
pitching = any(0 <= x <= 20.12 and -0.135 <= y <= 0.135 for x, y in pitch_points[:len(pitch_points)//2])
impact = any(18 <= x <= 20.12 for x, y in pitch_points[len(pitch_points)//2:])
# Fit a quadratic curve to predict trajectory post-impact
if len(x_coords) > 2:
t = np.linspace(0, 1, len(x_coords))
f_x = interp1d(t, x_coords, kind='quadratic', fill_value="extrapolate")
f_y = interp1d(t, y_coords, kind='quadratic', fill_value="extrapolate")
t_future = np.array([1.5]) # Predict beyond impact
x_future = f_x(t_future)[0]
y_future = f_y(t_future)[0]
wickets = (18 <= x_future <= 20.12) and (-0.135 <= y_future <= 0.135)
if pitching and impact and wickets:
lbw_decision = "Out"
elif pitching and impact:
lbw_decision = "Umpire's Call" # Marginal case
return output_path, pitch_map_path, lbw_decision
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# DRS Review System for Cricket")
gr.Markdown("Upload a cricket video to analyze ball tracking, pitch mapping, and LBW review. Adjust thresholds for detection accuracy.")
video_input = gr.Video(label="Upload Cricket Video")
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
output_video = gr.Video(label="Annotated Video with Ball Tracking")
output_image = gr.Image(label="Pitch Map")
output_text = gr.Textbox(label="LBW Decision")
submit_button = gr.Button("Analyze DRS")
submit_button.click(
fn=detect_and_track_ball,
inputs=[video_input, conf_slider, iou_slider],
outputs=[output_video, output_image, output_text]
)
demo.launch()