File size: 7,325 Bytes
41c03cf
 
ba9faee
 
 
6110fb8
ba9faee
 
 
 
 
6110fb8
 
 
ba9faee
 
 
 
 
 
 
6110fb8
 
 
 
 
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
41c03cf
 
ba9faee
6110fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6110fb8
 
ba9faee
 
 
 
 
6110fb8
 
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6110fb8
 
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6110fb8
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41c03cf
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6110fb8
ba9faee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import cv2
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import gradio as gr
import os
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

# Load YOLOv5 model from yolov5 repository
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression, xywh2xyxy

# Cricket pitch dimensions (in meters)
PITCH_LENGTH = 20.12  # Length of cricket pitch (stumps to stumps)
PITCH_WIDTH = 3.05    # Width of pitch
STUMP_HEIGHT = 0.71   # Stump height
STUMP_WIDTH = 0.2286  # Stump width (including bails)

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = attempt_load("best.pt", map_location=device)
model.eval()

# Function to process video and detect ball
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_rate = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    positions = []
    frame_numbers = []
    bounce_frame = None
    bounce_point = None

    while cap.isOpened():
        frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
        ret, frame = cap.read()
        if not ret:
            break

        # Preprocess frame for YOLOv5
        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img).to(device).float() / 255.0
        img = img.permute(2, 0, 1).unsqueeze(0)  # [1, 3, H, W]

        # Run inference
        with torch.no_grad():
            pred = model(img)[0]
        pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)

        # Process detections
        for det in pred:
            if det is not None and len(det):
                det = xywh2xyxy(det)  # Convert to [x1, y1, x2, y2]
                for *xyxy, conf, cls in det:
                    x_center = (xyxy[0] + xyxy[2]) / 2
                    y_center = (xyxy[1] + xyxy[3]) / 2
                    positions.append((x_center.item(), y_center.item()))
                    frame_numbers.append(frame_num)

                    # Detect bounce (lowest y_center point)
                    if bounce_frame is None or y_center > positions[bounce_frame][1]:
                        bounce_frame = len(frame_numbers) - 1
                        bounce_point = (x_center.item(), y_center.item())

    cap.release()
    return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height

# Polynomial function for trajectory fitting
def poly_func(x, a, b, c):
    return a * x**2 + b * x + c

# Predict trajectory and LBW decision
def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
    if len(positions) < 3:
        return None, "Insufficient detections for trajectory prediction"

    x_coords = [p[0] for p in positions]
    y_coords = [p[1] for p in positions]
    frames = np.array(frame_numbers)

    # Fit polynomial to x and y coordinates
    try:
        popt_x, _ = curve_fit(poly_func, frames, x_coords)
        popt_y, _ = curve_fit(poly_func, frames, y_coords)
    except:
        return None, "Failed to fit trajectory"

    # Extrapolate to stumps
    frame_max = max(frames) + 10
    future_frames = np.linspace(min(frames), frame_max, 100)
    x_pred = poly_func(future_frames, *popt_x)
    y_pred = poly_func(future_frames, *popt_y)

    # Check if trajectory hits stumps
    stump_x = frame_width / 2
    stump_y = frame_height
    stump_hit = False
    for x, y in zip(x_pred, y_pred):
        if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
            stump_hit = True
            break

    lbw_decision = "OUT" if stump_hit else "NOT OUT"
    return list(zip(future_frames, x_pred, y_pred)), lbw_decision

# Map pitch location
def map_pitch(bounce_point, frame_width, frame_height):
    if bounce_point is None:
        return None, "No bounce detected"

    x, y = bounce_point
    pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
    pitch_y = (1 - y / frame_height) * PITCH_LENGTH
    return pitch_x, pitch_y

# Estimate ball speed
def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
    if len(positions) < 2:
        return None, "Insufficient detections for speed estimation"

    distances = []
    for i in range(1, len(positions)):
        x1, y1 = positions[i-1]
        x2, y2 = positions[i]
        pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
        distances.append(pixel_dist)

    pixel_to_meter = PITCH_LENGTH / frame_width
    distances_m = [d * pixel_to_meter for d in distances]
    time_interval = 1 / frame_rate
    speeds = [d / time_interval for d in distances_m]
    avg_speed_kmh = np.mean(speeds) * 3.6
    return avg_speed_kmh, "Speed calculated successfully"

# Create pitch map visualization
def create_pitch_map(pitch_x, pitch_y):
    fig = go.Figure()
    fig.add_shape(
        type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
        line=dict(color="Green"), fillcolor="Green", opacity=0.3
    )
    fig.add_shape(
        type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
        line=dict(color="Brown"), fillcolor="Brown"
    )
    if pitch_x is not None and pitch_y is not None:
        fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
    
    fig.update_layout(
        title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
        xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
    )
    return fig

# Main Gradio function
def drs_analysis(video):
    video_path = "temp_video.mp4"
    with open(video_path, "wb") as f:
        f.write(video.read())

    positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
    if not positions:
        return None, None, "No ball detected in video", None

    trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
    if trajectory is None:
        return None, None, lbw_decision, None

    pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
    speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)

    trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
    fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
    fig_traj.update_yaxes(autorange="reversed")

    fig_pitch = create_pitch_map(pitch_x, pitch_y)

    os.remove(video_path)

    return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Cricket DRS Analysis")
    video_input = gr.Video(label="Upload Video Clip")
    btn = gr.Button("Analyze")
    trajectory_output = gr.Plot(label="Ball Trajectory")
    pitch_output = gr.Plot(label="Pitch Map")
    text_output = gr.Textbox(label="Analysis Results")
    video_output = gr.Video(label="Processed Video")
    btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])

if __name__ == "__main__":
    demo.launch()