AjaykumarPilla commited on
Commit
7e29aa0
·
verified ·
1 Parent(s): b96ba5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -6,156 +6,147 @@ import gradio as gr
6
  from scipy.interpolate import interp1d
7
  import uuid
8
  import os
 
9
 
10
- # Load the trained YOLOv8n model from the Space's root directory
11
- model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
12
 
13
- # Constants for LBW decision and video processing
14
- STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
- BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
  FRAME_RATE = 30 # Input video frame rate
17
- SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
- CONF_THRESHOLD = 0.3 # Lowered confidence threshold for better detection
 
 
 
19
 
20
  def process_video(video_path):
21
- # Initialize video capture
22
  if not os.path.exists(video_path):
23
  return [], [], "Error: Video file not found"
24
  cap = cv2.VideoCapture(video_path)
25
  frames = []
26
  ball_positions = []
27
  debug_log = []
28
-
29
  frame_count = 0
 
 
30
  while cap.isOpened():
31
  ret, frame = cap.read()
32
  if not ret:
33
  break
34
  frame_count += 1
 
 
 
 
 
35
  frames.append(frame.copy()) # Store original frame
36
- # Detect ball using the trained YOLOv8n model
37
- results = model.predict(frame, conf=CONF_THRESHOLD)
38
  detections = 0
39
  for detection in results[0].boxes:
40
- if detection.cls == 0: # Assuming class 0 is the ball
41
  detections += 1
42
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
 
 
43
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
- # Draw bounding box on frame for visualization
45
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
46
- frames[-1] = frame # Update frame with bounding box
47
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
 
 
48
  cap.release()
49
 
 
50
  if not ball_positions:
51
  debug_log.append("No balls detected in any frame")
52
  else:
53
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
54
-
55
  return frames, ball_positions, "\n".join(debug_log)
56
 
57
  def estimate_trajectory(ball_positions, frames):
58
- # Simplified physics-based trajectory projection
59
  if len(ball_positions) < 2:
60
  return None, None, "Error: Fewer than 2 ball detections for trajectory"
61
- # Extract x, y coordinates
62
  x_coords = [pos[0] for pos in ball_positions]
63
  y_coords = [pos[1] for pos in ball_positions]
64
  times = np.arange(len(ball_positions)) / FRAME_RATE
65
-
66
- # Interpolate to smooth trajectory
67
  try:
68
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
69
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
70
  except Exception as e:
71
  return None, None, f"Error in trajectory interpolation: {str(e)}"
72
-
73
- # Project trajectory forward (0.5 seconds post-impact)
74
  t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
75
  x_future = fx(t_future)
76
  y_future = fy(t_future)
77
-
78
- return list(zip(x_future, y_future)), t_future, "Trajectory estimated successfully"
79
 
80
  def lbw_decision(ball_positions, trajectory, frames):
81
- # Simplified LBW logic
82
  if not frames:
83
  return "Error: No frames processed", None, None, None
84
  if not trajectory or len(ball_positions) < 2:
85
  return "Not enough data (insufficient ball detections)", None, None, None
86
-
87
- # Assume stumps are at the bottom center of the frame (calibration needed)
88
  frame_height, frame_width = frames[0].shape[:2]
89
  stumps_x = frame_width / 2
90
- stumps_y = frame_height * 0.9 # Approximate stumps position
91
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
92
-
93
- # Store pitch and impact points
94
  pitch_point = ball_positions[0]
95
  impact_point = ball_positions[-1]
96
-
97
- # Check pitching point
98
  pitch_x, pitch_y = pitch_point
99
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
100
  return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
101
-
102
- # Check impact point
103
  impact_x, impact_y = impact_point
104
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
105
  return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
106
-
107
- # Check trajectory hitting stumps
108
  for x, y in trajectory:
109
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
110
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
 
111
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
112
 
113
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
114
- # Generate very slow-motion video with ball detection, trajectory, and pitch/impact points
115
  if not frames:
