AjaykumarPilla commited on
Commit
7e44cd3
·
verified ·
1 Parent(s): 0a9dc78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -51
app.py CHANGED
@@ -15,15 +15,17 @@ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
  FRAME_RATE = 30 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
- CONF_THRESHOLD = 0.25 # Lowered confidence threshold to improve detection
19
  IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
 
20
 
21
  def process_video(video_path):
22
  if not os.path.exists(video_path):
23
- return [], [], "Error: Video file not found"
24
  cap = cv2.VideoCapture(video_path)
25
  frames = []
26
  ball_positions = []
 
27
  debug_log = []
28
 
29
  frame_count = 0
@@ -40,6 +42,7 @@ def process_video(video_path):
40
  detections += 1
41
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
42
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
 
43
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
44
  frames[-1] = frame
45
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
@@ -50,48 +53,34 @@ def process_video(video_path):
50
  else:
51
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
52
 
53
- # Interpolate missing detections
54
- if ball_positions:
55
- ball_positions = interpolate_missing_positions(ball_positions, frame_count)
56
 
57
- return frames, ball_positions, "\n".join(debug_log)
58
-
59
- def interpolate_missing_positions(ball_positions, total_frames):
60
- if len(ball_positions) < 2:
61
- return ball_positions
62
- times = np.linspace(0, total_frames / FRAME_RATE, total_frames)
63
- detected_times = [i / FRAME_RATE for i, _ in enumerate(ball_positions)]
64
- x_coords = [pos[0] for pos in ball_positions]
65
- y_coords = [pos[1] for pos in ball_positions]
66
-
67
- try:
68
- fx = interp1d(detected_times, x_coords, kind='linear', fill_value="extrapolate")
69
- fy = interp1d(detected_times, y_coords, kind='linear', fill_value="extrapolate")
70
- interpolated_positions = [(fx(t), fy(t)) for t in times if t <= detected_times[-1]]
71
- return interpolated_positions
72
- except:
73
- return ball_positions
74
-
75
- def estimate_trajectory(ball_positions, frames):
76
  if len(ball_positions) < 2:
77
- return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
78
  frame_height = frames[0].shape[0]
79
 
80
  # Extract x, y coordinates
81
  x_coords = [pos[0] for pos in ball_positions]
82
  y_coords = [pos[1] for pos in ball_positions]
83
- times = np.arange(len(ball_positions)) / FRAME_RATE
 
 
 
 
84
 
85
- # Find impact point (closest to batsman, near stumps)
86
  impact_idx = None
87
- for i, y in enumerate(y_coords):
88
- if y > frame_height * IMPACT_ZONE_Y: # Ball is near stumps/batsman
 
89
  impact_idx = i
 
90
  break
91
  if impact_idx is None:
92
- impact_idx = len(ball_positions) - 1 # Fallback to last detection
 
93
 
94
- pitch_point = ball_positions[0]
95
  impact_point = ball_positions[impact_idx]
96
 
97
  # Use positions up to impact for interpolation
@@ -103,15 +92,16 @@ def estimate_trajectory(ball_positions, frames):
103
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
104
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
105
  except Exception as e:
106
- return None, None, None, f"Error in trajectory interpolation: {str(e)}"
107
 
108
- # Project trajectory (detected + future)
109
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
110
  x_full = fx(t_full)
111
  y_full = fy(t_full)
112
  trajectory = list(zip(x_full, y_full))
113
 
114
- return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
 
115
 
116
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
117
  if not frames:
@@ -141,31 +131,33 @@ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
141
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
142
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
143
 
144
- def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
145
  if not frames:
146
  return None
147
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
148
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
149
 
150
- for frame in frames:
151
- # Draw full trajectory (blue dots)
152
- if trajectory:
153
- for x, y in trajectory:
154
- cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
 
 
155
 
156
- # Draw pitch point (red circle with label)
157
- if pitch_point:
158
  x, y = pitch_point
159
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
160
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
161
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
162
 
163
- # Draw impact point (yellow circle with label)
164
- if impact_point:
165
  x, y = impact_point
166
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
167
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
168
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
169
 
170
  for _ in range(SLOW_MOTION_FACTOR):
171
  out.write(frame)
@@ -173,14 +165,14 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_p
173
  return output_path
174
 
175
  def drs_review(video):
176
- frames, ball_positions, debug_log = process_video(video)
177
  if not frames:
178
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
179
- trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
180
  decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
181
 
182
  output_path = f"output_{uuid.uuid4()}.mp4"
183
- slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
184
 
185
  debug_output = f"{debug_log}\n{trajectory_log}"
