dschandra commited on
Commit
9ffd92a
·
verified ·
1 Parent(s): 4292326

Delete lbw_detector.py

Browse files
Files changed (1) hide show
  1. lbw_detector.py +0 -127
lbw_detector.py DELETED
@@ -1,127 +0,0 @@
1
- # lbw_detector.py
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from torchvision import transforms
6
- import cv2
7
- from utils import extract_frames
8
- from trajectory_predictor import predict_trajectory
9
- from visualizer import draw_visuals
10
-
11
- # -----------------------------
12
- # Model Definition (UNet-lite)
13
- # -----------------------------
14
- class DoubleConv(nn.Module):
15
- def __init__(self, in_channels, out_channels):
16
- super(DoubleConv, self).__init__()
17
- self.double_conv = nn.Sequential(
18
- nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),
19
- nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True)
20
- )
21
-
22
- def forward(self, x):
23
- return self.double_conv(x)
24
-
25
- class UNet(nn.Module):
26
- def __init__(self, in_channels=3, out_channels=1): # ✅ Match model file
27
- super(UNet, self).__init__()
28
- self.down1 = DoubleConv(in_channels, 64)
29
- self.down2 = DoubleConv(64, 128)
30
- self.middle = DoubleConv(128, 256)
31
-
32
- self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
33
- self.upconv2 = DoubleConv(256, 128)
34
-
35
- self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
36
- self.upconv1 = DoubleConv(128, 64)
37
-
38
- self.final = nn.Conv2d(64, out_channels, 1)
39
-
40
- def forward(self, x):
41
- d1 = self.down1(x)
42
- d2 = self.down2(nn.MaxPool2d(2)(d1))
43
- m = self.middle(nn.MaxPool2d(2)(d2))
44
-
45
- u2 = self.up2(m)
46
- u2 = self.upconv2(torch.cat([u2, d2], dim=1))
47
-
48
- u1 = self.up1(u2)
49
- u1 = self.upconv1(torch.cat([u1, d1], dim=1))
50
-
51
- out = self.final(u1)
52
- return out
53
-
54
- # -----------------------------
55
- # Load Model
56
- # -----------------------------
57
- model_path = "models/lbw_drs_unet_model.pth"
58
- device = "cpu"
59
-
60
- model = UNet(in_channels=3, out_channels=1) # ✅ Must match trained model
61
- state_dict = torch.load(model_path, map_location=device)
62
- model.load_state_dict(state_dict)
63
- model.to(device)
64
- model.eval()
65
-
66
- # -----------------------------
67
- # Frame Preprocessing
68
- # -----------------------------
69
- transform = transforms.Compose([
70
- transforms.ToTensor(),
71
- ])
72
-
73
- def detect_objects_with_model(frame):
74
- """Run segmentation on a frame using the custom model"""
75
- input_tensor = transform(frame).unsqueeze(0).to(device)
76
- with torch.no_grad():
77
- output = model(input_tensor)
78
- mask = torch.sigmoid(output).squeeze().cpu().numpy()
79
- return mask # shape: (H, W)
80
-
81
- # -----------------------------
82
- # Main Analysis Function
83
- # -----------------------------
84
- def analyze_video(video_path):
85
- frames = extract_frames(video_path)
86
- ball_positions = []
87
- impact_frame_idx = None
88
- impact_zone = "unknown"
89
-
90
- for i, frame in enumerate(frames):
91
- mask = detect_objects_with_model(frame)
92
-
93
- ball_mask = mask > 0.5
94
- pad_mask = None # Not used with single-channel model
95
-
96
- # Ball center detection
97
- contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
98
- if contours:
99
- largest = max(contours, key=cv2.contourArea)
100
- M = cv2.moments(largest)
101
- if M['m00'] != 0:
102
- cx = int(M['m10']/M['m00'])
103
- cy = int(M['m01']/M['m00'])
104
- ball_positions.append((i, cx, cy))
105
-
106
- # Optional: skip impact detection for now (no pad_mask)
107
- # If you later upgrade to multi-class model, enable below
108
- # if pad_mask is not None and contours:
109
- # overlap = np.logical_and(ball_mask, pad_mask).sum()
110
- # if overlap > 10:
111
- # impact_frame_idx = i
112
- # impact_zone = "pad"
113
- # break
114
-
115
- trajectory = predict_trajectory(ball_positions)
116
- decision = "OUT" if trajectory_hits_stumps(trajectory) else "NOT OUT"
117
- result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision)
118
- return result_path, decision
119
-
120
- # -----------------------------
121
- # Basic Rule: Hits Stumps
122
- # -----------------------------
123
- def trajectory_hits_stumps(trajectory):
124
- for (x, y) in trajectory:
125
- if 300 < x < 340 and y < 480:
126
- return True
127
- return False