AjaykumarPilla commited on
Commit
ba9faee
·
verified ·
1 Parent(s): 3320b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -82
app.py CHANGED
@@ -1,87 +1,198 @@
1
- import gradio as gr
2
- import torch
3
  import cv2
4
  import numpy as np
5
- import matplotlib.pyplot as plt
6
- from pathlib import Path
7
-
8
- # Load the trained YOLOv5 model from a local path (directly using PyTorch)
9
- def load_model(model_path):
10
- # Load model using torch.load, ensuring no downloading from Hugging Face
11
- model = torch.load(model_path, map_location="cpu") # Load model from the local path
12
- model.eval() # Set the model to evaluation mode
13
- return model
14
-
15
- # Load the model directly from the local file
16
- model = load_model("best.pt") # Ensure 'best.pt' is in the correct path
17
-
18
- # Function to process the video and calculate ball trajectory, speed, and visualize the pitch
19
- def process_video(video_file):
20
- # Load video file using OpenCV
21
- video = cv2.VideoCapture(video_file.name)
22
- ball_positions = []
23
- speed_data = []
24
-
25
- frame_count = 0
26
- last_position = None
27
-
28
- while video.isOpened():
29
- ret, frame = video.read()
 
 
 
 
 
 
 
30
  if not ret:
31
  break
