dyagnosys-free / app /model.py
vitorcalvi's picture
1
18c46ab
# model.py
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import logging
from app.model_architectures import ResNet50, LSTMPyTorch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Determine the device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Define paths
STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
def load_model(model_class, model_path, *args, **kwargs):
model = model_class(*args, **kwargs).to(device)
if os.path.exists(model_path):
try:
state_dict = torch.load(model_path, map_location=device)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys when loading model from {model_path}: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys when loading model from {model_path}: {unexpected_keys}")
model.eval()
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Error loading model from {model_path}: {str(e)}")
logger.info("Initializing with random weights.")
else:
logger.warning(f"Model file not found at {model_path}. Initializing with random weights.")
return model
# Load the static model
pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3)
# Load the dynamic model
pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7)
# Set up GradCAM
target_layers = [pth_model_static.layer4[-1]] # Adjusted to match the updated model
cam = GradCAM(model=pth_model_static, target_layers=target_layers)
# Define image preprocessing
pth_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def pth_processing(img):
img = pth_transform(img).unsqueeze(0).to(device)
return img
# Additional utility functions...
if __name__ == "__main__":
logger.info("Model initialization complete.")