Spaces:
Build error
Build error
# model.py | |
import os | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
import logging | |
from app.model_architectures import ResNet50, LSTMPyTorch | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Determine the device | |
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') | |
logger.info(f"Using device: {device}") | |
# Define paths | |
STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt' | |
DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt' | |
def load_model(model_class, model_path, *args, **kwargs): | |
model = model_class(*args, **kwargs).to(device) | |
if os.path.exists(model_path): | |
try: | |
state_dict = torch.load(model_path, map_location=device) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if missing_keys: | |
logger.warning(f"Missing keys when loading model from {model_path}: {missing_keys}") | |
if unexpected_keys: | |
logger.warning(f"Unexpected keys when loading model from {model_path}: {unexpected_keys}") | |
model.eval() | |
logger.info(f"Model loaded successfully from {model_path}") | |
except Exception as e: | |
logger.error(f"Error loading model from {model_path}: {str(e)}") | |
logger.info("Initializing with random weights.") | |
else: | |
logger.warning(f"Model file not found at {model_path}. Initializing with random weights.") | |
return model | |
# Load the static model | |
pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3) | |
# Load the dynamic model | |
pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7) | |
# Set up GradCAM | |
target_layers = [pth_model_static.layer4[-1]] # Adjusted to match the updated model | |
cam = GradCAM(model=pth_model_static, target_layers=target_layers) | |
# Define image preprocessing | |
pth_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def pth_processing(img): | |
img = pth_transform(img).unsqueeze(0).to(device) | |
return img | |
# Additional utility functions... | |
if __name__ == "__main__": | |
logger.info("Model initialization complete.") | |