116
  return None
117
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
118
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
119
-
120
  for frame in frames:
121
- # Draw trajectory
122
  if trajectory:
123
  for x, y in trajectory:
124
- cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
125
-
126
- # Draw pitch point (red circle with label)
127
  if pitch_point:
128
  x, y = pitch_point
129
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
130
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
131
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
132
-
133
- # Draw impact point (yellow circle with label)
134
  if impact_point:
135
  x, y = impact_point
136
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
137
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
138
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
139
-
140
- for _ in range(SLOW_MOTION_FACTOR): # Duplicate frames for very slow motion
141
  out.write(frame)
142
  out.release()
143
- return output_path
 
144
 
145
  def drs_review(video):
146
- # Process video and generate DRS output
147
  frames, ball_positions, debug_log = process_video(video)
148
  if not frames:
149
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
150
  trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
151
  decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
152
-
153
- # Generate slow-motion replay with enhanced annotations
154
  output_path = f"output_{uuid.uuid4()}.mp4"
155
- slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
156
-
157
- # Combine debug logs for output
158
- debug_output = f"{debug_log}\n{trajectory_log}"
159
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
160
 
161
  # Gradio interface
@@ -164,10 +155,10 @@ iface = gr.Interface(
164
  inputs=gr.Video(label="Upload Video Clip"),
165
  outputs=[
166
  gr.Textbox(label="DRS Decision and Debug Log"),
167
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
168
  ],
169
  title="AI-Powered DRS for LBW in Local Cricket",
170
- description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
171
  )
172
 
173
  if __name__ == "__main__":
 
6
  from scipy.interpolate import interp1d
7
  import uuid
8
  import os
9
+ import time
10
 
11
+ # Load the trained YOLOv8n model
12
+ model = YOLO("best.pt")
13
 
14
+ # Constants
15
+ STUMPS_WIDTH = 0.2286 # meters
16
+ BALL_DIAMETER = 0.073 # meters
17
  FRAME_RATE = 30 # Input video frame rate
18
+ SLOW_MOTION_FACTOR = 3 # Reduced from 6 for faster processing
19
+ CONF_THRESHOLD = 0.3
20
+ MAX_DETECTIONS = 10 # Stop after detecting enough ball positions
21
+ PROCESS_EVERY_N_FRAME = 2 # Process every 2nd frame
22
+ RESIZE_FACTOR = 0.5 # Downscale frames to 50% for faster processing
23
 
24
  def process_video(video_path):
25
+ start_time = time.time()
26
  if not os.path.exists(video_path):
27
  return [], [], "Error: Video file not found"
28
  cap = cv2.VideoCapture(video_path)
29
  frames = []
30
  ball_positions = []
31
  debug_log = []
 
32
  frame_count = 0
33
+ processed_frames = 0
34
+
35
  while cap.isOpened():
36
  ret, frame = cap.read()
37
  if not ret:
38
  break
39
  frame_count += 1
40
+ if frame_count % PROCESS_EVERY_N_FRAME != 0:
41
+ continue # Skip frames
42
+ processed_frames += 1
43
+ # Resize frame for faster processing
44
+ frame_small = cv2.resize(frame, (0, 0), fx=RESIZE_FACTOR, fy=RESIZE_FACTOR)
45
  frames.append(frame.copy()) # Store original frame
46
+ # Detect ball
47
+ results = model.predict(frame_small, conf=CONF_THRESHOLD)
48
  detections = 0
49
  for detection in results[0].boxes:
50
+ if detection.cls == 0: # Ball class
51
  detections += 1
52
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
53
+ # Scale coordinates back to original frame size
54
+ x1, y1, x2, y2 = [v / RESIZE_FACTOR for v in [x1, y1, x2, y2]]
55
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
56
+ # Draw bounding box on original frame
57
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
58
+ frames[-1] = frame
59
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
60
+ if len(ball_positions) >= MAX_DETECTIONS:
61
+ debug_log.append(f"Stopping early after {len(ball_positions)} detections")
62
+ break
63
  cap.release()
