vitorcalvi commited on
Commit
27ef047
Β·
1 Parent(s): 5164deb
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +3 -3
  3. app/model.py +14 -52
  4. tabs/FACS_analysis.py +16 -46
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -14,7 +14,7 @@ TAB_STRUCTURE = [
14
  def create_demo():
15
  # Import model-related functions here to ensure spaces is imported first
16
  from app.model import load_models
17
-
18
  # Load models outside of the Gradio blocks
19
  pth_model_static, pth_model_dynamic, cam = load_models()
20
 
@@ -28,11 +28,11 @@ def create_demo():
28
  with gr.Tab(sub_tab):
29
  create_fn(pth_model_static, pth_model_dynamic, cam)
30
  gr.HTML(DISCLAIMER_HTML)
31
-
32
  return demo
33
 
34
  # Create the demo instance
35
  demo = create_demo()
36
 
37
  if __name__ == "__main__":
38
- demo.launch()
 
14
  def create_demo():
15
  # Import model-related functions here to ensure spaces is imported first
16
  from app.model import load_models
17
+
18
  # Load models outside of the Gradio blocks
19
  pth_model_static, pth_model_dynamic, cam = load_models()
20
 
 
28
  with gr.Tab(sub_tab):
29
  create_fn(pth_model_static, pth_model_dynamic, cam)
30
  gr.HTML(DISCLAIMER_HTML)
31
+
32
  return demo
33
 
34
  # Create the demo instance
35
  demo = create_demo()
36
 
37
  if __name__ == "__main__":
38
+ demo.launch()
app/model.py CHANGED
@@ -20,59 +20,21 @@ STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
20
  DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
21
 
22
  def load_model(model_class, model_path, *args, **kwargs):
23
- model = model_class(*args, **kwargs).to(device)
24
  if os.path.exists(model_path):
25
- try:
26
- model.load_state_dict(torch.load(model_path, map_location=device))
27
- model.eval()
28
- logger.info(f"Model loaded successfully from {model_path}")
29
- except Exception as e:
30
- logger.error(f"Error loading model from {model_path}: {str(e)}")
31
- logger.info("Initializing with random weights.")
32
  else:
33
- logger.warning(f"Model file not found at {model_path}. Initializing with random weights.")
 
34
  return model
35
 
36
- # Load the static model
37
- pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3)
38
-
39
- # Load the dynamic model
40
- pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7)
41
-
42
- # Set up GradCAM
43
- target_layers = [pth_model_static.resnet.layer4[-1]]
44
- cam = GradCAM(model=pth_model_static, target_layers=target_layers)
45
-
46
- # Define image preprocessing
47
- pth_transform = transforms.Compose([
48
- transforms.Resize((224, 224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
51
- ])
52
-
53
- def pth_processing(img):
54
- img = pth_transform(img).unsqueeze(0).to(device)
55
- return img
56
-
57
- def predict_emotion(img):
58
- with torch.no_grad():
59
- output = pth_model_static(pth_processing(img))
60
- _, predicted = torch.max(output, 1)
61
- return predicted.item()
62
-
63
- def get_emotion_probabilities(img):
64
- with torch.no_grad():
65
- output = nn.functional.softmax(pth_model_static(pth_processing(img)), dim=1)
66
- return output.squeeze().cpu().numpy()
67
-
68
- def generate_cam(img):
69
- input_tensor = pth_processing(img)
70
- targets = [ClassifierOutputTarget(predict_emotion(img))]
71
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
72
- return grayscale_cam[0, :]
73
-
74
- # Add any other necessary functions or variables here
75
-
76
- if __name__ == "__main__":
77
- logger.info("Model initialization complete.")
78
- # You can add some test code here to verify everything is working correctly
 
20
  DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
21
 
22
  def load_model(model_class, model_path, *args, **kwargs):
23
+ model = model_class(*args, **kwargs)
24
  if os.path.exists(model_path):
25
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
26
+ model.to(device)
27
+ model.eval()
28
+ logger.info(f"Loaded model from {model_path}")
 
 
 
29
  else:
30
+ logger.error(f"Model file not found: {model_path}")
31
+ model = model.to(device)
32
  return model
33
 
34
+ def load_models():
35
+ pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH)
36
+ pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH)
37
+
38
+ cam = GradCAM(model=pth_model_static, target_layers=[pth_model_static.layer4], use_cuda=device == 'cuda')
39
+
40
+ return pth_model_static, pth_model_dynamic, cam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tabs/FACS_analysis.py CHANGED
@@ -1,55 +1,25 @@
 
 
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
  import matplotlib.pyplot as plt
