import cv2 import numpy as np import torch from PIL import Image import mediapipe as mp from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing from app.face_utils import get_box, display_info from app.config import config_data from app.__plot import statistics_plot from .au_processing import features_to_au_intensities, au_statistics_plot from pytorch_grad_cam.utils.image import show_cam_on_image mp_face_mesh = mp.solutions.face_mesh def preprocess_video_and_predict(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error opening video file: {video_path}") return None, None, None, None, None width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0 or fps != fps: # Handle NaN fps fps = 30 # Default FPS # Paths to save processed videos path_save_video_face = 'result_face.mp4' path_save_video_hm = 'result_hm.mp4' # Video writers vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) lstm_features = [] count_frame = 1 count_face = 0 probs = [] frames = [] au_intensities_list = [] last_output = None last_heatmap = None last_au_intensities = None cur_face = None with mp_face_mesh.FaceMesh( max_num_faces=1, refine_landmarks=False, min_detection_confidence=0.5, min_tracking_confidence=0.5) as face_mesh: while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = face_mesh.process(frame_rgb) if results.multi_face_landmarks: for face_landmarks in results.multi_face_landmarks: startX, startY, endX, endY = get_box(face_landmarks, width, height) cur_face = frame_rgb[startY:endY, startX:endX] if count_face % config_data.FRAME_DOWNSAMPLING == 0: cur_face_pil = Image.fromarray(cur_face) cur_face_processed = pth_processing(cur_face_pil) with torch.no_grad(): features = torch.nn.functional.relu( pth_model_static.extract_features(cur_face_processed) ).cpu().numpy() au_intensities = features_to_au_intensities( pth_model_static(cur_face_processed) ) # Generate heatmap grayscale_cam = cam(input_tensor=cur_face_processed)[0, :] cur_face_resized = cv2.resize(cur_face, (224, 224), interpolation=cv2.INTER_AREA) cur_face_normalized = np.float32(cur_face_resized) / 255 heatmap = show_cam_on_image(cur_face_normalized, grayscale_cam, use_rgb=False) last_heatmap = heatmap last_au_intensities = au_intensities if not lstm_features: lstm_features = [features] * 10 else: lstm_features = lstm_features[1:] + [features] lstm_input = torch.from_numpy(np.vstack(lstm_features)).unsqueeze(0) with torch.no_grad(): output = pth_model_dynamic(lstm_input).cpu().numpy() last_output = output if count_face == 0: count_face += 1 else: if last_output is not None: output = last_output heatmap = last_heatmap au_intensities = last_au_intensities else: output = np.full((1, 7), np.nan) au_intensities = np.full(24, np.nan) probs.append(output[0]) frames.append(count_frame) au_intensities_list.append(au_intensities) else: if last_output is not None: lstm_features = [] probs.append(np.full(7, np.nan)) frames.append(count_frame) au_intensities_list.append(np.full(24, np.nan)) if cur_face is not None: heatmap_frame = display_info(heatmap, f'Frame: {count_frame}', box_scale=0.3) cur_face_bgr = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR) cur_face_resized = cv2.resize(cur_face_bgr, (224, 224), interpolation=cv2.INTER_AREA) cur_face_annotated = display_info(cur_face_resized, f'Frame: {count_frame}', box_scale=0.3) vid_writer_face.write(cur_face_annotated) vid_writer_hm.write(heatmap_frame) count_frame += 1 if count_face != 0: count_face += 1 cap.release() vid_writer_face.release() vid_writer_hm.release() stat = statistics_plot(frames, probs) au_stat = au_statistics_plot(frames, au_intensities_list) if not stat or not au_stat: return None, None, None, None, None return video_path, path_save_video_face, path_save_video_hm, stat, au_stat