files_upload
Browse files- action_model.py +161 -0
- best_model.pt +3 -0
- label_map.json +7 -0
action_model.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision import models, transforms
|
9 |
+
from datetime import datetime
|
10 |
+
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
|
11 |
+
|
12 |
+
try:
|
13 |
+
from ucf101_config import UCF101_CLASSES
|
14 |
+
UCF101_AVAILABLE = True
|
15 |
+
except ImportError:
|
16 |
+
UCF101_AVAILABLE = False
|
17 |
+
print("[WARNING] UCF101 config not found, using default configuration")
|
18 |
+
|
19 |
+
class CNN_GRU(nn.Module):
|
20 |
+
def __init__(self, cnn_model='mobilenetv2', hidden_size=128, num_layers=1,
|
21 |
+
num_classes=5, dropout=0.5, FREEZE_BACKBONE=True):
|
22 |
+
super(CNN_GRU, self).__init__()
|
23 |
+
|
24 |
+
if cnn_model == 'mobilenetv2':
|
25 |
+
cnn = models.mobilenet_v2(pretrained=True)
|
26 |
+
self.cnn_out_features = cnn.last_channel
|
27 |
+
self.cnn = cnn.features
|
28 |
+
elif cnn_model == 'efficientnet_b0':
|
29 |
+
import timm
|
30 |
+
cnn = timm.create_model('efficientnet_b0', pretrained=True)
|
31 |
+
self.cnn_out_features = cnn.classifier.in_features
|
32 |
+
cnn.classifier = nn.Identity()
|
33 |
+
self.cnn = cnn
|
34 |
+
else:
|
35 |
+
raise ValueError("Invalid CNN model")
|
36 |
+
|
37 |
+
if FREEZE_BACKBONE:
|
38 |
+
for p in self.cnn.parameters():
|
39 |
+
p.requires_grad = False
|
40 |
+
|
41 |
+
self.gru = nn.GRU(self.cnn_out_features,
|
42 |
+
hidden_size,
|
43 |
+
num_layers=num_layers,
|
44 |
+
batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
45 |
+
|
46 |
+
self.dropout = nn.Dropout(dropout)
|
47 |
+
self.fc = nn.Linear(hidden_size, num_classes)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
b, t, c, h, w = x.size()
|
51 |
+
x = x.view(b * t, c, h, w)
|
52 |
+
feats = self.cnn(x)
|
53 |
+
feats = F.adaptive_avg_pool2d(feats, 1).view(b, t, -1)
|
54 |
+
out, _ = self.gru(feats)
|
55 |
+
out = self.dropout(out[:, -1])
|
56 |
+
return self.fc(out)
|
57 |
+
|
58 |
+
|
59 |
+
def get_transform(resize=(112, 112), augment=False):
|
60 |
+
transforms_list = [
|
61 |
+
transforms.ToPILImage(),
|
62 |
+
transforms.Resize(resize),
|
63 |
+
]
|
64 |
+
|
65 |
+
if augment:
|
66 |
+
transforms_list.extend([
|
67 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
68 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
69 |
+
])
|
70 |
+
transforms_list.extend([
|
71 |
+
transforms.ToTensor(),
|
72 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
73 |
+
[0.229, 0.224, 0.225]),
|
74 |
+
])
|
75 |
+
return transforms.Compose(transforms_list)
|
76 |
+
|
77 |
+
|
78 |
+
def preprocess_frames(frames, seq_len=16, resize=(112, 112), augment=False):
|
79 |
+
transform = get_transform(resize=resize, augment=augment)
|
80 |
+
rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
|
81 |
+
total_frames = len(rgb_frames)
|
82 |
+
|
83 |
+
if total_frames >= seq_len:
|
84 |
+
indices = np.linspace(0, total_frames - 1, seq_len, dtype=int)
|
85 |
+
else:
|
86 |
+
indices = np.pad(np.arange(total_frames), (0, seq_len - total_frames), mode='wrap')
|
87 |
+
|
88 |
+
sampled_frames = [rgb_frames[i] for i in indices]
|
89 |
+
transformed_frames = [transform(frame) for frame in sampled_frames]
|
90 |
+
frames_tensor = torch.stack(transformed_frames) # [T, C, H, W]
|
91 |
+
return frames_tensor
|
92 |
+
|
93 |
+
|
94 |
+
def load_action_model(model_path="best_model.pt", device='cpu',
|
95 |
+
num_classes=5, hidden_size=128):
|
96 |
+
if not os.path.exists(model_path):
|
97 |
+
print(f"[ERROR] Model file not found: {model_path}")
|
98 |
+
return None
|
99 |
+
model = CNN_GRU(num_classes=num_classes, hidden_size=hidden_size)
|
100 |
+
checkpoint = torch.load(model_path, map_location=device)
|
101 |
+
model.load_state_dict(checkpoint)
|
102 |
+
model.to(device)
|
103 |
+
model.eval()
|
104 |
+
print(f"[INFO] Loaded model from {model_path} on {device}")
|
105 |
+
return model
|
106 |
+
|
107 |
+
|
108 |
+
def predict_action(model, frames_tensor, label_map_path="label_map.json", device="cpu", top_k=3):
|
109 |
+
if model is None:
|
110 |
+
return {"action": "Model not loaded", "confidence": 0.0, "top_predictions": []}
|
111 |
+
|
112 |
+
idx_to_class = {}
|
113 |
+
if os.path.exists(label_map_path):
|
114 |
+
try:
|
115 |
+
with open(label_map_path, 'r') as f:
|
116 |
+
label_map = json.load(f)
|
117 |
+
idx_to_class = {v: k for k, v in label_map.items()}
|
118 |
+
print(f"[INFO] Loaded {len(idx_to_class)} classes from {label_map_path}")
|
119 |
+
except Exception as e:
|
120 |
+
print(f"[WARNING] Could not load label map: {e}")
|
121 |
+
|
122 |
+
if not idx_to_class and UCF101_AVAILABLE:
|
123 |
+
idx_to_class = {i: class_name for i, class_name in enumerate(UCF101_CLASSES)}
|
124 |
+
print("[INFO] Using default UCF101 class mapping")
|
125 |
+
elif not idx_to_class:
|
126 |
+
idx_to_class = {0: 'CricketShot', 1: 'PlayingCello', 2: 'Punch',
|
127 |
+
3: 'ShavingBeard', 4: 'TennisSwing'}
|
128 |
+
print("[WARNING] Using minimal default labels.")
|
129 |
+
|
130 |
+
try:
|
131 |
+
with torch.no_grad():
|
132 |
+
frames_tensor = frames_tensor.unsqueeze(0).to(device) # [1, T, C, H, W]
|
133 |
+
output = model(frames_tensor)
|
134 |
+
probabilities = torch.softmax(output, dim=1)
|
135 |
+
top_k_probs, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)))
|
136 |
+
|
137 |
+
predicted_idx = top_k_indices[0][0].item()
|
138 |
+
predicted_class = idx_to_class.get(predicted_idx, f"Class_{predicted_idx}")
|
139 |
+
confidence = top_k_probs[0][0].item()
|
140 |
+
|
141 |
+
top_predictions = [
|
142 |
+
{"class": idx_to_class.get(idx.item(), f"Class_{idx.item()}"),
|
143 |
+
"confidence": prob.item()}
|
144 |
+
for prob, idx in zip(top_k_probs[0], top_k_indices[0])
|
145 |
+
]
|
146 |
+
|
147 |
+
return {
|
148 |
+
"action": predicted_class,
|
149 |
+
"confidence": confidence,
|
150 |
+
"top_predictions": top_predictions
|
151 |
+
}
|
152 |
+
except Exception as e:
|
153 |
+
print(f"[ERROR] Prediction failed: {e}")
|
154 |
+
return {"action": "Error", "confidence": 0.0, "top_predictions": []}
|
155 |
+
|
156 |
+
|
157 |
+
def log_action_prediction(action_label, confidence, log_file="logs/action_log.txt"):
|
158 |
+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
159 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
160 |
+
with open(log_file, 'a', encoding='utf-8') as f:
|
161 |
+
f.write(f"[{timestamp}] ACTION: {action_label} (confidence: {confidence:.2f})\n")
|
best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59c240bbc50fb6027bd31a0e0470a95450ab95c82d07cb9b13ae1b87da32821c
|
3 |
+
size 11307065
|
label_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"CricketShot": 0,
|
3 |
+
"PlayingCello": 1,
|
4 |
+
"Punch": 2,
|
5 |
+
"ShavingBeard": 3,
|
6 |
+
"TennisSwing": 4
|
7 |
+
}
|