Spaces:
Sleeping
Sleeping
Update gully_drs_core/ball_detection.py
Browse files- gully_drs_core/ball_detection.py +22 -37
gully_drs_core/ball_detection.py
CHANGED
|
@@ -5,53 +5,37 @@ import numpy as np
|
|
| 5 |
from .model_utils import load_model
|
| 6 |
|
| 7 |
def find_bounce_point(path):
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
"""
|
| 11 |
-
for i in range(1, len(path) - 1):
|
| 12 |
-
if path[i - 1][1] > path[i][1] < path[i + 1][1]: # y decreases then increases
|
| 13 |
return path[i]
|
| 14 |
return None
|
| 15 |
|
| 16 |
def estimate_speed(ball_path, fps, px_to_m=0.01):
|
| 17 |
-
"""
|
| 18 |
-
Estimate speed in km/h based on pixel distance and frame rate.
|
| 19 |
-
Assumes 1 pixel ≈ 1cm (adjust px_to_m for better accuracy).
|
| 20 |
-
"""
|
| 21 |
if len(ball_path) < 2:
|
| 22 |
return 0.0
|
| 23 |
-
|
| 24 |
p1 = ball_path[0]
|
| 25 |
-
p2 = ball_path[min(5, len(ball_path)
|
| 26 |
-
|
| 27 |
-
dx = p2[0] - p1[0]
|
| 28 |
-
dy = p2[1] - p1[1]
|
| 29 |
dist_px = (dx**2 + dy**2)**0.5
|
| 30 |
dist_m = dist_px * px_to_m
|
| 31 |
-
time_s = (min(5, len(ball_path)
|
| 32 |
-
|
| 33 |
speed_kmh = (dist_m / time_s) * 3.6 if time_s > 0 else 0
|
| 34 |
return round(speed_kmh, 1)
|
| 35 |
|
| 36 |
def analyze_video(file_path):
|
| 37 |
-
"""
|
| 38 |
-
Main processing function:
|
| 39 |
-
- Detects the ball using YOLOv8
|
| 40 |
-
- Builds trajectory from valid frames
|
| 41 |
-
- Detects bounce, impact, stump zone intersection
|
| 42 |
-
- Returns decision + video frame overlays
|
| 43 |
-
"""
|
| 44 |
model = load_model()
|
| 45 |
cap = cv2.VideoCapture(file_path)
|
| 46 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 47 |
-
width
|
| 48 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 49 |
|
| 50 |
-
frames = []
|
| 51 |
ball_path = []
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
max_jump = 100 # pixels
|
| 54 |
last_point = None
|
|
|
|
| 55 |
|
| 56 |
while True:
|
| 57 |
ret, frame = cap.read()
|
|
@@ -62,20 +46,22 @@ def analyze_video(file_path):
|
|
| 62 |
valid_detection = None
|
| 63 |
|
| 64 |
for r in results:
|
| 65 |
-
|
|
|
|
| 66 |
if len(ball_detections) == 1:
|
| 67 |
box = ball_detections[0]
|
| 68 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 69 |
-
cx = (x1 + x2) // 2
|
| 70 |
-
cy = (y1 + y2) // 2
|
| 71 |
|
| 72 |
-
#
|
| 73 |
if last_point:
|
| 74 |
-
dx = cx - last_point[0]
|
| 75 |
-
dy = cy - last_point[1]
|
| 76 |
jump = (dx**2 + dy**2)**0.5
|
| 77 |
if jump > max_jump:
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
valid_detection = (cx, cy)
|
| 81 |
last_point = valid_detection
|
|
@@ -85,15 +71,13 @@ def analyze_video(file_path):
|
|
| 85 |
ball_path.append(valid_detection)
|
| 86 |
|
| 87 |
frames.append(frame)
|
|
|
|
| 88 |
|
| 89 |
cap.release()
|
| 90 |
|
| 91 |
-
# Calculate analysis outputs
|
| 92 |
bounce_point = find_bounce_point(ball_path)
|
| 93 |
impact_point = ball_path[-1] if ball_path else None
|
| 94 |
-
speed_kmh = estimate_speed(ball_path, fps)
|
| 95 |
|
| 96 |
-
# Define stump zone area
|
| 97 |
stump_zone = (
|
| 98 |
width // 2 - 30,
|
| 99 |
height - 100,
|
|
@@ -101,13 +85,14 @@ def analyze_video(file_path):
|
|
| 101 |
height
|
| 102 |
)
|
| 103 |
|
| 104 |
-
# LBW decision: does ball impact land in stump zone?
|
| 105 |
decision = "OUT" if (
|
| 106 |
impact_point and
|
| 107 |
stump_zone[0] <= impact_point[0] <= stump_zone[2] and
|
| 108 |
stump_zone[1] <= impact_point[1] <= stump_zone[3]
|
| 109 |
) else "NOT OUT"
|
| 110 |
|
|
|
|
|
|
|
| 111 |
return {
|
| 112 |
"trajectory": ball_path,
|
| 113 |
"fps": fps,
|
|
|
|
| 5 |
from .model_utils import load_model
|
| 6 |
|
| 7 |
def find_bounce_point(path):
|
| 8 |
+
for i in range(1, len(path)-1):
|
| 9 |
+
if path[i-1][1] > path[i][1] < path[i+1][1]: # y dips = bounce
|
|
|
|
|
|
|
|
|
|
| 10 |
return path[i]
|
| 11 |
return None
|
| 12 |
|
| 13 |
def estimate_speed(ball_path, fps, px_to_m=0.01):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if len(ball_path) < 2:
|
| 15 |
return 0.0
|
|
|
|
| 16 |
p1 = ball_path[0]
|
| 17 |
+
p2 = ball_path[min(5, len(ball_path)-1)]
|
| 18 |
+
dx, dy = p2[0] - p1[0], p2[1] - p1[1]
|
|
|
|
|
|
|
| 19 |
dist_px = (dx**2 + dy**2)**0.5
|
| 20 |
dist_m = dist_px * px_to_m
|
| 21 |
+
time_s = (min(5, len(ball_path)-1)) / fps
|
|
|
|
| 22 |
speed_kmh = (dist_m / time_s) * 3.6 if time_s > 0 else 0
|
| 23 |
return round(speed_kmh, 1)
|
| 24 |
|
| 25 |
def analyze_video(file_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
model = load_model()
|
| 27 |
cap = cv2.VideoCapture(file_path)
|
| 28 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 29 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 30 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 31 |
|
|
|
|
| 32 |
ball_path = []
|
| 33 |
+
frames = []
|
| 34 |
+
|
| 35 |
+
max_jump = 100 # max allowed jump (pixels) between consecutive ball detections
|
| 36 |
|
|
|
|
| 37 |
last_point = None
|
| 38 |
+
frame_idx = 0
|
| 39 |
|
| 40 |
while True:
|
| 41 |
ret, frame = cap.read()
|
|
|
|
| 46 |
valid_detection = None
|
| 47 |
|
| 48 |
for r in results:
|
| 49 |
+
# Accept only if exactly one detection of cricket ball class (e.g., class 0)
|
| 50 |
+
ball_detections = [box for box in r.boxes if int(box.cls[0]) == 0]
|
| 51 |
if len(ball_detections) == 1:
|
| 52 |
box = ball_detections[0]
|
| 53 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 54 |
+
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
|
|
|
| 55 |
|
| 56 |
+
# Check jump threshold from last point
|
| 57 |
if last_point:
|
| 58 |
+
dx, dy = cx - last_point[0], cy - last_point[1]
|
|
|
|
| 59 |
jump = (dx**2 + dy**2)**0.5
|
| 60 |
if jump > max_jump:
|
| 61 |
+
# Reject outlier
|
| 62 |
+
frames.append(frame)
|
| 63 |
+
frame_idx += 1
|
| 64 |
+
continue
|
| 65 |
|
| 66 |
valid_detection = (cx, cy)
|
| 67 |
last_point = valid_detection
|
|
|
|
| 71 |
ball_path.append(valid_detection)
|
| 72 |
|
| 73 |
frames.append(frame)
|
| 74 |
+
frame_idx += 1
|
| 75 |
|
| 76 |
cap.release()
|
| 77 |
|
|
|
|
| 78 |
bounce_point = find_bounce_point(ball_path)
|
| 79 |
impact_point = ball_path[-1] if ball_path else None
|
|
|
|
| 80 |
|
|
|
|
| 81 |
stump_zone = (
|
| 82 |
width // 2 - 30,
|
| 83 |
height - 100,
|
|
|
|
| 85 |
height
|
| 86 |
)
|
| 87 |
|
|
|
|
| 88 |
decision = "OUT" if (
|
| 89 |
impact_point and
|
| 90 |
stump_zone[0] <= impact_point[0] <= stump_zone[2] and
|
| 91 |
stump_zone[1] <= impact_point[1] <= stump_zone[3]
|
| 92 |
) else "NOT OUT"
|
| 93 |
|
| 94 |
+
speed_kmh = estimate_speed(ball_path, fps)
|
| 95 |
+
|
| 96 |
return {
|
| 97 |
"trajectory": ball_path,
|
| 98 |
"fps": fps,
|