NanG01 commited on
Commit
26275ae
·
verified ·
1 Parent(s): 7a76894

files_upload

Browse files
Files changed (3) hide show
  1. action_model.py +161 -0
  2. best_model.pt +3 -0
  3. label_map.json +7 -0
action_model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchvision import models, transforms
9
+ from datetime import datetime
10
+ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
11
+
12
+ try:
13
+ from ucf101_config import UCF101_CLASSES
14
+ UCF101_AVAILABLE = True
15
+ except ImportError:
16
+ UCF101_AVAILABLE = False
17
+ print("[WARNING] UCF101 config not found, using default configuration")
18
+
19
+ class CNN_GRU(nn.Module):
20
+ def __init__(self, cnn_model='mobilenetv2', hidden_size=128, num_layers=1,
21
+ num_classes=5, dropout=0.5, FREEZE_BACKBONE=True):
22
+ super(CNN_GRU, self).__init__()
23
+
24
+ if cnn_model == 'mobilenetv2':
25
+ cnn = models.mobilenet_v2(pretrained=True)
26
+ self.cnn_out_features = cnn.last_channel
27
+ self.cnn = cnn.features
28
+ elif cnn_model == 'efficientnet_b0':
29
+ import timm
30
+ cnn = timm.create_model('efficientnet_b0', pretrained=True)
31
+ self.cnn_out_features = cnn.classifier.in_features
32
+ cnn.classifier = nn.Identity()
33
+ self.cnn = cnn
34
+ else:
35
+ raise ValueError("Invalid CNN model")
36
+
37
+ if FREEZE_BACKBONE:
38
+ for p in self.cnn.parameters():
39
+ p.requires_grad = False
40
+
41
+ self.gru = nn.GRU(self.cnn_out_features,
42
+ hidden_size,
43
+ num_layers=num_layers,
44
+ batch_first=True, dropout=dropout if num_layers > 1 else 0)
45
+
46
+ self.dropout = nn.Dropout(dropout)
47
+ self.fc = nn.Linear(hidden_size, num_classes)
48
+
49
+ def forward(self, x):
50
+ b, t, c, h, w = x.size()
51
+ x = x.view(b * t, c, h, w)
52
+ feats = self.cnn(x)
53
+ feats = F.adaptive_avg_pool2d(feats, 1).view(b, t, -1)
54
+ out, _ = self.gru(feats)
55
+ out = self.dropout(out[:, -1])
56
+ return self.fc(out)
57
+
58
+
59
+ def get_transform(resize=(112, 112), augment=False):
60
+ transforms_list = [
61
+ transforms.ToPILImage(),
62
+ transforms.Resize(resize),
63
+ ]
64
+
65
+ if augment:
66
+ transforms_list.extend([
67
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
68
+ transforms.RandomHorizontalFlip(p=0.5),
69
+ ])
70
+ transforms_list.extend([
71
+ transforms.ToTensor(),
72
+ transforms.Normalize([0.485, 0.456, 0.406],
73
+ [0.229, 0.224, 0.225]),
74
+ ])
75
+ return transforms.Compose(transforms_list)
76
+
77
+
78
+ def preprocess_frames(frames, seq_len=16, resize=(112, 112), augment=False):
79
+ transform = get_transform(resize=resize, augment=augment)
80
+ rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
81
+ total_frames = len(rgb_frames)
82
+
83
+ if total_frames >= seq_len:
84
+ indices = np.linspace(0, total_frames - 1, seq_len, dtype=int)
85
+ else:
86
+ indices = np.pad(np.arange(total_frames), (0, seq_len - total_frames), mode='wrap')
87
+
88
+ sampled_frames = [rgb_frames[i] for i in indices]
89
+ transformed_frames = [transform(frame) for frame in sampled_frames]
90
+ frames_tensor = torch.stack(transformed_frames) # [T, C, H, W]
91
+ return frames_tensor
92
+
93
+
94
+ def load_action_model(model_path="best_model.pt", device='cpu',
95
+ num_classes=5, hidden_size=128):
96
+ if not os.path.exists(model_path):
97
+ print(f"[ERROR] Model file not found: {model_path}")
98
+ return None
99
+ model = CNN_GRU(num_classes=num_classes, hidden_size=hidden_size)
100
+ checkpoint = torch.load(model_path, map_location=device)
101
+ model.load_state_dict(checkpoint)
102
+ model.to(device)
103
+ model.eval()
104
+ print(f"[INFO] Loaded model from {model_path} on {device}")
105
+ return model
106
+
107
+
108
+ def predict_action(model, frames_tensor, label_map_path="label_map.json", device="cpu", top_k=3):
109
+ if model is None:
110
+ return {"action": "Model not loaded", "confidence": 0.0, "top_predictions": []}
111
+
112
+ idx_to_class = {}
113
+ if os.path.exists(label_map_path):
114
+ try:
115
+ with open(label_map_path, 'r') as f:
116
+ label_map = json.load(f)
117
+ idx_to_class = {v: k for k, v in label_map.items()}
118
+ print(f"[INFO] Loaded {len(idx_to_class)} classes from {label_map_path}")
119
+ except Exception as e:
120
+ print(f"[WARNING] Could not load label map: {e}")
121
+
122
+ if not idx_to_class and UCF101_AVAILABLE:
123
+ idx_to_class = {i: class_name for i, class_name in enumerate(UCF101_CLASSES)}
124
+ print("[INFO] Using default UCF101 class mapping")
125
+ elif not idx_to_class:
126
+ idx_to_class = {0: 'CricketShot', 1: 'PlayingCello', 2: 'Punch',
127
+ 3: 'ShavingBeard', 4: 'TennisSwing'}
128
+ print("[WARNING] Using minimal default labels.")
129
+
130
+ try:
131
+ with torch.no_grad():
132
+ frames_tensor = frames_tensor.unsqueeze(0).to(device) # [1, T, C, H, W]
133
+ output = model(frames_tensor)
134
+ probabilities = torch.softmax(output, dim=1)
135
+ top_k_probs, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)))
136
+
137
+ predicted_idx = top_k_indices[0][0].item()
138
+ predicted_class = idx_to_class.get(predicted_idx, f"Class_{predicted_idx}")
139
+ confidence = top_k_probs[0][0].item()
140
+
141
+ top_predictions = [
142
+ {"class": idx_to_class.get(idx.item(), f"Class_{idx.item()}"),
143
+ "confidence": prob.item()}
144
+ for prob, idx in zip(top_k_probs[0], top_k_indices[0])
145
+ ]
146
+
147
+ return {
148
+ "action": predicted_class,
149
+ "confidence": confidence,
150
+ "top_predictions": top_predictions
151
+ }
152
+ except Exception as e:
153
+ print(f"[ERROR] Prediction failed: {e}")
154
+ return {"action": "Error", "confidence": 0.0, "top_predictions": []}
155
+
156
+
157
+ def log_action_prediction(action_label, confidence, log_file="logs/action_log.txt"):
158
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
159
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
160
+ with open(log_file, 'a', encoding='utf-8') as f:
161
+ f.write(f"[{timestamp}] ACTION: {action_label} (confidence: {confidence:.2f})\n")
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59c240bbc50fb6027bd31a0e0470a95450ab95c82d07cb9b13ae1b87da32821c
3
+ size 11307065
label_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "CricketShot": 0,
3
+ "PlayingCello": 1,
4
+ "Punch": 2,
5
+ "ShavingBeard": 3,
6
+ "TennisSwing": 4
7
+ }