File size: 2,536 Bytes
18c46ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# model.py

import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import logging
from app.model_architectures import ResNet50, LSTMPyTorch

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Determine the device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Define paths
STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'

def load_model(model_class, model_path, *args, **kwargs):
    model = model_class(*args, **kwargs).to(device)
    if os.path.exists(model_path):
        try:
            state_dict = torch.load(model_path, map_location=device)
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys:
                logger.warning(f"Missing keys when loading model from {model_path}: {missing_keys}")
            if unexpected_keys:
                logger.warning(f"Unexpected keys when loading model from {model_path}: {unexpected_keys}")
            model.eval()
            logger.info(f"Model loaded successfully from {model_path}")
        except Exception as e:
            logger.error(f"Error loading model from {model_path}: {str(e)}")
            logger.info("Initializing with random weights.")
    else:
        logger.warning(f"Model file not found at {model_path}. Initializing with random weights.")
    return model

# Load the static model
pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3)

# Load the dynamic model
pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7)

# Set up GradCAM
target_layers = [pth_model_static.layer4[-1]]  # Adjusted to match the updated model
cam = GradCAM(model=pth_model_static, target_layers=target_layers)

# Define image preprocessing
pth_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def pth_processing(img):
    img = pth_transform(img).unsqueeze(0).to(device)
    return img

# Additional utility functions...

if __name__ == "__main__":
    logger.info("Model initialization complete.")