lbw_drs_app_new / lbw_detector.py
dschandra's picture
Create lbw_detector.py
3bf3b3c verified
raw
history blame
2.38 kB
from ultralytics import YOLO
import cv2
import numpy as np
import os
# Load YOLO model (custom-trained or pretrained with compatible classes)
model = YOLO("yolov8n.pt") # Replace with "lbw_yolov8.pt" if custom-trained
# Target class IDs — update based on your custom model class mapping
CLASS_NAMES = {
0: "ball",
1: "bat",
2: "pad",
3: "stump",
4: "player"
}
def detect_lbw_event(frames):
"""
Detects ball, bat, stump, and pad in each frame.
Identifies impact and prepares coordinates for trajectory modeling.
Returns:
dict: {
"ball_positions": [x, y] list per frame,
"impact_frame": int,
"impact_type": str,
"objects_per_frame": [
{"ball": (x, y), "pad": (x, y), ...}
]
}
"""
ball_positions = []
impact_frame = -1
impact_type = None
objects_per_frame = []
for idx, frame in enumerate(frames):
results = model(frame)[0]
frame_objects = {}
for det in results.boxes.data:
x1, y1, x2, y2, conf, cls = det.cpu().numpy()
class_id = int(cls)
class_name = CLASS_NAMES.get(class_id, "unknown")
center_x = int((x1 + x2) / 2)
center_y = int((y1 + y2) / 2)
frame_objects[class_name] = (center_x, center_y)
if class_name == "ball":
ball_positions.append((idx, center_x, center_y))
objects_per_frame.append(frame_objects)
# Basic impact logic: ball overlaps pad or bat
if "ball" in frame_objects and ("pad" in frame_objects or "bat" in frame_objects):
bx, by = frame_objects["ball"]
if "pad" in frame_objects:
px, py = frame_objects["pad"]
if abs(bx - px) < 30 and abs(by - py) < 30:
impact_frame = idx
impact_type = "pad"
break
if "bat" in frame_objects:
tx, ty = frame_objects["bat"]
if abs(bx - tx) < 30 and abs(by - ty) < 30:
impact_frame = idx
impact_type = "bat"
break
return {
"ball_positions": ball_positions,
"impact_frame": impact_frame,
"impact_type": impact_type,
"objects_per_frame": objects_per_frame
}