AjaykumarPilla commited on
Commit
58199b3
·
verified ·
1 Parent(s): c9d4715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -102
app.py CHANGED
@@ -10,13 +10,13 @@ import os
10
  # Load the trained YOLOv8n model
11
  model = YOLO("best.pt")
12
 
13
- # Constants
14
- STUMPS_WIDTH = 0.2286 # meters
15
- BALL_DIAMETER = 0.073 # meters
16
  FRAME_RATE = 30 # Input video frame rate
17
- SLOW_MOTION_FACTOR = 6
18
- CONF_THRESHOLD = 0.3
19
- RESIZE_DIM = 640 # Resize frames for faster processing
20
 
21
  def process_video(video_path):
22
  if not os.path.exists(video_path):
@@ -25,30 +25,21 @@ def process_video(video_path):
25
  frames = []
26
  ball_positions = []
27
  debug_log = []
28
- frame_count = 0
29
- max_frames = FRAME_RATE * 3 # Limit to 3 seconds of frames
30
 
31
- while cap.isOpened() and frame_count < max_frames:
 
32
  ret, frame = cap.read()
33
  if not ret:
34
  break
35
  frame_count += 1
36
- # Resize frame for faster YOLO inference
37
- frame_resized = cv2.resize(frame, (RESIZE_DIM, RESIZE_DIM), interpolation=cv2.INTER_AREA)
38
- frames.append(frame.copy()) # Store original frame
39
- # Detect ball
40
- results = model.predict(frame_resized, conf=CONF_THRESHOLD, imgsz=RESIZE_DIM)
41
  detections = 0
42
- scale_x, scale_y = frame.shape[1] / RESIZE_DIM, frame.shape[0] / RESIZE_DIM
43
  for detection in results[0].boxes:
44
- if detection.cls == 0: # Ball class
45
  detections += 1
46
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
47
- # Scale coordinates back to original frame size
48
- x1, x2 = x1 * scale_x, x2 * scale_x
49
- y1, y2 = y1 * scale_y, y2 * scale_y
50
- ball_center = [(x1 + x2) / 2, (y1 + y2) / 2]
51
- ball_positions.append(ball_center)
52
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
53
  frames[-1] = frame
54
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
@@ -63,105 +54,66 @@ def process_video(video_path):
63
 
64
  def estimate_trajectory(ball_positions, frames):
65
  if len(ball_positions) < 2:
66
- return [], [], "Error: Fewer than 2 ball detections for trajectory"
 
 
 
67
  x_coords = [pos[0] for pos in ball_positions]
68
  y_coords = [pos[1] for pos in ball_positions]
69
  times = np.arange(len(ball_positions)) / FRAME_RATE
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  try:
72
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
73
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
74
  except Exception as e:
75
- return [], [], f"Error in trajectory interpolation: {str(e)}"
76
-
77
- # Interpolate for all frames and future projection
78
- t_all = np.linspace(0, times[-1] + 0.5, len(frames) + 10)
79
- x_all = fx(t_all)
80
- y_all = fy(t_all)
81
- trajectory = list(zip(x_all, y_all))
82
- return trajectory, t_all, "Trajectory estimated successfully"
83
-
84
- def detect_impact_point(ball_positions, frames):
85
- if len(ball_positions) < 3:
86
- return ball_positions[-1] if ball_positions else None, len(ball_positions) - 1
87
- # Assume batsman is near stumps (bottom center of frame)
88
- frame_height, frame_width = frames[0].shape[:2]
89
- batsman_x = frame_width / 2
90
- batsman_y = frame_height * 0.8 # Approximate batsman position
91
- min_dist = float('inf')
92
- impact_idx = len(ball_positions) - 1
93
- impact_point = ball_positions[-1]
94
-
95
- # Look for sudden change in trajectory or proximity to batsman
96
- for i in range(1, len(ball_positions) - 1):
97
- x, y = ball_positions[i]
98
- prev_x, prev_y = ball_positions[i-1]
99
- next_x, next_y = ball_positions[i+1]
100
- # Check direction change (simplified)
101
- dx1, dy1 = x - prev_x, y - prev_y
102
- dx2, dy2 = next_x - x, next_y - y
103
- angle_change = abs(np.arctan2(dy2, dx2) - np.arctan2(dy1, dx1))
104
- dist_to_batsman = np.sqrt((x - batsman_x)**2 + (y - batsman_y)**2)
105
- if angle_change > np.pi/4 or dist_to_batsman < frame_width * 0.1:
106
- impact_idx = i
107
- impact_point = ball_positions[i]
108
- break
109
-
110
- return impact_point, impact_idx
111
 
112
- def lbw_decision(ball_positions, trajectory, frames):
113
- if not frames:
114
- return "Error: No frames processed", None, None, None
115
- if len(ball_positions) < 2:
116
- return "Not enough data (insufficient ball detections)", None, None, None
117
-
118
- frame_height, frame_width = frames[0].shape[:2]
119
- stumps_x = frame_width / 2
120
- stumps_y = frame_height * 0.9
121
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
122
-
123
- pitch_point = ball_positions[0]
124
- impact_point, impact_idx = detect_impact_point(ball_positions, frames)
125
-
126
- # Check pitching point
127
- pitch_x, pitch_y = pitch_point
128
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x Moderation: x > stumps_x + stumps_width_pixels / 2:
129
- return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
130
-
131
- # Check impact point
132
- impact_x, impact_y = impact_point
133
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
134
- return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
135
-
136
- # Check trajectory hitting stumps
137
- for x, y in trajectory:
138
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
139
- return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
140
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
141
 
