# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from latentsync.utils.util import read_video, write_video from torchvision import transforms import cv2 from einops import rearrange import torch import numpy as np from typing import Union from .affine_transform import AlignRestore from .face_detector import FaceDetector def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor: mask_image = cv2.imread(mask_image_path) mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0 mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w") return mask_image class ImageProcessor: def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None): self.resolution = resolution self.resize = transforms.Resize( (resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True ) self.normalize = transforms.Normalize([0.5], [0.5], inplace=True) self.restorer = AlignRestore(resolution=resolution, device=device) if mask_image is None: self.mask_image = load_fixed_mask(resolution) else: self.mask_image = mask_image if device == "cpu": self.face_detector = None else: self.face_detector = FaceDetector(device=device) def affine_transform(self, image: torch.Tensor) -> np.ndarray: if self.face_detector is None: raise NotImplementedError("Using the CPU for face detection is not supported") bbox, landmark_2d_106 = self.face_detector(image) if bbox is None: raise RuntimeError("Face not detected") pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose]) face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True) box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2 face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4) face = rearrange(torch.from_numpy(face), "h w c -> c h w") return face, box, affine_matrix def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False): if affine_transform: image, _, _ = self.affine_transform(image) else: image = self.resize(image) pixel_values = self.normalize(image / 255.0) masked_pixel_values = pixel_values * self.mask_image return pixel_values, masked_pixel_values, self.mask_image[0:1] def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False): if isinstance(images, np.ndarray): images = torch.from_numpy(images) if images.shape[3] == 3: images = rearrange(images, "f h w c -> f c h w") results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images] pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results)) return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list) def process_images(self, images: Union[torch.Tensor, np.ndarray]): if isinstance(images, np.ndarray): images = torch.from_numpy(images) if images.shape[3] == 3: images = rearrange(images, "f h w c -> f c h w") images = self.resize(images) pixel_values = self.normalize(images / 255.0) return pixel_values class VideoProcessor: def __init__(self, resolution: int = 512, device: str = "cpu"): self.image_processor = ImageProcessor(resolution, device) def affine_transform_video(self, video_path): video_frames = read_video(video_path, change_fps=False) results = [] for frame in video_frames: frame, _, _ = self.image_processor.affine_transform(frame) results.append(frame) results = torch.stack(results) results = rearrange(results, "f c h w -> f h w c").numpy() return results if __name__ == "__main__": video_processor = VideoProcessor(256, "cuda") video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4") write_video("output.mp4", video_frames, fps=25)