vitorcalvi commited on
Commit
fd37619
Β·
1 Parent(s): 793723d
Files changed (5) hide show
  1. app.py +0 -11
  2. app/app_utils.py +0 -333
  3. app/face_utils.py +0 -68
  4. app/model.py +0 -78
  5. app/model_architectures.py +0 -46
app.py CHANGED
@@ -4,13 +4,6 @@ from tabs.FACS_analysis import create_facs_analysis_tab
4
  from ui_components import CUSTOM_CSS, HEADER_HTML, DISCLAIMER_HTML
5
  import spaces # Importing spaces to utilize Zero GPU
6
 
7
- # Initialize Zero GPU
8
- if torch.cuda.is_available():
9
- zero = torch.Tensor([0]).cuda()
10
- print(f"Initial device: {zero.device}")
11
- else:
12
- zero = torch.Tensor([0])
13
- print("CUDA is not available. Using CPU.")
14
 
15
  # Define the tab structure
16
  TAB_STRUCTURE = [
@@ -22,10 +15,6 @@ TAB_STRUCTURE = [
22
  # Decorate GPU-dependent function with Zero GPU
23
  @spaces.GPU(duration=120) # Allocates GPU for 120 seconds when needed
24
  def create_demo():
25
- if torch.cuda.is_available():
26
- print(f"Device inside create_demo: {zero.device}")
27
- else:
28
- print("CUDA is not available inside create_demo.")
29
 
30
  # Gradio blocks to create the interface
31
  with gr.Blocks(css=CUSTOM_CSS) as demo:
 
4
  from ui_components import CUSTOM_CSS, HEADER_HTML, DISCLAIMER_HTML
5
  import spaces # Importing spaces to utilize Zero GPU
6
 
 
 
 
 
 
 
 
7
 
8
  # Define the tab structure
9
  TAB_STRUCTURE = [
 
15
  # Decorate GPU-dependent function with Zero GPU
16
  @spaces.GPU(duration=120) # Allocates GPU for 120 seconds when needed
17
  def create_demo():
 
 
 
 
18
 
19
  # Gradio blocks to create the interface
20
  with gr.Blocks(css=CUSTOM_CSS) as demo:
app/app_utils.py DELETED
@@ -1,333 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import mediapipe as mp
4
- from PIL import Image
5
- import cv2
6
- from pytorch_grad_cam.utils.image import show_cam_on_image
7
- import matplotlib.pyplot as plt
8
-
9
- # Importing necessary components for the Gradio app
10
- from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
11
- from app.face_utils import get_box, display_info
12
- from app.config import DICT_EMO, config_data
13
- from app.plot import statistics_plot
14
-
15
- mp_face_mesh = mp.solutions.face_mesh
16
-
17
- def get_device():
18
- if torch.backends.mps.is_available():
19
- return torch.device("mps")
20
- elif torch.cuda.is_available():
21
- return torch.device("cuda")
22
- else:
23
- return torch.device("cpu")
24
-
25
- device = get_device()
26
- print(f"Using device: {device}")
27
-
28
- # Move models to the selected device
29
- pth_model_static = pth_model_static.to(device)
30
- pth_model_dynamic = pth_model_dynamic.to(device)
31
-
32
- def preprocess_image_and_predict(inp):
33
- inp = np.array(inp)
34
-
35
- if inp is None:
36
- return None, None, None
37
-
38
- try:
39
- h, w = inp.shape[:2]
40
- except Exception:
41
- return None, None, None
42
-
43
- with mp_face_mesh.FaceMesh(
44
- max_num_faces=1,
45
- refine_landmarks=False,
46
- min_detection_confidence=0.5,
47
- min_tracking_confidence=0.5,
48
- ) as face_mesh:
49
- results = face_mesh.process(inp)
50
- if results.multi_face_landmarks:
51
- for fl in results.multi_face_landmarks:
52
- startX, startY, endX, endY = get_box(fl, w, h)
53
- cur_face = inp[startY:endY, startX:endX]
54
- cur_face_n = pth_processing(Image.fromarray(cur_face)).to(device)
55
- with torch.no_grad():
56
- prediction = (
57
- torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1)
58
- .detach()
59
- .cpu()
60
- .numpy()[0]
61
- )
62
- confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
63
- grayscale_cam = cam(input_tensor=cur_face_n)
64
- grayscale_cam = grayscale_cam[0, :]
65
- cur_face_hm = cv2.resize(cur_face,(224,224))
66
- cur_face_hm = np.float32(cur_face_hm) / 255
67
- heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
68
-
69
- return cur_face, heatmap, confidences
70
-
71
- def preprocess_frame_and_predict_aus(frame):
72
- if len(frame.shape) == 2:
73
- frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
74
- elif frame.shape[2] == 4:
75
- frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
76
-
77
- with mp_face_mesh.FaceMesh(
78
- max_num_faces=1,
79
- refine_landmarks=False,
80
- min_detection_confidence=0.5,
81
- min_tracking_confidence=0.5
82
- ) as face_mesh:
83
- results = face_mesh.process(frame)
84
-
85
- if results.multi_face_landmarks:
86
- h, w = frame.shape[:2]
87
- for fl in results.multi_face_landmarks:
88
- startX, startY, endX, endY = get_box(fl, w, h)
89
- cur_face = frame[startY:endY, startX:endX]
90
- cur_face_n = pth_processing(Image.fromarray(cur_face)).to(device)
91
-
92
- with torch.no_grad():
93
- features = pth_model_static(cur_face_n)
94
- au_intensities = features_to_au_intensities(features)
95
-
96
- grayscale_cam = cam(input_tensor=cur_face_n)
97
- grayscale_cam = grayscale_cam[0, :]
98
- cur_face_hm = cv2.resize(cur_face, (224, 224))
99
- cur_face_hm = np.float32(cur_face_hm) / 255
100
- heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
101
-
102
- return cur_face, au_intensities, heatmap
103
-
104
- return None, None, None
105
-
106
- def features_to_au_intensities(features):
107
- features_np = features.detach().cpu().numpy()[0]
108
- au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min())
109
- return au_intensities[:24] # Assuming we want 24 AUs
110
-
111
- def preprocess_video_and_predict(video):
112
- cap = cv2.VideoCapture(video)
113
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
114
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
115
- fps = np.round(cap.get(cv2.CAP_PROP_FPS))
116
-
117
- path_save_video_face = 'result_face.mp4'
118
- vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
119
-
120
- path_save_video_hm = 'result_hm.mp4'
121
- vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
122
-
123
- lstm_features = []
124
- count_frame = 1
125
- count_face = 0
126
- probs = []
127
- frames = []
128
- au_intensities_list = []
129
- last_output = None
130
- last_heatmap = None
131
- last_au_intensities = None
132
- cur_face = None
133
-
134
- with mp_face_mesh.FaceMesh(
135
- max_num_faces=1,
136
- refine_landmarks=False,
137
- min_detection_confidence=0.5,
138
- min_tracking_confidence=0.5) as face_mesh:
139
-
140
- while cap.isOpened():
141
- _, frame = cap.read()
142
- if frame is None: break
143
-
144
- frame_copy = frame.copy()
145
- frame_copy.flags.writeable = False
146
- frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
147
- results = face_mesh.process(frame_copy)
148
- frame_copy.flags.writeable = True
149
-
150
- if results.multi_face_landmarks:
151
- for fl in results.multi_face_landmarks:
152
- startX, startY, endX, endY = get_box(fl, w, h)
153
- cur_face = frame_copy[startY:endY, startX: endX]
154
-
155
- if count_face%config_data.FRAME_DOWNSAMPLING == 0:
156
- cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(device)
157
- with torch.no_grad():
158
- features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().cpu().numpy()
159
- au_intensities = features_to_au_intensities(pth_model_static(cur_face_copy))
160
-
161
- grayscale_cam = cam(input_tensor=cur_face_copy)
162
- grayscale_cam = grayscale_cam[0, :]
163
- cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA)
164
- cur_face_hm = np.float32(cur_face_hm) / 255
165
- heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False)
166
- last_heatmap = heatmap
167
- last_au_intensities = au_intensities
168
-
169
- if len(lstm_features) == 0:
170
- lstm_features = [features]*10
171
- else:
172
- lstm_features = lstm_features[1:] + [features]
173
-
174
- lstm_f = torch.from_numpy(np.vstack(lstm_features)).to(device)
175
- lstm_f = torch.unsqueeze(lstm_f, 0)
176
- with torch.no_grad():
177
- output = pth_model_dynamic(lstm_f).detach().cpu().numpy()
178
- last_output = output
179
-
180
- if count_face == 0:
181
- count_face += 1
182
-
183
- else:
184
- if last_output is not None:
185
- output = last_output
186
- heatmap = last_heatmap
187
- au_intensities = last_au_intensities
188
-
189
- elif last_output is None:
190
- output = np.empty((1, 7))
191
- output[:] = np.nan
192
- au_intensities = np.empty(24)
193
- au_intensities[:] = np.nan
194
-
195
- probs.append(output[0])
196
- frames.append(count_frame)
197
- au_intensities_list.append(au_intensities)
198
- else:
199
- if last_output is not None:
200
- lstm_features = []
201
- empty = np.empty((7))
202
- empty[:] = np.nan
203
- probs.append(empty)
204
- frames.append(count_frame)
205
- au_intensities_list.append(np.full(24, np.nan))
206
-
207
- if cur_face is not None:
208
- heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3)
209
-
210
- cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
211
- cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
212
- cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
213
- vid_writer_face.write(cur_face)
214
- vid_writer_hm.write(heatmap_f)
215
-
216
- count_frame += 1
217
- if count_face != 0:
218
- count_face += 1
219
-
220
- vid_writer_face.release()
221
- vid_writer_hm.release()
222
-
223
- stat = statistics_plot(frames, probs)
224
- au_stat = au_statistics_plot(frames, au_intensities_list)
225
-
226
- if not stat or not au_stat:
227
- return None, None, None, None, None
228
-
229
- return video, path_save_video_face, path_save_video_hm, stat, au_stat
230
-
231
- # The rest of the functions remain the same
232
- # ...
233
-
234
- def au_statistics_plot(frames, au_intensities_list):
235
- fig, ax = plt.subplots(figsize=(12, 6))
236
- au_intensities_array = np.array(au_intensities_list)
237
-
238
- for i in range(au_intensities_array.shape[1]):
239
- ax.plot(frames, au_intensities_array[:, i], label=f'AU{i+1}')
240
-
241
- ax.set_xlabel('Frame')
242
- ax.set_ylabel('AU Intensity')
243
- ax.set_title('Action Unit Intensities Over Time')
244
- ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
245
- plt.tight_layout()
246
- return fig
247
-
248
- def preprocess_video_and_predict_sleep_quality(video):
249
- cap = cv2.VideoCapture(video)
250
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
251
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
252
- fps = np.round(cap.get(cv2.CAP_PROP_FPS))
253
-
254
- path_save_video_original = 'result_original.mp4'
255
- path_save_video_face = 'result_face.mp4'
256
- path_save_video_sleep = 'result_sleep.mp4'
257
-
258
- vid_writer_original = cv2.VideoWriter(path_save_video_original, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
259
- vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
260
- vid_writer_sleep = cv2.VideoWriter(path_save_video_sleep, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
261
-
262
- frames = []
263
- sleep_quality_scores = []
264
- eye_bags_images = []
265
-
266
- with mp_face_mesh.FaceMesh(
267
- max_num_faces=1,
268
- refine_landmarks=False,
269
- min_detection_confidence=0.5,
270
- min_tracking_confidence=0.5) as face_mesh:
271
-
272
- while cap.isOpened():
273
- ret, frame = cap.read()
274
- if not ret:
275
- break
276
-
277
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
278
- results = face_mesh.process(frame_rgb)
279
-
280
- if results.multi_face_landmarks:
281
- for fl in results.multi_face_landmarks:
282
- startX, startY, endX, endY = get_box(fl, w, h)
283
- cur_face = frame_rgb[startY:endY, startX:endX]
284
-
285
- sleep_quality_score, eye_bags_image = analyze_sleep_quality(cur_face)
286
- sleep_quality_scores.append(sleep_quality_score)
287
- eye_bags_images.append(cv2.resize(eye_bags_image, (224, 224)))
288
-
289
- sleep_quality_viz = create_sleep_quality_visualization(cur_face, sleep_quality_score)
290
-
291
- cur_face = cv2.resize(cur_face, (224, 224))
292
-
293
- vid_writer_face.write(cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR))
294
- vid_writer_sleep.write(sleep_quality_viz)
295
-
296
- vid_writer_original.write(frame)
297
- frames.append(len(frames) + 1)
298
-
299
- cap.release()
300
- vid_writer_original.release()
301
- vid_writer_face.release()
302
- vid_writer_sleep.release()
303
-
304
- sleep_stat = sleep_quality_statistics_plot(frames, sleep_quality_scores)
305
-
306
- if eye_bags_images:
307
- average_eye_bags_image = np.mean(np.array(eye_bags_images), axis=0).astype(np.uint8)
308
- else:
309
- average_eye_bags_image = np.zeros((224, 224, 3), dtype=np.uint8)
310
-
311
- return (path_save_video_original, path_save_video_face, path_save_video_sleep,
312
- average_eye_bags_image, sleep_stat)
313
-
314
- def analyze_sleep_quality(face_image):
315
- # Placeholder function - implement your sleep quality analysis here
316
- sleep_quality_score = np.random.random()
317
- eye_bags_image = cv2.resize(face_image, (224, 224))
318
- return sleep_quality_score, eye_bags_image
319
-
320
- def create_sleep_quality_visualization(face_image, sleep_quality_score):
321
- viz = face_image.copy()
322
- cv2.putText(viz, f"Sleep Quality: {sleep_quality_score:.2f}", (10, 30),
323
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
324
- return cv2.cvtColor(viz, cv2.COLOR_RGB2BGR)
325
-
326
- def sleep_quality_statistics_plot(frames, sleep_quality_scores):
327
- # Placeholder function - implement your statistics plotting here
328
- fig, ax = plt.subplots()
329
- ax.plot(frames, sleep_quality_scores)
330
- ax.set_xlabel('Frame')
331
- ax.set_ylabel('Sleep Quality Score')
332
- ax.set_title('Sleep Quality Over Time')
333
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/face_utils.py DELETED
@@ -1,68 +0,0 @@
1
- """
2
- File: face_utils.py
3
- Author: Elena Ryumina and Dmitry Ryumin
4
- Description: This module contains utility functions related to facial landmarks and image processing.
5
- License: MIT License
6
- """
7
-
8
- import numpy as np
9
- import math
10
- import cv2
11
-
12
-
13
- def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
14
- x_px = min(math.floor(normalized_x * image_width), image_width - 1)
15
- y_px = min(math.floor(normalized_y * image_height), image_height - 1)
16
- return x_px, y_px
17
-
18
-
19
- def get_box(fl, w, h):
20
- idx_to_coors = {}
21
- for idx, landmark in enumerate(fl.landmark):
22
- landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)
23
- if landmark_px:
24
- idx_to_coors[idx] = landmark_px
25
-
26
- x_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 0])
27
- y_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 1])
28
- endX = np.max(np.asarray(list(idx_to_coors.values()))[:, 0])
29
- endY = np.max(np.asarray(list(idx_to_coors.values()))[:, 1])
30
-
31
- (startX, startY) = (max(0, x_min), max(0, y_min))
32
- (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
33
-
34
- return startX, startY, endX, endY
35
-
36
- def display_info(img, text, margin=1.0, box_scale=1.0):
37
- img_copy = img.copy()
38
- img_h, img_w, _ = img_copy.shape
39
- line_width = int(min(img_h, img_w) * 0.001)
40
- thickness = max(int(line_width / 3), 1)
41
-
42
- font_face = cv2.FONT_HERSHEY_SIMPLEX
43
- font_color = (0, 0, 0)
44
- font_scale = thickness / 1.5
45
-
46
- t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
47
-
48
- margin_n = int(t_h * margin)
49
- sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
50
- img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
51
-
52
- white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
53
-
54
- img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
55
- img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
56
-
57
- cv2.putText(img=img_copy,
58
- text=text,
59
- org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
60
- 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
61
- fontFace=font_face,
62
- fontScale=font_scale,
63
- color=font_color,
64
- thickness=thickness,
65
- lineType=cv2.LINE_AA,
66
- bottomLeftOrigin=False)
67
-
68
- return img_copy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/model.py DELETED
@@ -1,78 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.transforms as transforms
5
- from pytorch_grad_cam import GradCAM
6
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
7
- import logging
8
- from app.model_architectures import ResNet50, LSTMPyTorch
9
-
10
- # Set up logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- # Determine the device
15
- device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
16
- logger.info(f"Using device: {device}")
17
-
18
- # Define paths
19
- STATIC_MODEL_PATH = 'assets/models/FER_static_ResNet50_AffectNet.pt'
20
- DYNAMIC_MODEL_PATH = 'assets/models/FER_dynamic_LSTM.pt'
21
-
22
- def load_model(model_class, model_path, *args, **kwargs):
23
- model = model_class(*args, **kwargs).to(device)
24
- if os.path.exists(model_path):
25
- try:
26
- model.load_state_dict(torch.load(model_path, map_location=device))
27
- model.eval()
28
- logger.info(f"Model loaded successfully from {model_path}")
29
- except Exception as e:
30
- logger.error(f"Error loading model from {model_path}: {str(e)}")
31
- logger.info("Initializing with random weights.")
32
- else:
33
- logger.warning(f"Model file not found at {model_path}. Initializing with random weights.")
34
- return model
35
-
36
- # Load the static model
37
- pth_model_static = load_model(ResNet50, STATIC_MODEL_PATH, num_classes=7, channels=3)
38
-
39
- # Load the dynamic model
40
- pth_model_dynamic = load_model(LSTMPyTorch, DYNAMIC_MODEL_PATH, input_size=2048, hidden_size=256, num_layers=2, num_classes=7)
41
-
42
- # Set up GradCAM
43
- target_layers = [pth_model_static.resnet.layer4[-1]]
44
- cam = GradCAM(model=pth_model_static, target_layers=target_layers)
45
-
46
- # Define image preprocessing
47
- pth_transform = transforms.Compose([
48
- transforms.Resize((224, 224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
51
- ])
52
-
53
- def pth_processing(img):
54
- img = pth_transform(img).unsqueeze(0).to(device)
55
- return img
56
-
57
- def predict_emotion(img):
58
- with torch.no_grad():
59
- output = pth_model_static(pth_processing(img))
60
- _, predicted = torch.max(output, 1)
61
- return predicted.item()
62
-
63
- def get_emotion_probabilities(img):
64
- with torch.no_grad():
65
- output = nn.functional.softmax(pth_model_static(pth_processing(img)), dim=1)
66
- return output.squeeze().cpu().numpy()
67
-
68
- def generate_cam(img):
69
- input_tensor = pth_processing(img)
70
- targets = [ClassifierOutputTarget(predict_emotion(img))]
71
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
72
- return grayscale_cam[0, :]
73
-
74
- # Add any other necessary functions or variables here
75
-
76
- if __name__ == "__main__":
77
- logger.info("Model initialization complete.")
78
- # You can add some test code here to verify everything is working correctly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/model_architectures.py DELETED
@@ -1,46 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.models as models
4
-
5
- class ResNet50(nn.Module):
6
- def __init__(self, num_classes=7, channels=3):
7
- super(ResNet50, self).__init__()
8
- self.resnet = models.resnet50(pretrained=True)
9
- # Modify the first convolutional layer if channels != 3
10
- if channels != 3:
11
- self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
12
- num_features = self.resnet.fc.in_features
13
- self.resnet.fc = nn.Linear(num_features, num_classes)
14
-
15
- def forward(self, x):
16
- return self.resnet(x)
17
-
18
- def extract_features(self, x):
19
- x = self.resnet.conv1(x)
20
- x = self.resnet.bn1(x)
21
- x = self.resnet.relu(x)
22
- x = self.resnet.maxpool(x)
23
-
24
- x = self.resnet.layer1(x)
25
- x = self.resnet.layer2(x)
26
- x = self.resnet.layer3(x)
27
- x = self.resnet.layer4(x)
28
-
29
- x = self.resnet.avgpool(x)
30
- x = torch.flatten(x, 1)
31
- return x
32
-
33
- class LSTMPyTorch(nn.Module):
34
- def __init__(self, input_size, hidden_size, num_layers, num_classes):
35
- super(LSTMPyTorch, self).__init__()
36
- self.hidden_size = hidden_size
37
- self.num_layers = num_layers
38
- self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
39
- self.fc = nn.Linear(hidden_size, num_classes)
40
-
41
- def forward(self, x):
42
- h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
43
- c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
44
- out, _ = self.lstm(x, (h0, c0))
45
- out = self.fc(out[:, -1, :])
46
- return out