# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py import numpy as np import cv2 import torch from einops import rearrange import kornia class AlignRestore(object): def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16): if align_points == 3: self.upscale_factor = 1 ratio = resolution / 256 * 2.8 self.crop_ratio = (ratio, ratio) self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]]) self.face_template = self.face_template * ratio self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1])) self.p_bias = None self.device = device self.dtype = dtype self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype) self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype) def align_warp_face(self, img, landmarks3, smooth=True): affine_matrix, self.p_bias = self.transformation_from_points( landmarks3, self.face_template, smooth, self.p_bias ) img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0) affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0) cropped_face = kornia.geometry.transform.warp_affine( img, affine_matrix, (self.face_size[1], self.face_size[0]), mode="bilinear", padding_mode="fill", fill_value=self.fill_value, ) cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8) return cropped_face, affine_matrix def restore_img(self, input_img, face, affine_matrix): h, w, _ = input_img.shape if isinstance(affine_matrix, np.ndarray): affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0) inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix) face = face.to(dtype=self.dtype).unsqueeze(0) inv_face = kornia.geometry.transform.warp_affine( face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value ).squeeze(0) inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255 input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w") inv_mask = kornia.geometry.transform.warp_affine( self.mask, inv_affine_matrix, (h, w), padding_mode="zeros" ) # (1, 1, h_up, w_up) inv_mask_erosion = kornia.morphology.erosion( inv_mask, torch.ones( (int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype ), ) inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face) pasted_face = inv_mask_erosion_t * inv_face total_face_area = torch.sum(inv_mask_erosion.float()) w_edge = int(total_face_area**0.5) // 20 erosion_radius = w_edge * 2 # This step will consume a large amount of GPU memory. # inv_mask_center = kornia.morphology.erosion( # inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype) # ) # Run on CPU to avoid consuming a large amount of GPU memory. inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32) inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...] blur_size = w_edge * 2 + 1 sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8 inv_soft_mask = kornia.filters.gaussian_blur2d( inv_mask_center, (blur_size, blur_size), (sigma, sigma) ).squeeze(0) inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face) img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8) img_back = img_back.cpu().numpy() return img_back def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None): if isinstance(points0, np.ndarray): points2 = torch.tensor(points0, device=self.device, dtype=torch.float32) else: points2 = points0.clone() if isinstance(points1, np.ndarray): points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32) else: points1_tensor = points1.clone() c1 = torch.mean(points1_tensor, dim=0) c2 = torch.mean(points2, dim=0) points1_centered = points1_tensor - c1 points2_centered = points2 - c2 s1 = torch.std(points1_centered) s2 = torch.std(points2_centered) points1_normalized = points1_centered / s1 points2_normalized = points2_centered / s2 covariance = torch.matmul(points1_normalized.T, points2_normalized) U, S, V = torch.svd(covariance) R = torch.matmul(V, U.T) det = torch.det(R) if det < 0: V[:, -1] = -V[:, -1] R = torch.matmul(V, U.T) sR = (s2 / s1) * R T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1)) M = torch.cat((sR, T), dim=1) if smooth: bias = points2_normalized[2] - points1_normalized[2] if p_bias is None: p_bias = bias else: bias = p_bias * 0.2 + bias * 0.8 p_bias = bias M[:, 2] = M[:, 2] + bias return M.cpu().numpy(), p_bias