ghost / utils /inference /image_processing.py
Jagrut Thakare
v1
9be8aa9
import base64
from io import BytesIO
from typing import Callable, List
import numpy as np
import torch
import cv2
from .masks import face_mask_static
from matplotlib import pyplot as plt
from insightface.utils import face_align
def crop_face(image_full: np.ndarray, app: Callable, crop_size: int) -> np.ndarray:
"""
Crop face from image and resize
"""
kps = app.get(image_full, crop_size)
M, _ = face_align.estimate_norm(kps[0], crop_size, mode ='None')
align_img = cv2.warpAffine(image_full, M, (crop_size, crop_size), borderValue=0.0)
return [align_img]
def normalize_and_torch(image: np.ndarray) -> torch.tensor:
"""
Normalize image and transform to torch
"""
image = torch.tensor(image.copy(), dtype=torch.float32).cuda()
if image.max() > 1.:
image = image/255.
image = image.permute(2, 0, 1).unsqueeze(0)
image = (image - 0.5) / 0.5
return image
def normalize_and_torch_batch(frames: np.ndarray) -> torch.tensor:
"""
Normalize batch images and transform to torch
"""
batch_frames = torch.from_numpy(frames.copy()).cuda()
if batch_frames.max() > 1.:
batch_frames = batch_frames/255.
batch_frames = batch_frames.permute(0, 3, 1, 2)
batch_frames = (batch_frames - 0.5)/0.5
return batch_frames
def get_final_image(final_frames: List[np.ndarray],
crop_frames: List[np.ndarray],
full_frame: np.ndarray,
tfm_arrays: List[np.ndarray],
handler) -> None:
"""
Create final video from frames
"""
final = full_frame.copy()
params = [None for i in range(len(final_frames))]
for i in range(len(final_frames)):
frame = cv2.resize(final_frames[i][0], (224, 224))
landmarks = handler.get_without_detection_without_transform(frame)
landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[i][0])
mask, _ = face_mask_static(crop_frames[i][0], landmarks, landmarks_tgt, params[i])
mat_rev = cv2.invertAffineTransform(tfm_arrays[i][0])
swap_t = cv2.warpAffine(frame, mat_rev, (full_frame.shape[1], full_frame.shape[0]), borderMode=cv2.BORDER_REPLICATE)
mask_t = cv2.warpAffine(mask, mat_rev, (full_frame.shape[1], full_frame.shape[0]))
mask_t = np.expand_dims(mask_t, 2)
final = mask_t*swap_t + (1-mask_t)*final
final = np.array(final, dtype='uint8')
return final
def show_images(images: List[np.ndarray],
titles=None,
figsize=(20, 5),
fontsize=15):
if titles:
assert len(titles) == len(images), "Amount of images should be the same as the amount of titles"
fig, axes = plt.subplots(1, len(images), figsize=figsize)
for idx, (ax, image) in enumerate(zip(axes, images)):
ax.imshow(image[:, :, ::-1])
if titles:
ax.set_title(titles[idx], fontsize=fontsize)
ax.axis("off")