Spaces:
Running
Running
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from latentsync.utils.util import read_video, write_video | |
from torchvision import transforms | |
import cv2 | |
from einops import rearrange | |
import torch | |
import numpy as np | |
from typing import Union | |
from .affine_transform import AlignRestore | |
from .face_detector import FaceDetector | |
def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor: | |
mask_image = cv2.imread(mask_image_path) | |
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) | |
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0 | |
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w") | |
return mask_image | |
class ImageProcessor: | |
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None): | |
self.resolution = resolution | |
self.resize = transforms.Resize( | |
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True | |
) | |
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True) | |
self.restorer = AlignRestore(resolution=resolution, device=device) | |
if mask_image is None: | |
self.mask_image = load_fixed_mask(resolution) | |
else: | |
self.mask_image = mask_image | |
if device == "cpu": | |
self.face_detector = None | |
else: | |
self.face_detector = FaceDetector(device=device) | |
def affine_transform(self, image: torch.Tensor) -> np.ndarray: | |
if self.face_detector is None: | |
raise NotImplementedError("Using the CPU for face detection is not supported") | |
bbox, landmark_2d_106 = self.face_detector(image) | |
if bbox is None: | |
raise RuntimeError("Face not detected") | |
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center | |
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center | |
pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center | |
landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose]) | |
face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True) | |
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2 | |
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4) | |
face = rearrange(torch.from_numpy(face), "h w c -> c h w") | |
return face, box, affine_matrix | |
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False): | |
if affine_transform: | |
image, _, _ = self.affine_transform(image) | |
else: | |
image = self.resize(image) | |
pixel_values = self.normalize(image / 255.0) | |
masked_pixel_values = pixel_values * self.mask_image | |
return pixel_values, masked_pixel_values, self.mask_image[0:1] | |
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False): | |
if isinstance(images, np.ndarray): | |
images = torch.from_numpy(images) | |
if images.shape[3] == 3: | |
images = rearrange(images, "f h w c -> f c h w") | |
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images] | |
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results)) | |
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list) | |
def process_images(self, images: Union[torch.Tensor, np.ndarray]): | |
if isinstance(images, np.ndarray): | |
images = torch.from_numpy(images) | |
if images.shape[3] == 3: | |
images = rearrange(images, "f h w c -> f c h w") | |
images = self.resize(images) | |
pixel_values = self.normalize(images / 255.0) | |
return pixel_values | |
class VideoProcessor: | |
def __init__(self, resolution: int = 512, device: str = "cpu"): | |
self.image_processor = ImageProcessor(resolution, device) | |
def affine_transform_video(self, video_path): | |
video_frames = read_video(video_path, change_fps=False) | |
results = [] | |
for frame in video_frames: | |
frame, _, _ = self.image_processor.affine_transform(frame) | |
results.append(frame) | |
results = torch.stack(results) | |
results = rearrange(results, "f c h w -> f h w c").numpy() | |
return results | |
if __name__ == "__main__": | |
video_processor = VideoProcessor(256, "cuda") | |
video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4") | |
write_video("output.mp4", video_frames, fps=25) | |