186
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
@@ -191,10 +183,10 @@ iface = gr.Interface(
191
  inputs=gr.Video(label="Upload Video Clip"),
192
  outputs=[
193
  gr.Textbox(label="DRS Decision and Debug Log"),
194
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
195
  ],
196
  title="AI-Powered DRS for LBW in Local Cricket",
197
- description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
198
  )
199
 
200
  if __name__ == "__main__":
 
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
  FRAME_RATE = 30 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
+ CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
  IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
20
+ IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change (impact)
21
 
22
  def process_video(video_path):
23
  if not os.path.exists(video_path):
24
+ return [], [], [], "Error: Video file not found"
25
  cap = cv2.VideoCapture(video_path)
26
  frames = []
27
  ball_positions = []
28
+ detection_frames = [] # Track frames with detections
29
  debug_log = []
30
 
31
  frame_count = 0
 
42
  detections += 1
43
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
44
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
45
+ detection_frames.append(frame_count - 1) # Store frame index (0-based)
46
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
47
  frames[-1] = frame
48
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
53
  else:
54
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
55
 
56
+ return frames, ball_positions, detection_frames, "\n".join(debug_log)
 
 
57
 
58
+ def estimate_trajectory(ball_positions, frames, detection_frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if len(ball_positions) < 2:
60
+ return None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
61
  frame_height = frames[0].shape[0]
62
 
63
  # Extract x, y coordinates
64
  x_coords = [pos[0] for pos in ball_positions]
65
  y_coords = [pos[1] for pos in ball_positions]
66
+ times = np.array(detection_frames) / FRAME_RATE
67
+
68
+ # Identify pitch point (first detection)
69
+ pitch_point = ball_positions[0]
70
+ pitch_frame = detection_frames[0]
71
 
72
+ # Find impact point (sudden change in y or near stumps)
73
  impact_idx = None
74
+ impact_frame = None
75
+ for i in range(1, len(y_coords)):
76
+ if y_coords[i] > frame_height * IMPACT_ZONE_Y or abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y:
77
  impact_idx = i
78
+ impact_frame = detection_frames[i]
79
  break
80
  if impact_idx is None:
81
+ impact_idx = len(ball_positions) - 1
82
+ impact_frame = detection_frames[-1]
83
 
 
84
  impact_point = ball_positions[impact_idx]
85
 
86
  # Use positions up to impact for interpolation
 
92
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
93
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
94
  except Exception as e:
95
+ return None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
96
 
97
+ # Project trajectory (detected + future for LBW decision)
98
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
99
  x_full = fx(t_full)
100
  y_full = fy(t_full)
101
  trajectory = list(zip(x_full, y_full))
102
 
103
+ debug_log = f"Trajectory estimated successfully\nPitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\nImpact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})"
104
+ return trajectory, pitch_point, impact_point, pitch_frame, impact_frame, debug_log
105
 
106
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
107
  if not frames:
 
131
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
132
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
133
 
134
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
135
  if not frames:
136
  return None
137
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
139
 
140
+ # Extract trajectory points up to impact for visualization
141
+ trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
142
+
143
+ for i, frame in enumerate(frames):
144
+ # Draw trajectory (blue line) only for frames with detections
145
+ if i in detection_frames and trajectory_points.size > 0:
146
+ cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
147
 
148
+ # Draw pitch point (red circle with label) only in pitch frame
149
+ if pitch_point and i == pitch_frame:
150
  x, y = pitch_point
151
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
152
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
153
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
154
 
155
+ # Draw impact point (yellow circle with label) only in impact frame
156
+ if impact_point and i == impact_frame:
157
  x, y = impact_point
158
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
159
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
160
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
161
 
162
  for _ in range(SLOW_MOTION_FACTOR):
163
  out.write(frame)
 
165
  return output_path
166
 
167
  def drs_review(video):
168
+ frames, ball_positions, detection_frames, debug_log = process_video(video)
169
  if not frames:
170
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
171
+ trajectory, pitch_point, impact_point, pitch_frame, impact_frame, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
172
  decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
173
 
174
  output_path = f"output_{uuid.uuid4()}.mp4"
175
+ slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
176
 
177
  debug_output = f"{debug_log}\n{trajectory_log}"
178
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
 
183
  inputs=gr.Video(label="Upload Video Clip"),
184
  outputs=[
185
  gr.Textbox(label="DRS Decision and Debug Log"),
186
+ gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
187
  ],
188
  title="AI-Powered DRS for LBW in Local Cricket",
189
+ description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
190
  )
191
 
192
  if __name__ == "__main__":