File size: 6,334 Bytes
26275ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import cv2
import json
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from datetime import datetime
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
try:
from ucf101_config import UCF101_CLASSES
UCF101_AVAILABLE = True
except ImportError:
UCF101_AVAILABLE = False
print("[WARNING] UCF101 config not found, using default configuration")
class CNN_GRU(nn.Module):
def __init__(self, cnn_model='mobilenetv2', hidden_size=128, num_layers=1,
num_classes=5, dropout=0.5, FREEZE_BACKBONE=True):
super(CNN_GRU, self).__init__()
if cnn_model == 'mobilenetv2':
cnn = models.mobilenet_v2(pretrained=True)
self.cnn_out_features = cnn.last_channel
self.cnn = cnn.features
elif cnn_model == 'efficientnet_b0':
import timm
cnn = timm.create_model('efficientnet_b0', pretrained=True)
self.cnn_out_features = cnn.classifier.in_features
cnn.classifier = nn.Identity()
self.cnn = cnn
else:
raise ValueError("Invalid CNN model")
if FREEZE_BACKBONE:
for p in self.cnn.parameters():
p.requires_grad = False
self.gru = nn.GRU(self.cnn_out_features,
hidden_size,
num_layers=num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
b, t, c, h, w = x.size()
x = x.view(b * t, c, h, w)
feats = self.cnn(x)
feats = F.adaptive_avg_pool2d(feats, 1).view(b, t, -1)
out, _ = self.gru(feats)
out = self.dropout(out[:, -1])
return self.fc(out)
def get_transform(resize=(112, 112), augment=False):
transforms_list = [
transforms.ToPILImage(),
transforms.Resize(resize),
]
if augment:
transforms_list.extend([
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomHorizontalFlip(p=0.5),
])
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
return transforms.Compose(transforms_list)
def preprocess_frames(frames, seq_len=16, resize=(112, 112), augment=False):
transform = get_transform(resize=resize, augment=augment)
rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
total_frames = len(rgb_frames)
if total_frames >= seq_len:
indices = np.linspace(0, total_frames - 1, seq_len, dtype=int)
else:
indices = np.pad(np.arange(total_frames), (0, seq_len - total_frames), mode='wrap')
sampled_frames = [rgb_frames[i] for i in indices]
transformed_frames = [transform(frame) for frame in sampled_frames]
frames_tensor = torch.stack(transformed_frames) # [T, C, H, W]
return frames_tensor
def load_action_model(model_path="best_model.pt", device='cpu',
num_classes=5, hidden_size=128):
if not os.path.exists(model_path):
print(f"[ERROR] Model file not found: {model_path}")
return None
model = CNN_GRU(num_classes=num_classes, hidden_size=hidden_size)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"[INFO] Loaded model from {model_path} on {device}")
return model
def predict_action(model, frames_tensor, label_map_path="label_map.json", device="cpu", top_k=3):
if model is None:
return {"action": "Model not loaded", "confidence": 0.0, "top_predictions": []}
idx_to_class = {}
if os.path.exists(label_map_path):
try:
with open(label_map_path, 'r') as f:
label_map = json.load(f)
idx_to_class = {v: k for k, v in label_map.items()}
print(f"[INFO] Loaded {len(idx_to_class)} classes from {label_map_path}")
except Exception as e:
print(f"[WARNING] Could not load label map: {e}")
if not idx_to_class and UCF101_AVAILABLE:
idx_to_class = {i: class_name for i, class_name in enumerate(UCF101_CLASSES)}
print("[INFO] Using default UCF101 class mapping")
elif not idx_to_class:
idx_to_class = {0: 'CricketShot', 1: 'PlayingCello', 2: 'Punch',
3: 'ShavingBeard', 4: 'TennisSwing'}
print("[WARNING] Using minimal default labels.")
try:
with torch.no_grad():
frames_tensor = frames_tensor.unsqueeze(0).to(device) # [1, T, C, H, W]
output = model(frames_tensor)
probabilities = torch.softmax(output, dim=1)
top_k_probs, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)))
predicted_idx = top_k_indices[0][0].item()
predicted_class = idx_to_class.get(predicted_idx, f"Class_{predicted_idx}")
confidence = top_k_probs[0][0].item()
top_predictions = [
{"class": idx_to_class.get(idx.item(), f"Class_{idx.item()}"),
"confidence": prob.item()}
for prob, idx in zip(top_k_probs[0], top_k_indices[0])
]
return {
"action": predicted_class,
"confidence": confidence,
"top_predictions": top_predictions
}
except Exception as e:
print(f"[ERROR] Prediction failed: {e}")
return {"action": "Error", "confidence": 0.0, "top_predictions": []}
def log_action_prediction(action_label, confidence, log_file="logs/action_log.txt"):
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"[{timestamp}] ACTION: {action_label} (confidence: {confidence:.2f})\n") |