dyagnosys-free / app /video_processing.py
vitorcalvi's picture
1
18c46ab
import cv2
import numpy as np
import torch
from PIL import Image
import mediapipe as mp
from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
from app.face_utils import get_box, display_info
from app.config import config_data
from app.__plot import statistics_plot
from .au_processing import features_to_au_intensities, au_statistics_plot
from pytorch_grad_cam.utils.image import show_cam_on_image
mp_face_mesh = mp.solutions.face_mesh
def preprocess_video_and_predict(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error opening video file: {video_path}")
return None, None, None, None, None
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0 or fps != fps: # Handle NaN fps
fps = 30 # Default FPS
# Paths to save processed videos
path_save_video_face = 'result_face.mp4'
path_save_video_hm = 'result_hm.mp4'
# Video writers
vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
lstm_features = []
count_frame = 1
count_face = 0
probs = []
frames = []
au_intensities_list = []
last_output = None
last_heatmap = None
last_au_intensities = None
cur_face = None
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as face_mesh:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = face_mesh.process(frame_rgb)
if results.multi_face_landmarks:
for face_landmarks in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(face_landmarks, width, height)
cur_face = frame_rgb[startY:endY, startX:endX]
if count_face % config_data.FRAME_DOWNSAMPLING == 0:
cur_face_pil = Image.fromarray(cur_face)
cur_face_processed = pth_processing(cur_face_pil)
with torch.no_grad():
features = torch.nn.functional.relu(
pth_model_static.extract_features(cur_face_processed)
).cpu().numpy()
au_intensities = features_to_au_intensities(
pth_model_static(cur_face_processed)
)
# Generate heatmap
grayscale_cam = cam(input_tensor=cur_face_processed)[0, :]
cur_face_resized = cv2.resize(cur_face, (224, 224), interpolation=cv2.INTER_AREA)
cur_face_normalized = np.float32(cur_face_resized) / 255
heatmap = show_cam_on_image(cur_face_normalized, grayscale_cam, use_rgb=False)
last_heatmap = heatmap
last_au_intensities = au_intensities
if not lstm_features:
lstm_features = [features] * 10
else:
lstm_features = lstm_features[1:] + [features]
lstm_input = torch.from_numpy(np.vstack(lstm_features)).unsqueeze(0)
with torch.no_grad():
output = pth_model_dynamic(lstm_input).cpu().numpy()
last_output = output
if count_face == 0:
count_face += 1
else:
if last_output is not None:
output = last_output
heatmap = last_heatmap
au_intensities = last_au_intensities
else:
output = np.full((1, 7), np.nan)
au_intensities = np.full(24, np.nan)
probs.append(output[0])
frames.append(count_frame)
au_intensities_list.append(au_intensities)
else:
if last_output is not None:
lstm_features = []
probs.append(np.full(7, np.nan))
frames.append(count_frame)
au_intensities_list.append(np.full(24, np.nan))
if cur_face is not None:
heatmap_frame = display_info(heatmap, f'Frame: {count_frame}', box_scale=0.3)
cur_face_bgr = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
cur_face_resized = cv2.resize(cur_face_bgr, (224, 224), interpolation=cv2.INTER_AREA)
cur_face_annotated = display_info(cur_face_resized, f'Frame: {count_frame}', box_scale=0.3)
vid_writer_face.write(cur_face_annotated)
vid_writer_hm.write(heatmap_frame)
count_frame += 1
if count_face != 0:
count_face += 1
cap.release()
vid_writer_face.release()
vid_writer_hm.release()
stat = statistics_plot(frames, probs)
au_stat = au_statistics_plot(frames, au_intensities_list)
if not stat or not au_stat:
return None, None, None, None, None
return video_path, path_save_video_face, path_save_video_hm, stat, au_stat