LatentSync / latentsync /utils /affine_transform.py
welher's picture
Upload folder using huggingface_hub
8d11d43 verified
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
import numpy as np
import cv2
import torch
from einops import rearrange
import kornia
class AlignRestore(object):
def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16):
if align_points == 3:
self.upscale_factor = 1
ratio = resolution / 256 * 2.8
self.crop_ratio = (ratio, ratio)
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
self.face_template = self.face_template * ratio
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self.p_bias = None
self.device = device
self.dtype = dtype
self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype)
self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype)
def align_warp_face(self, img, landmarks3, smooth=True):
affine_matrix, self.p_bias = self.transformation_from_points(
landmarks3, self.face_template, smooth, self.p_bias
)
img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0)
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
cropped_face = kornia.geometry.transform.warp_affine(
img,
affine_matrix,
(self.face_size[1], self.face_size[0]),
mode="bilinear",
padding_mode="fill",
fill_value=self.fill_value,
)
cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8)
return cropped_face, affine_matrix
def restore_img(self, input_img, face, affine_matrix):
h, w, _ = input_img.shape
if isinstance(affine_matrix, np.ndarray):
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix)
face = face.to(dtype=self.dtype).unsqueeze(0)
inv_face = kornia.geometry.transform.warp_affine(
face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value
).squeeze(0)
inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255
input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w")
inv_mask = kornia.geometry.transform.warp_affine(
self.mask, inv_affine_matrix, (h, w), padding_mode="zeros"
) # (1, 1, h_up, w_up)
inv_mask_erosion = kornia.morphology.erosion(
inv_mask,
torch.ones(
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype
),
)
inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face)
pasted_face = inv_mask_erosion_t * inv_face
total_face_area = torch.sum(inv_mask_erosion.float())
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
# This step will consume a large amount of GPU memory.
# inv_mask_center = kornia.morphology.erosion(
# inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype)
# )
# Run on CPU to avoid consuming a large amount of GPU memory.
inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32)
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...]
blur_size = w_edge * 2 + 1
sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8
inv_soft_mask = kornia.filters.gaussian_blur2d(
inv_mask_center, (blur_size, blur_size), (sigma, sigma)
).squeeze(0)
inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face)
img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img
img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8)
img_back = img_back.cpu().numpy()
return img_back
def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None):
if isinstance(points0, np.ndarray):
points2 = torch.tensor(points0, device=self.device, dtype=torch.float32)
else:
points2 = points0.clone()
if isinstance(points1, np.ndarray):
points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32)
else:
points1_tensor = points1.clone()
c1 = torch.mean(points1_tensor, dim=0)
c2 = torch.mean(points2, dim=0)
points1_centered = points1_tensor - c1
points2_centered = points2 - c2
s1 = torch.std(points1_centered)
s2 = torch.std(points2_centered)
points1_normalized = points1_centered / s1
points2_normalized = points2_centered / s2
covariance = torch.matmul(points1_normalized.T, points2_normalized)
U, S, V = torch.svd(covariance)
R = torch.matmul(V, U.T)
det = torch.det(R)
if det < 0:
V[:, -1] = -V[:, -1]
R = torch.matmul(V, U.T)
sR = (s2 / s1) * R
T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1))
M = torch.cat((sR, T), dim=1)
if smooth:
bias = points2_normalized[2] - points1_normalized[2]
if p_bias is None:
p_bias = bias
else:
bias = p_bias * 0.2 + bias * 0.8
p_bias = bias
M[:, 2] = M[:, 2] + bias
return M.cpu().numpy(), p_bias