vitorcalvi commited on
Commit
662fda2
Β·
1 Parent(s): fdc3f39
Files changed (1) hide show
  1. app/model.py +27 -1
app/model.py CHANGED
@@ -20,6 +20,15 @@ STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
20
  DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
21
 
22
  def load_model(model_class, model_path, *args, **kwargs):
 
 
 
 
 
 
 
 
 
23
  model = model_class(*args, **kwargs)
24
  if os.path.exists(model_path):
25
  model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
@@ -32,9 +41,26 @@ def load_model(model_class, model_path, *args, **kwargs):
32
  return model
33
 
34
  def load_models():
 
 
 
 
 
 
35
  pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH)
36
- pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH)
37
 
 
 
 
 
 
 
 
 
 
 
38
  cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.layer4], use_cuda=device == 'cuda')
39
 
40
  return pth_model_static, pth_model_dynamic, cam
 
 
 
20
  DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
21
 
22
  def load_model(model_class, model_path, *args, **kwargs):
23
+ """
24
+ Load a model from a given path and instantiate it with the provided arguments.
25
+
26
+ :param model_class: Class of the model (ResNet50, LSTMPyTorch, etc.)
27
+ :param model_path: Path to the saved model's weights
28
+ :param args: Positional arguments for model initialization
29
+ :param kwargs: Keyword arguments for model initialization
30
+ :return: The loaded model
31
+ """
32
  model = model_class(*args, **kwargs)
33
  if os.path.exists(model_path):
34
  model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
 
41
  return model
42
 
43
  def load_models():
44
+ """
45
+ Load the ResNet50 static model and LSTMPyTorch dynamic model along with GradCAM.
46
+
47
+ :return: Tuple (static model, dynamic model, GradCAM instance)
48
+ """
49
+ # Load the static ResNet50 model
50
  pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH)
 
51
 
52
+ # Define LSTMPyTorch parameters (set the correct values for your model)
53
+ input_size = 2048 # Example value: This should match the feature size from the ResNet50 output
54
+ hidden_size = 512 # Example value: Adjust based on your LSTM architecture
55
+ num_layers = 2 # Example value: Number of layers in the LSTM
56
+ num_classes = 7 # Example value: Number of emotion classes
57
+
58
+ # Load the dynamic LSTM model with the correct arguments
59
+ pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size, hidden_size, num_layers, num_classes)
60
+
61
+ # Initialize GradCAM
62
  cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.layer4], use_cuda=device == 'cuda')
63
 
64
  return pth_model_static, pth_model_dynamic, cam
65
+
66
+ # Optionally, additional utility functions for model processing can be added here.