ghost / utils /training /image_processing.py
Jagrut Thakare
v1
9be8aa9
import cv2
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def torch2image(torch_image: torch.tensor) -> np.ndarray:
batch = False
if torch_image.dim() == 4:
torch_image = torch_image[:8]
batch = True
device = torch_image.device
# mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2)
# std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2)
mean = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)
denorm_image = (std * torch_image) + mean
if batch:
denorm_image = denorm_image.permute(0, 2, 3, 1)
else:
denorm_image = denorm_image.permute(1, 2, 0)
np_image = denorm_image.detach().cpu().numpy()
np_image = np.clip(np_image*255., 0, 255).astype(np.uint8)
if batch:
return np.concatenate(np_image, axis=1)
else:
return np_image
def make_image_list(images) -> np.ndarray:
np_images = []
for torch_image in images:
np_img = torch2image(torch_image)
np_images.append(np_img)
return np.concatenate(np_images, axis=0)
def read_torch_image(path: str) -> torch.tensor:
image = cv2.imread(path)
image = cv2.resize(image, (256, 256))
image = Image.fromarray(image[:, :, ::-1])
image = transformer_Arcface(image)
image = image.view(-1, image.shape[0], image.shape[1], image.shape[2])
return image
def get_faceswap(source_path: str, target_path: str,
G: 'generator model', netArc: 'arcface model',
device: 'torch device') -> np.array:
source = read_torch_image(source_path)
source = source.to(device)
embeds = netArc(F.interpolate(source, [112, 112], mode='bilinear', align_corners=False))
# embeds = F.normalize(embeds, p=2, dim=1)
target = read_torch_image(target_path)
target = target.cuda()
with torch.no_grad():
Yt, _ = G(target, embeds)
Yt = torch2image(Yt)
source = torch2image(source)
target = torch2image(target)
return np.concatenate((cv2.resize(source, (256, 256)), target, Yt), axis=1)