# model.py import os import torch import torch.nn as nn import torchvision.transforms as transforms from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget import logging from app.model_architectures import ResNet50, LSTMPyTorch # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Determine the device device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") # Define paths STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt' DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt' def load_model(model_class, model_path, *args, **kwargs): model = model_class(*args, **kwargs).to(device) if os.path.exists(model_path): try: state_dict = torch.load(model_path, map_location=device) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys: logger.warning(f"Missing keys when loading model from {model_path}: {missing_keys}") if unexpected_keys: logger.warning(f"Unexpected keys when loading model from {model_path}: {unexpected_keys}") model.eval() logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.error(f"Error loading model from {model_path}: {str(e)}") logger.info("Initializing with random weights.") else: logger.warning(f"Model file not found at {model_path}. Initializing with random weights.") return model # Load the static model pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3) # Load the dynamic model pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7) # Set up GradCAM target_layers = [pth_model_static.layer4[-1]] # Adjusted to match the updated model cam = GradCAM(model=pth_model_static, target_layers=target_layers) # Define image preprocessing pth_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def pth_processing(img): img = pth_transform(img).unsqueeze(0).to(device) return img # Additional utility functions... if __name__ == "__main__": logger.info("Model initialization complete.")