Spaces:
Build error
Build error
import base64 | |
from io import BytesIO | |
from typing import Callable, List | |
import numpy as np | |
import torch | |
import cv2 | |
from .masks import face_mask_static | |
from matplotlib import pyplot as plt | |
from insightface.utils import face_align | |
def crop_face(image_full: np.ndarray, app: Callable, crop_size: int) -> np.ndarray: | |
""" | |
Crop face from image and resize | |
""" | |
kps = app.get(image_full, crop_size) | |
M, _ = face_align.estimate_norm(kps[0], crop_size, mode ='None') | |
align_img = cv2.warpAffine(image_full, M, (crop_size, crop_size), borderValue=0.0) | |
return [align_img] | |
def normalize_and_torch(image: np.ndarray) -> torch.tensor: | |
""" | |
Normalize image and transform to torch | |
""" | |
image = torch.tensor(image.copy(), dtype=torch.float32).cuda() | |
if image.max() > 1.: | |
image = image/255. | |
image = image.permute(2, 0, 1).unsqueeze(0) | |
image = (image - 0.5) / 0.5 | |
return image | |
def normalize_and_torch_batch(frames: np.ndarray) -> torch.tensor: | |
""" | |
Normalize batch images and transform to torch | |
""" | |
batch_frames = torch.from_numpy(frames.copy()).cuda() | |
if batch_frames.max() > 1.: | |
batch_frames = batch_frames/255. | |
batch_frames = batch_frames.permute(0, 3, 1, 2) | |
batch_frames = (batch_frames - 0.5)/0.5 | |
return batch_frames | |
def get_final_image(final_frames: List[np.ndarray], | |
crop_frames: List[np.ndarray], | |
full_frame: np.ndarray, | |
tfm_arrays: List[np.ndarray], | |
handler) -> None: | |
""" | |
Create final video from frames | |
""" | |
final = full_frame.copy() | |
params = [None for i in range(len(final_frames))] | |
for i in range(len(final_frames)): | |
frame = cv2.resize(final_frames[i][0], (224, 224)) | |
landmarks = handler.get_without_detection_without_transform(frame) | |
landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[i][0]) | |
mask, _ = face_mask_static(crop_frames[i][0], landmarks, landmarks_tgt, params[i]) | |
mat_rev = cv2.invertAffineTransform(tfm_arrays[i][0]) | |
swap_t = cv2.warpAffine(frame, mat_rev, (full_frame.shape[1], full_frame.shape[0]), borderMode=cv2.BORDER_REPLICATE) | |
mask_t = cv2.warpAffine(mask, mat_rev, (full_frame.shape[1], full_frame.shape[0])) | |
mask_t = np.expand_dims(mask_t, 2) | |
final = mask_t*swap_t + (1-mask_t)*final | |
final = np.array(final, dtype='uint8') | |
return final | |
def show_images(images: List[np.ndarray], | |
titles=None, | |
figsize=(20, 5), | |
fontsize=15): | |
if titles: | |
assert len(titles) == len(images), "Amount of images should be the same as the amount of titles" | |
fig, axes = plt.subplots(1, len(images), figsize=figsize) | |
for idx, (ax, image) in enumerate(zip(axes, images)): | |
ax.imshow(image[:, :, ::-1]) | |
if titles: | |
ax.set_title(titles[idx], fontsize=fontsize) | |
ax.axis("off") | |