142
- def generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path):
143
  if not frames:
144
  return None
145
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
146
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
147
 
148
- for i, frame in enumerate(frames):
149
- # Draw trajectory up to current frame
150
- traj_points = [p for j, p in enumerate(trajectory) if j / SLOW_MOTION_FACTOR <= i]
151
- for x, y in traj_points:
152
- cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots
153
 
154
- # Draw pitch point in early frames
155
- if pitch_point and i < len(frames) // 2:
156
  x, y = pitch_point
157
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
158
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
159
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
160
 
161
- # Draw impact point around impact frame
162
- if impact_point and abs(i - impact_idx) < 5:
163
  x, y = impact_point
164
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
165
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
166
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
167
 
@@ -174,12 +126,11 @@ def drs_review(video):
174
  frames, ball_positions, debug_log = process_video(video)
175
  if not frames:
176
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
177
- trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
178
- decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
179
- _, impact_idx = detect_impact_point(ball_positions, frames)
180
 
181
  output_path = f"output_{uuid.uuid4()}.mp4"
182
- slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, impact_idx, output_path)
183
 
184
  debug_output = f"{debug_log}\n{trajectory_log}"
185
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
@@ -193,6 +144,7 @@ iface = gr.Interface(
193
  gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
194
  ],
195
  title="AI-Powered DRS for LBW in Local Cricket",
 
196
  )
197
 
198
  if __name__ == "__main__":
 
10
  # Load the trained YOLOv8n model
11
  model = YOLO("best.pt")
12
 
13
+ # Constants for LBW decision and video processing
14
+ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
+ BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
  FRAME_RATE = 30 # Input video frame rate
17
+ SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
+ CONF_THRESHOLD = 0.3 # Confidence threshold for detection
19
+ IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
20
 
21
  def process_video(video_path):
22
  if not os.path.exists(video_path):
 
25
  frames = []
26
  ball_positions = []
27
  debug_log = []
 
 
28
 
29
+ frame_count = 0
30
+ while cap.isOpened():
31
  ret, frame = cap.read()
32
  if not ret:
33
  break
34
  frame_count += 1
35
+ frames.append(frame.copy())
36
+ results = model.predict(frame, conf=CONF_THRESHOLD)
 
 
 
37
  detections = 0
 
38
  for detection in results[0].boxes:
39
+ if detection.cls == 0: # Assuming class 0 is the ball
40
  detections += 1
41
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
42
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
 
 
 
 
43
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
44
  frames[-1] = frame
45
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
54
 
55
  def estimate_trajectory(ball_positions, frames):
56
  if len(ball_positions) < 2:
57
+ return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
58
+ frame_height = frames[0].shape[0]
59
+
60
+ # Extract x, y coordinates
61
  x_coords = [pos[0] for pos in ball_positions]
62
  y_coords = [pos[1] for pos in ball_positions]
63
  times = np.arange(len(ball_positions)) / FRAME_RATE
64
 
65
+ # Find impact point (closest to batsman, near stumps)
66
+ impact_idx = None
67
+ for i, y in enumerate(y_coords):
68
+ if y > frame_height * IMPACT_ZONE_Y: # Ball is near stumps/batsman
69
+ impact_idx = i
70
+ break
71
+ if impact_idx is None:
72
+ impact_idx = len(ball_positions) - 1 # Fallback to last detection
73
+
74
+ impact_point = ball_positions[impact_idx]
75
+
76
+ # Use positions up to impact for interpolation
77
+ x_coords = x_coords[:impact_idx + 1]
78
+ y_coords = y_coords[:impact_idx + 1]
79
+ times = times[:impact_idx + 1]
80
+
81
  try:
82
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
83
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
84
  except Exception as e:
85
+ return None, None, None, f"Error in trajectory interpolation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # Project trajectory (detected + future)
88
+ t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
89
+ x_full = fx(t_full)
90
+ y_full = fy(t_full)
91
+ trajectory = listagus = f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
93
 
94
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
95
  if not frames:
96
  return None
97
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
98
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
99
 
100
+ for frame in frames:
101
+ # Draw full trajectory (blue dots)
102
+ if trajectory:
103
+ for x, y in trajectory:
104
+ cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
105
 
106
+ # Draw pitch point (red circle with label)
107
+ if pitch_point:
108
  x, y = pitch_point
109
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
110
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
111
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
112
 
113
+ # Draw impact point (yellow circle with label)
114
+ if impact_point:
115
  x, y = impact_point
116
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
117
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
118
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
119
 
 
126
  frames, ball_positions, debug_log = process_video(video)
127
  if not frames:
128
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
129
+ trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
130
+ decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
 
131
 
132
  output_path = f"output_{uuid.uuid4()}.mp4"
133
+ slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
134
 
135
  debug_output = f"{debug_log}\n{trajectory_log}"
136
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
 
144
  gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
145
  ],
146
  title="AI-Powered DRS for LBW in Local Cricket",
147
+ description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
148
  )
149
 
150
  if __name__ == "__main__":