ghostv1 / utils /inference /image_processing.py
Jagrut Thakare
v1
6d92c79
raw
history blame
3.03 kB
import base64
from io import BytesIO
from typing import Callable, List
import numpy as np
import torch
import cv2
from .masks import face_mask_static
from matplotlib import pyplot as plt
from insightface.utils import face_align
def crop_face(image_full: np.ndarray, app: Callable, crop_size: int) -> np.ndarray:
"""
Crop face from image and resize
"""
kps = app.get(image_full, crop_size)
M, _ = face_align.estimate_norm(kps[0], crop_size, mode ='None')
align_img = cv2.warpAffine(image_full, M, (crop_size, crop_size), borderValue=0.0)
return [align_img]
def normalize_and_torch(image: np.ndarray) -> torch.tensor:
"""
Normalize image and transform to torch
"""
image = torch.tensor(image.copy(), dtype=torch.float32).cuda()
if image.max() > 1.:
image = image/255.
image = image.permute(2, 0, 1).unsqueeze(0)
image = (image - 0.5) / 0.5
return image
def normalize_and_torch_batch(frames: np.ndarray) -> torch.tensor:
"""
Normalize batch images and transform to torch
"""
batch_frames = torch.from_numpy(frames.copy()).cuda()
if batch_frames.max() > 1.:
batch_frames = batch_frames/255.
batch_frames = batch_frames.permute(0, 3, 1, 2)
batch_frames = (batch_frames - 0.5)/0.5
return batch_frames
def get_final_image(final_frames: List[np.ndarray],
crop_frames: List[np.ndarray],
full_frame: np.ndarray,
tfm_arrays: List[np.ndarray],
handler) -> None:
"""
Create final video from frames
"""
final = full_frame.copy()
params = [None for i in range(len(final_frames))]
for i in range(len(final_frames)):
frame = cv2.resize(final_frames[i][0], (224, 224))
landmarks = handler.get_without_detection_without_transform(frame)
landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[i][0])
mask, _ = face_mask_static(crop_frames[i][0], landmarks, landmarks_tgt, params[i])
mat_rev = cv2.invertAffineTransform(tfm_arrays[i][0])
swap_t = cv2.warpAffine(frame, mat_rev, (full_frame.shape[1], full_frame.shape[0]), borderMode=cv2.BORDER_REPLICATE)
mask_t = cv2.warpAffine(mask, mat_rev, (full_frame.shape[1], full_frame.shape[0]))
mask_t = np.expand_dims(mask_t, 2)
final = mask_t*swap_t + (1-mask_t)*final
final = np.array(final, dtype='uint8')
return final
def show_images(images: List[np.ndarray],
titles=None,
figsize=(20, 5),
fontsize=15):
if titles:
assert len(titles) == len(images), "Amount of images should be the same as the amount of titles"
fig, axes = plt.subplots(1, len(images), figsize=figsize)
for idx, (ax, image) in enumerate(zip(axes, images)):
ax.imshow(image[:, :, ::-1])
if titles:
ax.set_title(titles[idx], fontsize=fontsize)
ax.axis("off")