Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
import logging | |
from app.model_architectures import ResNet50, LSTMPyTorch | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Determine the device | |
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') | |
logger.info(f"Using device: {device}") | |
# Define paths | |
STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt' | |
DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt' | |
def load_model(model_class, model_path, *args, **kwargs): | |
""" | |
Load a model from a given path and instantiate it with the provided arguments. | |
:param model_class: Class of the model (ResNet50, LSTMPyTorch, etc.) | |
:param model_path: Path to the saved model's weights | |
:param args: Positional arguments for model initialization | |
:param kwargs: Keyword arguments for model initialization | |
:return: The loaded model | |
""" | |
model = model_class(*args, **kwargs) | |
if os.path.exists(model_path): | |
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) | |
model.to(device) | |
model.eval() | |
logger.info(f"Loaded model from {model_path}") | |
else: | |
logger.error(f"Model file not found: {model_path}") | |
model = model.to(device) | |
return model | |
def load_models(): | |
""" | |
Load the ResNet50 static model and LSTMPyTorch dynamic model along with GradCAM. | |
:return: Tuple (static model, dynamic model, GradCAM instance) | |
""" | |
# Load the static ResNet50 model | |
pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH) | |
# Define LSTMPyTorch parameters | |
input_size = 2048 # Example value: This should match the feature size from the ResNet50 output | |
hidden_size = 512 # Example value: Adjust based on your LSTM architecture | |
num_layers = 2 # Example value: Number of layers in the LSTM | |
num_classes = 7 # Example value: Number of emotion classes | |
# Load the dynamic LSTM model (if available) | |
pth_model_dynamic = None | |
if os.path.exists(DYNAMIC_MODEL_PATH): | |
pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size, hidden_size, num_layers, num_classes) | |
else: | |
logger.error(f"Dynamic model file not found: {DYNAMIC_MODEL_PATH}") | |
# Initialize GradCAM | |
cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.resnet.layer4]) | |
return pth_model_static, pth_model_dynamic, cam | |