vitorcalvi commited on
Commit
fdc3f39
Β·
1 Parent(s): 27ef047
Files changed (1) hide show
  1. app/app_utils.py +13 -101
app/app_utils.py CHANGED
@@ -6,25 +6,22 @@ import cv2
6
  from pytorch_grad_cam.utils.image import show_cam_on_image
7
  import matplotlib.pyplot as plt
8
 
9
- # Importing necessary components for the Gradio app
10
- from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
11
  from app.face_utils import get_box, display_info
12
  from app.config import DICT_EMO, config_data
13
  from app.plot import statistics_plot
14
 
 
15
  mp_face_mesh = mp.solutions.face_mesh
16
 
17
- def get_device():
18
- if torch.backends.mps.is_available():
19
- return torch.device("mps")
20
- elif torch.cuda.is_available():
21
- return torch.device("cuda")
22
- else:
23
- return torch.device("cpu")
24
-
25
- device = get_device()
26
  print(f"Using device: {device}")
27
 
 
 
 
28
  # Move models to the selected device
29
  pth_model_static = pth_model_static.to(device)
30
  pth_model_dynamic = pth_model_dynamic.to(device)
@@ -152,7 +149,7 @@ def preprocess_video_and_predict(video):
152
  startX, startY, endX, endY = get_box(fl, w, h)
153
  cur_face = frame_copy[startY:endY, startX: endX]
154
 
155
- if count_face%config_data.FRAME_DOWNSAMPLING == 0:
156
  cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(device)
157
  with torch.no_grad():
158
  features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().cpu().numpy()
@@ -228,8 +225,10 @@ def preprocess_video_and_predict(video):
228
 
229
  return video, path_save_video_face, path_save_video_hm, stat, au_stat
230
 
231
- # The rest of the functions remain the same
232
- # ...
 
 
233
 
234
  def au_statistics_plot(frames, au_intensities_list):
235
  fig, ax = plt.subplots(figsize=(12, 6))
@@ -244,90 +243,3 @@ def au_statistics_plot(frames, au_intensities_list):
244
  ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
245
  plt.tight_layout()
246
  return fig
247
-
248
- def preprocess_video_and_predict_sleep_quality(video):
249
- cap = cv2.VideoCapture(video)
250
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
251
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
252
- fps = np.round(cap.get(cv2.CAP_PROP_FPS))
253
-
254
- path_save_video_original = 'result_original.mp4'
255
- path_save_video_face = 'result_face.mp4'
256
- path_save_video_sleep = 'result_sleep.mp4'
257
-
258
- vid_writer_original = cv2.VideoWriter(path_save_video_original, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
259
- vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
260
- vid_writer_sleep = cv2.VideoWriter(path_save_video_sleep, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
261
-
262
- frames = []
263
- sleep_quality_scores = []
264
- eye_bags_images = []
265
-
266
- with mp_face_mesh.FaceMesh(
267
- max_num_faces=1,
268
- refine_landmarks=False,
269
- min_detection_confidence=0.5,
270
- min_tracking_confidence=0.5) as face_mesh:
271
-
272
- while cap.isOpened():
273
- ret, frame = cap.read()
274
- if not ret:
275
- break
276
-
277
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
278
- results = face_mesh.process(frame_rgb)
279
-
280
- if results.multi_face_landmarks:
281
- for fl in results.multi_face_landmarks:
282
- startX, startY, endX, endY = get_box(fl, w, h)
283
- cur_face = frame_rgb[startY:endY, startX:endX]
284
-
285
- sleep_quality_score, eye_bags_image = analyze_sleep_quality(cur_face)
286
- sleep_quality_scores.append(sleep_quality_score)
287
- eye_bags_images.append(cv2.resize(eye_bags_image, (224, 224)))
288
-
289
- sleep_quality_viz = create_sleep_quality_visualization(cur_face, sleep_quality_score)
290
-
291
- cur_face = cv2.resize(cur_face, (224, 224))
292
-
293
- vid_writer_face.write(cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR))
294
- vid_writer_sleep.write(sleep_quality_viz)
295
-
296
- vid_writer_original.write(frame)
297
- frames.append(len(frames) + 1)
298
-
299
- cap.release()
300
- vid_writer_original.release()
301
- vid_writer_face.release()
302
- vid_writer_sleep.release()
303
-
304
- sleep_stat = sleep_quality_statistics_plot(frames, sleep_quality_scores)
305
-
306
- if eye_bags_images:
307
- average_eye_bags_image = np.mean(np.array(eye_bags_images), axis=0).astype(np.uint8)
308
- else:
309
- average_eye_bags_image = np.zeros((224, 224, 3), dtype=np.uint8)
310
-
311
- return (path_save_video_original, path_save_video_face, path_save_video_sleep,
312
- average_eye_bags_image, sleep_stat)
313
-
314
- def analyze_sleep_quality(face_image):
315
- # Placeholder function - implement your sleep quality analysis here
316
- sleep_quality_score = np.random.random()
317
- eye_bags_image = cv2.resize(face_image, (224, 224))
318
- return sleep_quality_score, eye_bags_image
319
-
320
- def create_sleep_quality_visualization(face_image, sleep_quality_score):
321
- viz = face_image.copy()
322
- cv2.putText(viz, f"Sleep Quality: {sleep_quality_score:.2f}", (10, 30),
323
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
324
- return cv2.cvtColor(viz, cv2.COLOR_RGB2BGR)
325
-
326
- def sleep_quality_statistics_plot(frames, sleep_quality_scores):
327
- # Placeholder function - implement your statistics plotting here
328
- fig, ax = plt.subplots()
329
- ax.plot(frames, sleep_quality_scores)
330
- ax.set_xlabel('Frame')
331
- ax.set_ylabel('Sleep Quality Score')
332
- ax.set_title('Sleep Quality Over Time')
333
- return fig
 
6
  from pytorch_grad_cam.utils.image import show_cam_on_image
7
  import matplotlib.pyplot as plt
8
 
9
+ # Importing necessary components
10
+ from app.model import load_models
11
  from app.face_utils import get_box, display_info
12
  from app.config import DICT_EMO, config_data
13
  from app.plot import statistics_plot
14
 
15
+ # Initialize the face mesh detector
16
  mp_face_mesh = mp.solutions.face_mesh
17
 
18
+ # Set the device directly
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
20
  print(f"Using device: {device}")
21
 
22
+ # Load models and GradCAM
23
+ pth_model_static, pth_model_dynamic, cam = load_models()
24
+
25
  # Move models to the selected device
26
  pth_model_static = pth_model_static.to(device)
27
  pth_model_dynamic = pth_model_dynamic.to(device)
 
149
  startX, startY, endX, endY = get_box(fl, w, h)
150
  cur_face = frame_copy[startY:endY, startX: endX]
151
 
152
+ if count_face % config_data.FRAME_DOWNSAMPLING == 0:
153
  cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(device)
154
  with torch.no_grad():
155
  features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().cpu().numpy()
 
225
 
226
  return video, path_save_video_face, path_save_video_hm, stat, au_stat
227
 
228
+ def features_to_au_intensities(features):
229
+ features_np = features.detach().cpu().numpy()[0]
230
+ au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min())
231
+ return au_intensities[:24] # Assuming we want 24 AUs
232
 
233
  def au_statistics_plot(frames, au_intensities_list):
234
  fig, ax = plt.subplots(figsize=(12, 6))
 
243
  ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
244
  plt.tight_layout()
245
  return fig