32
-
33
- frame_count += 1
34
-
35
- # Run the model on the frame to detect the ball
36
- results = model(frame) # Using the pre-trained YOLOv5 model
37
-
38
- # Extract the ball position (assuming class 0 = ball)
39
- ball_detections = results.pandas().xywh
40
- ball = ball_detections[ball_detections['class'] == 0] # Class 0 for ball (adjust if needed)
41
-
42
- if not ball.empty:
43
- ball_x = ball.iloc[0]['xmin'] + (ball.iloc[0]['xmax'] - ball.iloc[0]['xmin']) / 2
44
- ball_y = ball.iloc[0]['ymin'] + (ball.iloc[0]['ymax'] - ball.iloc[0]['ymin']) / 2
45
- ball_positions.append((frame_count, ball_x, ball_y)) # Track position in each frame
46
-
47
- if last_position is not None:
48
- # Calculate speed based on pixel displacement between frames
49
- distance = np.sqrt((ball_x - last_position[1]) ** 2 + (ball_y - last_position[2]) ** 2)
50
- fps = video.get(cv2.CAP_PROP_FPS) # Frames per second of the video
51
- speed = distance * fps # Speed = distance / time (time between frames is 1/fps)
52
- speed_data.append(speed)
53
-
54
- last_position = (frame_count, ball_x, ball_y) # Update last position
55
-
56
- video.release()
57
-
58
- # Ball trajectory plot
59
- plot_trajectory(ball_positions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Return results
62
- avg_speed = np.mean(speed_data) if speed_data else 0
63
- return f"Average Ball Speed: {avg_speed:.2f} pixels per second"
64
-
65
- # Function to plot ball trajectory using matplotlib
66
- def plot_trajectory(ball_positions):
67
- x_positions = [pos[1] for pos in ball_positions]
68
- y_positions = [pos[2] for pos in ball_positions]
69
-
70
- plt.figure(figsize=(10, 6))
71
- plt.plot(x_positions, y_positions, label="Ball Trajectory", color='b')
72
- plt.title("Ball Trajectory on Pitch")
73
- plt.xlabel("X Position (pitch width)")
74
- plt.ylabel("Y Position (pitch length)")
75
- plt.grid(True)
76
- plt.legend()
77
- plt.show()
78
-
79
- # Gradio interface for the app
80
- iface = gr.Interface(
81
- fn=process_video, # Function to call when video is uploaded
82
- inputs=gr.inputs.File(label="Upload a Video File"), # File input (video)
83
- outputs="text", # Output the result as text
84
- live=True # Keep the interface live
85
- )
86
-
87
- iface.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ from ultralytics import YOLO
7
+ import gradio as gr
8
+ import os
9
+ from scipy.interpolate import interp1d
10
+ from scipy.optimize import curve_fit
11
+
12
+ # Load YOLOv5 model
13
+ model = YOLO("best.pt") # Path to your best.pt
14
+
15
+ # Cricket pitch dimensions (in meters)
16
+ PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
17
+ PITCH_WIDTH = 3.05 # Width of pitch
18
+ STUMP_HEIGHT = 0.71 # Stump height
19
+ STUMP_WIDTH = 0.2286 # Stump width (including bails)
20
+
21
+ # Function to process video and detect ball
22
+ def process_video(video_path):
23
+ cap = cv2.VideoCapture(video_path)
24
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
25
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
26
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
27
+ positions = []
28
+ frame_numbers = []
29
+ bounce_frame = None
30
+ bounce_point = None
31
+
32
+ while cap.isOpened():
33
+ frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
34
+ ret, frame = cap.read()
35
  if not ret:
36
  break
37
+
38
+ # Run YOLOv5 detection
39
+ results = model(frame)
40
+ detections = results[0].boxes.xywh.cpu().numpy() # [x_center, y_center, width, height]
41
+
42
+ for det in detections:
43
+ x_center, y_center, _, _ = det
44
+ positions.append((x_center, y_center))
45
+ frame_numbers.append(frame_num)
46
+
47
+ # Detect bounce (lowest y_center point)
48
+ if bounce_frame is None or y_center > positions[bounce_frame][1]:
49
+ bounce_frame = len(frame_numbers) - 1
50
+ bounce_point = (x_center, y_center)
51
+
52
+ cap.release()
53
+ return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
54
+
55
+ # Polynomial function for trajectory fitting
56
+ def poly_func(x, a, b, c):
57
+ return a * x**2 + b * x + c
58
+
59
+ # Predict trajectory and LBW decision
60
+ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
61
+ if len(positions) < 3:
62
+ return None, "Insufficient detections for trajectory prediction"
63
+
64
+ x_coords = [p[0] for p in positions]
65
+ y_coords = [p[1] for p in positions]
66
+ frames = np.array(frame_numbers)
67
+
68
+ # Fit polynomial to x and y coordinates
69
+ try:
70
+ popt_x, _ = curve_fit(poly_func, frames, x_coords)
71
+ popt_y, _ = curve_fit(poly_func, frames, y_coords)
72
+ except:
73
+ return None, "Failed to fit trajectory"
74
+
75
+ # Extrapolate to stumps (assume stumps at y=frame_height)
76
+ frame_max = max(frames) + 10 # Predict 10 frames ahead
77
+ future_frames = np.linspace(min(frames), frame_max, 100)
78
+ x_pred = poly_func(future_frames, *popt_x)
79
+ y_pred = poly_func(future_frames, *popt_y)
80
+
81
+ # Check if trajectory hits stumps
82
+ stump_x = frame_width / 2 # Assume stumps at center of frame
83
+ stump_y = frame_height # Assume stumps at bottom of frame
84
+ stump_hit = False
85
+ for x, y in zip(x_pred, y_pred):
86
+ if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
87
+ stump_hit = True
88
+ break
89
+
90
+ lbw_decision = "OUT" if stump_hit else "NOT OUT"
91
+ return list(zip(future_frames, x_pred, y_pred)), lbw_decision
92
+
93
+ # Map pitch location
94
+ def map_pitch(bounce_point, frame_width, frame_height):
95
+ if bounce_point is None:
96
+ return None, "No bounce detected"
97
+
98
+ x, y = bounce_point
99
+ # Convert pixel coordinates to pitch coordinates
100
+ pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2 # Center at 0
101
+ pitch_y = (1 - y / frame_height) * PITCH_LENGTH # Bottom of frame = 0
102
+ return pitch_x, pitch_y
103
+
104
+ # Estimate ball speed
105
+ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
106
+ if len(positions) < 2:
107
+ return None, "Insufficient detections for speed estimation"
108
+
109
+ # Calculate distance in pixels between consecutive detections
110
+ distances = []
111
+ for i in range(1, len(positions)):
112
+ x1, y1 = positions[i-1]
113
+ x2, y2 = positions[i]
114
+ pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
115
+ distances.append(pixel_dist)
116
+
117
+ # Convert to meters (assume pitch length = frame height)
118
+ pixel_to_meter = PITCH_LENGTH / frame_width
119
+ distances_m = [d * pixel_to_meter for d in distances]
120
+
121
+ # Speed in m/s
122
+ time_interval = 1 / frame_rate
123
+ speeds = [d / time_interval for d in distances_m]
124
+ avg_speed_kmh = np.mean(speeds) * 3.6 # Convert m/s to km/h
125
+ return avg_speed_kmh, "Speed calculated successfully"
126
+
127
+ # Create pitch map visualization
128
+ def create_pitch_map(pitch_x, pitch_y):
129
+ fig = go.Figure()
130
+ # Draw pitch rectangle
131
+ fig.add_shape(
132
+ type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
133
+ line=dict(color="Green"), fillcolor="Green", opacity=0.3
134
+ )
135
+ # Draw stumps
136
+ fig.add_shape(
137
+ type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
138
+ line=dict(color="Brown"), fillcolor="Brown"
139
+ )
140
+ # Plot bounce point
141
+ if pitch_x is not None and pitch_y is not None:
142
+ fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
143
 
144
+ fig.update_layout(
145
+ title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
146
+ xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
147
+ )
148
+ return fig
149
+
150
+ # Main Gradio function
151
+ def drs_analysis(video):
152
+ # Save uploaded video temporarily
153
+ video_path = "temp_video.mp4"
154
+ with open(video_path, "wb") as f:
155
+ f.write(video.read())
156
+
157
+ # Process video
158
+ positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
159
+ if not positions:
160
+ return None, None, "No ball detected in video", None
161
+
162
+ # Predict trajectory
163
+ trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
164
+ if trajectory is None:
165
+ return None, None, lbw_decision, None
166
+
167
+ # Map pitch
168
+ pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
169
+
170
+ # Estimate speed
171
+ speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
172
+
173
+ # Create trajectory plot
174
+ trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
175
+ fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
176
+ fig_traj.update_yaxes(autorange="reversed") # Invert y-axis to match video frame
177
+
178
+ # Create pitch map
179
+ fig_pitch = create_pitch_map(pitch_x, pitch_y)
180
+
181
+ # Clean up
182
+ os.remove(video_path)
183
+
184
+ return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
185
+
186
+ # Gradio interface
187
+ with gr.Blocks() as demo:
188
+ gr.Markdown("## Cricket DRS Analysis")
189
+ video_input = gr.Video(label="Upload Video Clip")
190
+ btn = gr.Button("Analyze")
191
+ trajectory_output = gr.Plot(label="Ball Trajectory")
192
+ pitch_output = gr.Plot(label="Pitch Map")
193
+ text_output = gr.Textbox(label="Analysis Results")
194
+ video_output = gr.Video(label="Processed Video")
195
+ btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
196
+
197
+ if __name__ == "__main__":
198
+ demo.launch()