5
- from app.app_utils import preprocess_frame_and_predict_aus
6
-
7
- # Define the AUs associated with stress, anxiety, and depression
8
- STRESS_AUS = [4, 7, 17, 23, 24]
9
- ANXIETY_AUS = [1, 2, 4, 5, 20]
10
- DEPRESSION_AUS = [1, 4, 15, 17]
11
 
 
 
 
 
12
  AU_DESCRIPTIONS = {
13
- 1: "Inner Brow Raiser",
14
- 2: "Outer Brow Raiser",
15
- 4: "Brow Lowerer",
16
- 5: "Upper Lid Raiser",
17
- 7: "Lid Tightener",
18
- 15: "Lip Corner Depressor",
19
- 17: "Chin Raiser",
20
- 20: "Lip Stretcher",
21
- 23: "Lip Tightener",
22
- 24: "Lip Pressor"
23
  }
24
 
25
- def normalize_score(score):
26
- return max(0, min(1, (score + 1.5) / 3)) # Adjust the range as needed
 
 
 
 
27
 
28
- def process_video_for_facs(video_path):
29
- cap = cv2.VideoCapture(video_path)
30
- frames = []
31
- au_intensities_list = []
32
-
33
- while True:
34
- ret, frame = cap.read()
35
- if not ret:
36
- break
37
-
38
- processed_frame, au_intensities, _ = preprocess_frame_and_predict_aus(frame)
39
-
40
- if processed_frame is not None and au_intensities is not None:
41
- frames.append(processed_frame)
42
- au_intensities_list.append(au_intensities)
43
-
44
- cap.release()
45
-
46
- if not frames:
47
- return None, None
48
-
49
- # Calculate average AU intensities
50
- avg_au_intensities = np.mean(au_intensities_list, axis=0)
51
-
52
- # Calculate and normalize emotional state scores
53
  stress_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in STRESS_AUS if au <= len(avg_au_intensities)]))
54
  anxiety_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in ANXIETY_AUS if au <= len(avg_au_intensities)]))
55
  depression_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in DEPRESSION_AUS if au <= len(avg_au_intensities)]))
@@ -82,7 +52,7 @@ def process_video_for_facs(video_path):
82
 
83
  return frames[-1], fig # Return the last processed frame and the plot
84
 
85
- def create_facs_analysis_tab():
86
  with gr.Row():
87
  with gr.Column(scale=1):
88
  input_video = gr.Video()
 
1
+ from gradio import Interface
2
+ from app.app_utils import preprocess_frame_and_predict_aus
3
  import gradio as gr
 
 
4
  import matplotlib.pyplot as plt
 
 
 
 
 
 
5
 
6
+ # Define stress, anxiety, and depression AU mappings
7
+ STRESS_AUS = [1, 2, 4]
8
+ ANXIETY_AUS = [5, 9, 14]
9
+ DEPRESSION_AUS = [15, 17, 20]
10
  AU_DESCRIPTIONS = {
11
+ 1: "Inner Brow Raiser", 2: "Outer Brow Raiser", 4: "Brow Lowerer",
12
+ 5: "Upper Lid Raiser", 9: "Nose Wrinkler", 14: "Dimpler",
13
+ 15: "Lip Corner Depressor", 17: "Chin Raiser", 20: "Lip Stretcher"
 
 
 
 
 
 
 
14
  }
15
 
16
+ def process_video_for_facs(video):
17
+ frames, avg_au_intensities = preprocess_frame_and_predict_aus(video)
18
+
19
+ # Calculate emotional state scores
20
+ def normalize_score(score):
21
+ return max(0, min(1, score))
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  stress_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in STRESS_AUS if au <= len(avg_au_intensities)]))
24
  anxiety_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in ANXIETY_AUS if au <= len(avg_au_intensities)]))
25
  depression_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in DEPRESSION_AUS if au <= len(avg_au_intensities)]))
 
52
 
53
  return frames[-1], fig # Return the last processed frame and the plot
54
 
55
+ def create_facs_analysis_tab(pth_model_static, pth_model_dynamic, cam):
56
  with gr.Row():
57
  with gr.Column(scale=1):
58
  input_video = gr.Video()