import os import torch import torch.nn as nn import torchvision.transforms as transforms from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget import logging from app.model_architectures import ResNet50, LSTMPyTorch # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Determine the device device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") # Define paths STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt' DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt' def load_model(model_class, model_path, *args, **kwargs): """ Load a model from a given path and instantiate it with the provided arguments. :param model_class: Class of the model (ResNet50, LSTMPyTorch, etc.) :param model_path: Path to the saved model's weights :param args: Positional arguments for model initialization :param kwargs: Keyword arguments for model initialization :return: The loaded model """ model = model_class(*args, **kwargs) if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=device), strict=False) model.to(device) model.eval() logger.info(f"Loaded model from {model_path}") else: logger.error(f"Model file not found: {model_path}") model = model.to(device) return model def load_models(): """ Load the ResNet50 static model and LSTMPyTorch dynamic model along with GradCAM. :return: Tuple (static model, dynamic model, GradCAM instance) """ # Load the static ResNet50 model pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH) # Define LSTMPyTorch parameters input_size = 2048 # Example value: This should match the feature size from the ResNet50 output hidden_size = 512 # Example value: Adjust based on your LSTM architecture num_layers = 2 # Example value: Number of layers in the LSTM num_classes = 7 # Example value: Number of emotion classes # Load the dynamic LSTM model (if available) pth_model_dynamic = None if os.path.exists(DYNAMIC_MODEL_PATH): pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size, hidden_size, num_layers, num_classes) else: logger.error(f"Dynamic model file not found: {DYNAMIC_MODEL_PATH}") # Initialize GradCAM cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.resnet.layer4]) return pth_model_static, pth_model_dynamic, cam