Spaces:
Build error
Build error
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
import mediapipe as mp | |
from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing | |
from app.face_utils import get_box, display_info | |
from app.config import config_data | |
from app.__plot import statistics_plot | |
from .au_processing import features_to_au_intensities, au_statistics_plot | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
mp_face_mesh = mp.solutions.face_mesh | |
def preprocess_video_and_predict(video_path): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"Error opening video file: {video_path}") | |
return None, None, None, None, None | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0 or fps != fps: # Handle NaN fps | |
fps = 30 # Default FPS | |
# Paths to save processed videos | |
path_save_video_face = 'result_face.mp4' | |
path_save_video_hm = 'result_hm.mp4' | |
# Video writers | |
vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
lstm_features = [] | |
count_frame = 1 | |
count_face = 0 | |
probs = [] | |
frames = [] | |
au_intensities_list = [] | |
last_output = None | |
last_heatmap = None | |
last_au_intensities = None | |
cur_face = None | |
with mp_face_mesh.FaceMesh( | |
max_num_faces=1, | |
refine_landmarks=False, | |
min_detection_confidence=0.5, | |
min_tracking_confidence=0.5) as face_mesh: | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
results = face_mesh.process(frame_rgb) | |
if results.multi_face_landmarks: | |
for face_landmarks in results.multi_face_landmarks: | |
startX, startY, endX, endY = get_box(face_landmarks, width, height) | |
cur_face = frame_rgb[startY:endY, startX:endX] | |
if count_face % config_data.FRAME_DOWNSAMPLING == 0: | |
cur_face_pil = Image.fromarray(cur_face) | |
cur_face_processed = pth_processing(cur_face_pil) | |
with torch.no_grad(): | |
features = torch.nn.functional.relu( | |
pth_model_static.extract_features(cur_face_processed) | |
).cpu().numpy() | |
au_intensities = features_to_au_intensities( | |
pth_model_static(cur_face_processed) | |
) | |
# Generate heatmap | |
grayscale_cam = cam(input_tensor=cur_face_processed)[0, :] | |
cur_face_resized = cv2.resize(cur_face, (224, 224), interpolation=cv2.INTER_AREA) | |
cur_face_normalized = np.float32(cur_face_resized) / 255 | |
heatmap = show_cam_on_image(cur_face_normalized, grayscale_cam, use_rgb=False) | |
last_heatmap = heatmap | |
last_au_intensities = au_intensities | |
if not lstm_features: | |
lstm_features = [features] * 10 | |
else: | |
lstm_features = lstm_features[1:] + [features] | |
lstm_input = torch.from_numpy(np.vstack(lstm_features)).unsqueeze(0) | |
with torch.no_grad(): | |
output = pth_model_dynamic(lstm_input).cpu().numpy() | |
last_output = output | |
if count_face == 0: | |
count_face += 1 | |
else: | |
if last_output is not None: | |
output = last_output | |
heatmap = last_heatmap | |
au_intensities = last_au_intensities | |
else: | |
output = np.full((1, 7), np.nan) | |
au_intensities = np.full(24, np.nan) | |
probs.append(output[0]) | |
frames.append(count_frame) | |
au_intensities_list.append(au_intensities) | |
else: | |
if last_output is not None: | |
lstm_features = [] | |
probs.append(np.full(7, np.nan)) | |
frames.append(count_frame) | |
au_intensities_list.append(np.full(24, np.nan)) | |
if cur_face is not None: | |
heatmap_frame = display_info(heatmap, f'Frame: {count_frame}', box_scale=0.3) | |
cur_face_bgr = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR) | |
cur_face_resized = cv2.resize(cur_face_bgr, (224, 224), interpolation=cv2.INTER_AREA) | |
cur_face_annotated = display_info(cur_face_resized, f'Frame: {count_frame}', box_scale=0.3) | |
vid_writer_face.write(cur_face_annotated) | |
vid_writer_hm.write(heatmap_frame) | |
count_frame += 1 | |
if count_face != 0: | |
count_face += 1 | |
cap.release() | |
vid_writer_face.release() | |
vid_writer_hm.release() | |
stat = statistics_plot(frames, probs) | |
au_stat = au_statistics_plot(frames, au_intensities_list) | |
if not stat or not au_stat: | |
return None, None, None, None, None | |
return video_path, path_save_video_face, path_save_video_hm, stat, au_stat | |