|
import gradio as gr |
|
import torch |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
import matplotlib.pyplot as plt |
|
from scipy.interpolate import interp1d |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
class CentroidTracker: |
|
def __init__(self, max_disappeared=50): |
|
self.next_object_id = 0 |
|
self.objects = {} |
|
self.disappeared = {} |
|
self.max_disappeared = max_disappeared |
|
|
|
def register(self, centroid): |
|
self.objects[self.next_object_id] = centroid |
|
self.disappeared[self.next_object_id] = 0 |
|
self.next_object_id += 1 |
|
|
|
def deregister(self, object_id): |
|
del self.objects[object_id] |
|
del self.disappeared[object_id] |
|
|
|
def update(self, rects): |
|
if len(rects) == 0: |
|
for object_id in list(self.disappeared.keys()): |
|
self.disappeared[object_id] += 1 |
|
if self.disappeared[object_id] > self.max_disappeared: |
|
self.deregister(object_id) |
|
return self.objects |
|
|
|
input_centroids = np.zeros((len(rects), 2), dtype="int") |
|
for (i, (x1, y1, x2, y2)) in enumerate(rects): |
|
cX = int((x1 + x2) / 2.0) |
|
cY = int((y1 + y2) / 2.0) |
|
input_centroids[i] = (cX, cY) |
|
|
|
if len(self.objects) == 0: |
|
for i in range(len(input_centroids)): |
|
self.register(input_centroids[i]) |
|
else: |
|
object_ids = list(self.objects.keys()) |
|
object_centroids = list(self.objects.values()) |
|
D = np.sqrt(((input_centroids[:, None] - object_centroids) ** 2).sum(axis=2)) |
|
rows = D.min(axis=1).argsort() |
|
cols = D.argmin(axis=1)[rows] |
|
used_rows = set() |
|
used_cols = set() |
|
for (row, col) in zip(rows, cols): |
|
if row in used_rows or col in used_cols: |
|
continue |
|
object_id = object_ids[col] |
|
self.objects[object_id] = input_centroids[row] |
|
self.disappeared[object_id] = 0 |
|
used_rows.add(row) |
|
used_cols.add(col) |
|
unused_rows = set(range(0, D.shape[0])).difference(used_rows) |
|
unused_cols = set(range(0, D.shape[1])).difference(used_cols) |
|
if D.shape[0] >= D.shape[1]: |
|
for row in unused_rows: |
|
self.register(input_centroids[row]) |
|
else: |
|
for col in unused_cols: |
|
object_id = object_ids[col] |
|
self.disappeared[object_id] += 1 |
|
if self.disappeared[object_id] > self.max_disappeared: |
|
self.deregister(object_id) |
|
return self.objects |
|
|
|
def detect_and_track_ball(video_path, conf_threshold=0.5, iou_threshold=0.5): |
|
""" |
|
Detect and track ball in video, generate pitch map, and predict LBW outcome. |
|
|
|
Args: |
|
video_path: Path to uploaded video |
|
conf_threshold: Confidence threshold for detection |
|
iou_threshold: IoU threshold for non-max suppression |
|
|
|
Returns: |
|
Tuple of (annotated video path, pitch map image path, LBW decision) |
|
""" |
|
|
|
tracker = CentroidTracker(max_disappeared=10) |
|
cap = cv2.VideoCapture(video_path) |
|
output_path = "output_video.mp4" |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, 30.0, |
|
(int(cap.get(3)), int(cap.get(4)))) |
|
|
|
|
|
centroids = [] |
|
pitch_points = [] |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold) |
|
|
|
rects = [] |
|
for box in results[0].boxes: |
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
conf = box.conf[0] |
|
label = f"Ball: {conf:.2f}" |
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
rects.append((x1, y1, x2, y2)) |
|
|
|
|
|
objects = tracker.update(rects) |
|
for object_id, centroid in objects.items(): |
|
cv2.circle(frame, centroid, 5, (0, 0, 255), -1) |
|
centroids.append(centroid) |
|
|
|
out.write(frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
|
|
pitch_map_path = "pitch_map.png" |
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
ax.set_xlim(0, 22) |
|
ax.set_ylim(-1.5, 1.5) |
|
ax.set_xlabel("Length (m)") |
|
ax.set_ylabel("Width (m)") |
|
ax.set_title("Pitch Map with Ball Trajectory") |
|
|
|
|
|
ax.plot([20.12, 20.12], [-0.135, 0.135], 'k-', lw=5) |
|
ax.plot([0, 0], [-0.135, 0.135], 'k-', lw=5) |
|
ax.plot([0, 20.12], [0, 0], 'k--') |
|
|
|
|
|
if centroids: |
|
x_coords = [20.12 - (c[1] / cap.get(4)) * 20.12 for c in centroids] |
|
y_coords = [(c[0] / cap.get(3)) * 2.7 - 1.35 for c in centroids] |
|
ax.plot(x_coords, y_coords, 'ro-', label="Ball Trajectory") |
|
pitch_points = list(zip(x_coords, y_coords)) |
|
|
|
ax.legend() |
|
plt.savefig(pitch_map_path) |
|
plt.close() |
|
|
|
|
|
lbw_decision = "Not Out" |
|
if pitch_points: |
|
|
|
pitching = any(0 <= x <= 20.12 and -0.135 <= y <= 0.135 for x, y in pitch_points[:len(pitch_points)//2]) |
|
impact = any(18 <= x <= 20.12 for x, y in pitch_points[len(pitch_points)//2:]) |
|
|
|
|
|
if len(x_coords) > 2: |
|
t = np.linspace(0, 1, len(x_coords)) |
|
f_x = interp1d(t, x_coords, kind='quadratic', fill_value="extrapolate") |
|
f_y = interp1d(t, y_coords, kind='quadratic', fill_value="extrapolate") |
|
t_future = np.array([1.5]) |
|
x_future = f_x(t_future)[0] |
|
y_future = f_y(t_future)[0] |
|
wickets = (18 <= x_future <= 20.12) and (-0.135 <= y_future <= 0.135) |
|
|
|
if pitching and impact and wickets: |
|
lbw_decision = "Out" |
|
elif pitching and impact: |
|
lbw_decision = "Umpire's Call" |
|
|
|
return output_path, pitch_map_path, lbw_decision |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# DRS Review System for Cricket") |
|
gr.Markdown("Upload a cricket video to analyze ball tracking, pitch mapping, and LBW review. Adjust thresholds for detection accuracy.") |
|
|
|
video_input = gr.Video(label="Upload Cricket Video") |
|
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold") |
|
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold") |
|
output_video = gr.Video(label="Annotated Video with Ball Tracking") |
|
output_image = gr.Image(label="Pitch Map") |
|
output_text = gr.Textbox(label="LBW Decision") |
|
submit_button = gr.Button("Analyze DRS") |
|
|
|
submit_button.click( |
|
fn=detect_and_track_ball, |
|
inputs=[video_input, conf_slider, iou_slider], |
|
outputs=[output_video, output_image, output_text] |
|
) |
|
|
|
demo.launch() |