import gradio as gr import torch from ultralytics import YOLO import cv2 import numpy as np from PIL import Image import os import matplotlib.pyplot as plt from scipy.interpolate import interp1d # Load YOLOv5 model model = YOLO("best.pt") class CentroidTracker: def __init__(self, max_disappeared=50): self.next_object_id = 0 self.objects = {} self.disappeared = {} self.max_disappeared = max_disappeared def register(self, centroid): self.objects[self.next_object_id] = centroid self.disappeared[self.next_object_id] = 0 self.next_object_id += 1 def deregister(self, object_id): del self.objects[object_id] del self.disappeared[object_id] def update(self, rects): if len(rects) == 0: for object_id in list(self.disappeared.keys()): self.disappeared[object_id] += 1 if self.disappeared[object_id] > self.max_disappeared: self.deregister(object_id) return self.objects input_centroids = np.zeros((len(rects), 2), dtype="int") for (i, (x1, y1, x2, y2)) in enumerate(rects): cX = int((x1 + x2) / 2.0) cY = int((y1 + y2) / 2.0) input_centroids[i] = (cX, cY) if len(self.objects) == 0: for i in range(len(input_centroids)): self.register(input_centroids[i]) else: object_ids = list(self.objects.keys()) object_centroids = list(self.objects.values()) D = np.sqrt(((input_centroids[:, None] - object_centroids) ** 2).sum(axis=2)) rows = D.min(axis=1).argsort() cols = D.argmin(axis=1)[rows] used_rows = set() used_cols = set() for (row, col) in zip(rows, cols): if row in used_rows or col in used_cols: continue object_id = object_ids[col] self.objects[object_id] = input_centroids[row] self.disappeared[object_id] = 0 used_rows.add(row) used_cols.add(col) unused_rows = set(range(0, D.shape[0])).difference(used_rows) unused_cols = set(range(0, D.shape[1])).difference(used_cols) if D.shape[0] >= D.shape[1]: for row in unused_rows: self.register(input_centroids[row]) else: for col in unused_cols: object_id = object_ids[col] self.disappeared[object_id] += 1 if self.disappeared[object_id] > self.max_disappeared: self.deregister(object_id) return self.objects def detect_and_track_ball(video_path, conf_threshold=0.5, iou_threshold=0.5): """ Detect and track ball in video, generate pitch map, and predict LBW outcome. Args: video_path: Path to uploaded video conf_threshold: Confidence threshold for detection iou_threshold: IoU threshold for non-max suppression Returns: Tuple of (annotated video path, pitch map image path, LBW decision) """ # Initialize tracker tracker = CentroidTracker(max_disappeared=10) cap = cv2.VideoCapture(video_path) output_path = "output_video.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4)))) # Store ball centroids for trajectory centroids = [] pitch_points = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold) rects = [] for box in results[0].boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) conf = box.conf[0] label = f"Ball: {conf:.2f}" cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) rects.append((x1, y1, x2, y2)) # Update tracker objects = tracker.update(rects) for object_id, centroid in objects.items(): cv2.circle(frame, centroid, 5, (0, 0, 255), -1) centroids.append(centroid) out.write(frame) cap.release() out.release() # Generate pitch map pitch_map_path = "pitch_map.png" fig, ax = plt.subplots(figsize=(8, 4)) ax.set_xlim(0, 22) # Pitch length in meters (approx) ax.set_ylim(-1.5, 1.5) # Pitch width (approx) ax.set_xlabel("Length (m)") ax.set_ylabel("Width (m)") ax.set_title("Pitch Map with Ball Trajectory") # Plot stumps ax.plot([20.12, 20.12], [-0.135, 0.135], 'k-', lw=5) # Stumps at bowling end ax.plot([0, 0], [-0.135, 0.135], 'k-', lw=5) # Stumps at batting end ax.plot([0, 20.12], [0, 0], 'k--') # Pitch center line # Map centroids to pitch coordinates (simplified scaling) if centroids: x_coords = [20.12 - (c[1] / cap.get(4)) * 20.12 for c in centroids] # Scale y to pitch length y_coords = [(c[0] / cap.get(3)) * 2.7 - 1.35 for c in centroids] # Scale x to pitch width ax.plot(x_coords, y_coords, 'ro-', label="Ball Trajectory") pitch_points = list(zip(x_coords, y_coords)) ax.legend() plt.savefig(pitch_map_path) plt.close() # LBW Decision (simplified physics-based model) lbw_decision = "Not Out" if pitch_points: # Check pitching, impact, and wickets pitching = any(0 <= x <= 20.12 and -0.135 <= y <= 0.135 for x, y in pitch_points[:len(pitch_points)//2]) impact = any(18 <= x <= 20.12 for x, y in pitch_points[len(pitch_points)//2:]) # Fit a quadratic curve to predict trajectory post-impact if len(x_coords) > 2: t = np.linspace(0, 1, len(x_coords)) f_x = interp1d(t, x_coords, kind='quadratic', fill_value="extrapolate") f_y = interp1d(t, y_coords, kind='quadratic', fill_value="extrapolate") t_future = np.array([1.5]) # Predict beyond impact x_future = f_x(t_future)[0] y_future = f_y(t_future)[0] wickets = (18 <= x_future <= 20.12) and (-0.135 <= y_future <= 0.135) if pitching and impact and wickets: lbw_decision = "Out" elif pitching and impact: lbw_decision = "Umpire's Call" # Marginal case return output_path, pitch_map_path, lbw_decision # Gradio interface with gr.Blocks() as demo: gr.Markdown("# DRS Review System for Cricket") gr.Markdown("Upload a cricket video to analyze ball tracking, pitch mapping, and LBW review. Adjust thresholds for detection accuracy.") video_input = gr.Video(label="Upload Cricket Video") conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold") iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold") output_video = gr.Video(label="Annotated Video with Ball Tracking") output_image = gr.Image(label="Pitch Map") output_text = gr.Textbox(label="LBW Decision") submit_button = gr.Button("Analyze DRS") submit_button.click( fn=detect_and_track_ball, inputs=[video_input, conf_slider, iou_slider], outputs=[output_video, output_image, output_text] ) demo.launch()