Spaces:
Runtime error
Runtime error
# coding: utf-8 | |
""" | |
utility functions and classes to handle feature extraction and model loading | |
""" | |
import os | |
import os.path as osp | |
import torch | |
from collections import OrderedDict | |
import numpy as np | |
from scipy.spatial import ConvexHull # pylint: disable=E0401,E0611 | |
from typing import Union | |
import cv2 | |
from ..modules.spade_generator import SPADEDecoder | |
from ..modules.warping_network import WarpingNetwork | |
from ..modules.motion_extractor import MotionExtractor | |
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor | |
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork | |
def tensor_to_numpy(data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: | |
"""transform torch.Tensor into numpy.ndarray""" | |
if isinstance(data, torch.Tensor): | |
return data.data.cpu().numpy() | |
return data | |
def calc_motion_multiplier( | |
kp_source: Union[np.ndarray, torch.Tensor], | |
kp_driving_initial: Union[np.ndarray, torch.Tensor] | |
) -> float: | |
"""calculate motion_multiplier based on the source image and the first driving frame""" | |
kp_source_np = tensor_to_numpy(kp_source) | |
kp_driving_initial_np = tensor_to_numpy(kp_driving_initial) | |
source_area = ConvexHull(kp_source_np.squeeze(0)).volume | |
driving_area = ConvexHull(kp_driving_initial_np.squeeze(0)).volume | |
motion_multiplier = np.sqrt(source_area) / np.sqrt(driving_area) | |
# motion_multiplier = np.cbrt(source_area) / np.cbrt(driving_area) | |
return motion_multiplier | |
def suffix(filename): | |
"""a.jpg -> jpg""" | |
pos = filename.rfind(".") | |
if pos == -1: | |
return "" | |
return filename[pos + 1:] | |
def prefix(filename): | |
"""a.jpg -> a""" | |
pos = filename.rfind(".") | |
if pos == -1: | |
return filename | |
return filename[:pos] | |
def basename(filename): | |
"""a/b/c.jpg -> c""" | |
return prefix(osp.basename(filename)) | |
def remove_suffix(filepath): | |
"""a/b/c.jpg -> a/b/c""" | |
return osp.join(osp.dirname(filepath), basename(filepath)) | |
def is_image(file_path): | |
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp') | |
return file_path.lower().endswith(image_extensions) | |
def is_video(file_path): | |
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): | |
return True | |
return False | |
def is_template(file_path): | |
if file_path.endswith(".pkl"): | |
return True | |
return False | |
def mkdir(d, log=False): | |
# return self-assined `d`, for one line code | |
if not osp.exists(d): | |
os.makedirs(d, exist_ok=True) | |
if log: | |
print(f"Make dir: {d}") | |
return d | |
def squeeze_tensor_to_numpy(tensor): | |
out = tensor.data.squeeze(0).cpu().numpy() | |
return out | |
def dct2device(dct: dict, device): | |
for key in dct: | |
if isinstance(dct[key], torch.Tensor): | |
dct[key] = dct[key].to(device) | |
else: | |
dct[key] = torch.tensor(dct[key]).to(device) | |
return dct | |
def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: | |
""" | |
kp_source: (bs, k, 3) | |
kp_driving: (bs, k, 3) | |
Return: (bs, 2k*3) | |
""" | |
bs_src = kp_source.shape[0] | |
bs_dri = kp_driving.shape[0] | |
assert bs_src == bs_dri, 'batch size must be equal' | |
feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) | |
return feat | |
def remove_ddp_dumplicate_key(state_dict): | |
state_dict_new = OrderedDict() | |
for key in state_dict.keys(): | |
state_dict_new[key.replace('module.', '')] = state_dict[key] | |
return state_dict_new | |
def load_model(ckpt_path, model_config, device, model_type): | |
model_params = model_config['model_params'][f'{model_type}_params'] | |
if model_type == 'appearance_feature_extractor': | |
model = AppearanceFeatureExtractor(**model_params).to(device) | |
elif model_type == 'motion_extractor': | |
model = MotionExtractor(**model_params).to(device) | |
elif model_type == 'warping_module': | |
model = WarpingNetwork(**model_params).to(device) | |
elif model_type == 'spade_generator': | |
model = SPADEDecoder(**model_params).to(device) | |
elif model_type == 'stitching_retargeting_module': | |
# Special handling for stitching and retargeting module | |
config = model_config['model_params']['stitching_retargeting_module_params'] | |
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) | |
stitcher = StitchingRetargetingNetwork(**config.get('stitching')) | |
stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) | |
stitcher = stitcher.to(device) | |
stitcher.eval() | |
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip')) | |
retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth'])) | |
retargetor_lip = retargetor_lip.to(device) | |
retargetor_lip.eval() | |
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye')) | |
retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye'])) | |
retargetor_eye = retargetor_eye.to(device) | |
retargetor_eye.eval() | |
return { | |
'stitching': stitcher, | |
'lip': retargetor_lip, | |
'eye': retargetor_eye | |
} | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) | |
model.eval() | |
return model | |
def load_description(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def is_square_video(video_path): | |
video = cv2.VideoCapture(video_path) | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
video.release() | |
# if width != height: | |
# gr.Info(f"Uploaded video is not square, force do crop (driving) to be True") | |
return width == height | |
def clean_state_dict(state_dict): | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # remove `module.` | |
new_state_dict[k] = v | |
return new_state_dict | |