File size: 7,325 Bytes
41c03cf ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 41c03cf ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 6110fb8 ba9faee 41c03cf ba9faee 6110fb8 ba9faee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import cv2
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import gradio as gr
import os
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit
# Load YOLOv5 model from yolov5 repository
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression, xywh2xyxy
# Cricket pitch dimensions (in meters)
PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
PITCH_WIDTH = 3.05 # Width of pitch
STUMP_HEIGHT = 0.71 # Stump height
STUMP_WIDTH = 0.2286 # Stump width (including bails)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = attempt_load("best.pt", map_location=device)
model.eval()
# Function to process video and detect ball
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
frame_rate = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
positions = []
frame_numbers = []
bounce_frame = None
bounce_point = None
while cap.isOpened():
frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
ret, frame = cap.read()
if not ret:
break
# Preprocess frame for YOLOv5
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img).to(device).float() / 255.0
img = img.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
# Run inference
with torch.no_grad():
pred = model(img)[0]
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
# Process detections
for det in pred:
if det is not None and len(det):
det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
for *xyxy, conf, cls in det:
x_center = (xyxy[0] + xyxy[2]) / 2
y_center = (xyxy[1] + xyxy[3]) / 2
positions.append((x_center.item(), y_center.item()))
frame_numbers.append(frame_num)
# Detect bounce (lowest y_center point)
if bounce_frame is None or y_center > positions[bounce_frame][1]:
bounce_frame = len(frame_numbers) - 1
bounce_point = (x_center.item(), y_center.item())
cap.release()
return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
# Polynomial function for trajectory fitting
def poly_func(x, a, b, c):
return a * x**2 + b * x + c
# Predict trajectory and LBW decision
def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
if len(positions) < 3:
return None, "Insufficient detections for trajectory prediction"
x_coords = [p[0] for p in positions]
y_coords = [p[1] for p in positions]
frames = np.array(frame_numbers)
# Fit polynomial to x and y coordinates
try:
popt_x, _ = curve_fit(poly_func, frames, x_coords)
popt_y, _ = curve_fit(poly_func, frames, y_coords)
except:
return None, "Failed to fit trajectory"
# Extrapolate to stumps
frame_max = max(frames) + 10
future_frames = np.linspace(min(frames), frame_max, 100)
x_pred = poly_func(future_frames, *popt_x)
y_pred = poly_func(future_frames, *popt_y)
# Check if trajectory hits stumps
stump_x = frame_width / 2
stump_y = frame_height
stump_hit = False
for x, y in zip(x_pred, y_pred):
if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
stump_hit = True
break
lbw_decision = "OUT" if stump_hit else "NOT OUT"
return list(zip(future_frames, x_pred, y_pred)), lbw_decision
# Map pitch location
def map_pitch(bounce_point, frame_width, frame_height):
if bounce_point is None:
return None, "No bounce detected"
x, y = bounce_point
pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
pitch_y = (1 - y / frame_height) * PITCH_LENGTH
return pitch_x, pitch_y
# Estimate ball speed
def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
if len(positions) < 2:
return None, "Insufficient detections for speed estimation"
distances = []
for i in range(1, len(positions)):
x1, y1 = positions[i-1]
x2, y2 = positions[i]
pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
distances.append(pixel_dist)
pixel_to_meter = PITCH_LENGTH / frame_width
distances_m = [d * pixel_to_meter for d in distances]
time_interval = 1 / frame_rate
speeds = [d / time_interval for d in distances_m]
avg_speed_kmh = np.mean(speeds) * 3.6
return avg_speed_kmh, "Speed calculated successfully"
# Create pitch map visualization
def create_pitch_map(pitch_x, pitch_y):
fig = go.Figure()
fig.add_shape(
type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
line=dict(color="Green"), fillcolor="Green", opacity=0.3
)
fig.add_shape(
type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
line=dict(color="Brown"), fillcolor="Brown"
)
if pitch_x is not None and pitch_y is not None:
fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
fig.update_layout(
title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
)
return fig
# Main Gradio function
def drs_analysis(video):
video_path = "temp_video.mp4"
with open(video_path, "wb") as f:
f.write(video.read())
positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
if not positions:
return None, None, "No ball detected in video", None
trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
if trajectory is None:
return None, None, lbw_decision, None
pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
fig_traj.update_yaxes(autorange="reversed")
fig_pitch = create_pitch_map(pitch_x, pitch_y)
os.remove(video_path)
return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Cricket DRS Analysis")
video_input = gr.Video(label="Upload Video Clip")
btn = gr.Button("Analyze")
trajectory_output = gr.Plot(label="Ball Trajectory")
pitch_output = gr.Plot(label="Pitch Map")
text_output = gr.Textbox(label="Analysis Results")
video_output = gr.Video(label="Processed Video")
btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
if __name__ == "__main__":
demo.launch() |