64
 
65
+ debug_log.append(f"Processed {processed_frames} frames in {time.time() - start_time:.2f} seconds")
66
  if not ball_positions:
67
  debug_log.append("No balls detected in any frame")
68
  else:
69
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
 
70
  return frames, ball_positions, "\n".join(debug_log)
71
 
72
  def estimate_trajectory(ball_positions, frames):
73
+ start_time = time.time()
74
  if len(ball_positions) < 2:
75
  return None, None, "Error: Fewer than 2 ball detections for trajectory"
 
76
  x_coords = [pos[0] for pos in ball_positions]
77
  y_coords = [pos[1] for pos in ball_positions]
78
  times = np.arange(len(ball_positions)) / FRAME_RATE
 
 
79
  try:
80
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
81
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
82
  except Exception as e:
83
  return None, None, f"Error in trajectory interpolation: {str(e)}"
 
 
84
  t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
85
  x_future = fx(t_future)
86
  y_future = fy(t_future)
87
+ debug_log = f"Trajectory estimated in {time.time() - start_time:.2f} seconds"
88
+ return list(zip(x_future, y_future)), t_future, debug_log
89
 
90
  def lbw_decision(ball_positions, trajectory, frames):
91
+ start_time = time.time()
92
  if not frames:
93
  return "Error: No frames processed", None, None, None
94
  if not trajectory or len(ball_positions) < 2:
95
  return "Not enough data (insufficient ball detections)", None, None, None
 
 
96
  frame_height, frame_width = frames[0].shape[:2]
97
  stumps_x = frame_width / 2
98
+ stumps_y = frame_height * 0.9
99
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
 
 
100
  pitch_point = ball_positions[0]
101
  impact_point = ball_positions[-1]
 
 
102
  pitch_x, pitch_y = pitch_point
103
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
104
  return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
 
 
105
  impact_x, impact_y = impact_point
106
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
107
  return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
 
 
108
  for x, y in trajectory:
109
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
110
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
111
+ debug_log = f"LBW decision computed in {time.time() - start_time:.2f} seconds"
112
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
113
 
114
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
115
+ start_time = time.time()
116
  if not frames:
117
  return None
118
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
119
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
 
120
  for frame in frames:
 
121
  if trajectory:
122
  for x, y in trajectory:
123
+ cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)
 
 
124
  if pitch_point:
125
  x, y = pitch_point
126
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
127
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
128
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
 
 
129
  if impact_point:
130
  x, y = impact_point
131
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
132
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
133
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
134
+ for _ in range(SLOW_MOTION_FACTOR):
 
135
  out.write(frame)
136
  out.release()
137
+ debug_log = f"Slow-motion video generated in {time.time() - start_time:.2f} seconds"
138
+ return output_path, debug_log
139
 
140
  def drs_review(video):
141
+ start_time = time.time()
142
  frames, ball_positions, debug_log = process_video(video)
143
  if not frames:
144
  return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
145
  trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
146
  decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
 
 
147
  output_path = f"output_{uuid.uuid4()}.mp4"
148
+ slow_motion_path, slow_motion_log = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
149
+ debug_output = f"{debug_log}\n{trajectory_log}\n{slow_motion_log}\nTotal processing time: {time.time() - start_time:.2f} seconds"
 
 
150
  return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
151
 
152
  # Gradio interface
 
155
  inputs=gr.Video(label="Upload Video Clip"),
156
  outputs=[
157
  gr.Textbox(label="DRS Decision and Debug Log"),
158
+ gr.Video(label="Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
159
  ],
160
  title="AI-Powered DRS for LBW in Local Cricket",
161
+ description="Upload a 3-second video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
162
  )
163
 
164
  if __name__ == "__main__":