mmesa-gpu-gitex / app /model.py
vitorcalvi's picture
1
1484b84
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import logging
from app.model_architectures import ResNet50, LSTMPyTorch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Determine the device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Define paths
STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
def load_model(model_class, model_path, *args, **kwargs):
"""
Load a model from a given path and instantiate it with the provided arguments.
:param model_class: Class of the model (ResNet50, LSTMPyTorch, etc.)
:param model_path: Path to the saved model's weights
:param args: Positional arguments for model initialization
:param kwargs: Keyword arguments for model initialization
:return: The loaded model
"""
model = model_class(*args, **kwargs)
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model.to(device)
model.eval()
logger.info(f"Loaded model from {model_path}")
else:
logger.error(f"Model file not found: {model_path}")
model = model.to(device)
return model
def load_models():
"""
Load the ResNet50 static model and LSTMPyTorch dynamic model along with GradCAM.
:return: Tuple (static model, dynamic model, GradCAM instance)
"""
# Load the static ResNet50 model
pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH)
# Define LSTMPyTorch parameters
input_size = 2048 # Example value: This should match the feature size from the ResNet50 output
hidden_size = 512 # Example value: Adjust based on your LSTM architecture
num_layers = 2 # Example value: Number of layers in the LSTM
num_classes = 7 # Example value: Number of emotion classes
# Load the dynamic LSTM model (if available)
pth_model_dynamic = None
if os.path.exists(DYNAMIC_MODEL_PATH):
pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size, hidden_size, num_layers, num_classes)
else:
logger.error(f"Dynamic model file not found: {DYNAMIC_MODEL_PATH}")
# Initialize GradCAM
cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.resnet.layer4])
return pth_model_static, pth_model_dynamic, cam