|
import os |
|
import sys |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
sys.path.append(os.getcwd()) |
|
from glob import glob |
|
|
|
import numpy as np |
|
import json |
|
import smplx as smpl |
|
|
|
from nets import * |
|
from repro_nets import * |
|
from trainer.options import parse_args |
|
from data_utils import torch_data |
|
from trainer.config import load_JsonConfig |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils import data |
|
|
|
def init_model(model_name, model_path, args, config): |
|
if model_name == 'freeMo': |
|
|
|
|
|
generator = freeMo_dev(args, config) |
|
|
|
elif model_name == 'smplx_S2G': |
|
generator = smplx_S2G(args, config) |
|
elif model_name == 'StyleGestures': |
|
generator = StyleGesture_Generator( |
|
args, |
|
config |
|
) |
|
elif model_name == 'Audio2Gestures': |
|
config.Train.using_mspec_stat = False |
|
generator = Audio2Gesture_Generator( |
|
args, |
|
config, |
|
torch.zeros([1, 1, 108]), |
|
torch.ones([1, 1, 108]) |
|
) |
|
elif model_name == 'S2G': |
|
generator = S2G_Generator( |
|
args, |
|
config, |
|
) |
|
elif model_name == 'Tmpt': |
|
generator = S2G_Generator( |
|
args, |
|
config, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) |
|
if model_name == 'smplx_S2G': |
|
generator.generator.load_state_dict(model_ckpt['generator']['generator']) |
|
elif 'generator' in list(model_ckpt.keys()): |
|
generator.load_state_dict(model_ckpt['generator']) |
|
else: |
|
model_ckpt = {'generator': model_ckpt} |
|
generator.load_state_dict(model_ckpt) |
|
|
|
return generator |
|
|
|
|
|
|
|
def prevar_loader(data_root, speakers, args, config, model_path, device, generator): |
|
path = model_path.split('ckpt')[0] |
|
file = os.path.join(os.path.dirname(path), "pre_variable.npy") |
|
data_base = torch_data( |
|
data_root=data_root, |
|
speakers=speakers, |
|
split='pre', |
|
limbscaling=False, |
|
normalization=config.Data.pose.normalization, |
|
norm_method=config.Data.pose.norm_method, |
|
split_trans_zero=False, |
|
num_pre_frames=config.Data.pose.pre_pose_length, |
|
num_generate_length=config.Data.pose.generate_length, |
|
num_frames=15, |
|
aud_feat_win_size=config.Data.aud.aud_feat_win_size, |
|
aud_feat_dim=config.Data.aud.aud_feat_dim, |
|
feat_method=config.Data.aud.feat_method, |
|
smplx=True, |
|
audio_sr=22000, |
|
convert_to_6d=config.Data.pose.convert_to_6d, |
|
expression=config.Data.pose.expression |
|
) |
|
|
|
data_base.get_dataset() |
|
pre_set = data_base.all_dataset |
|
pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True) |
|
|
|
total_pose = [] |
|
|
|
with torch.no_grad(): |
|
for bat in pre_loader: |
|
pose = bat['poses'].to(device).to(torch.float32) |
|
expression = bat['expression'].to(device).to(torch.float32) |
|
pose = pose.permute(0, 2, 1) |
|
pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0) |
|
expression = expression.permute(0, 2, 1) |
|
expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0) |
|
pose = torch.cat([pose, expression], dim=-1) |
|
pose = pose.reshape(pose.shape[0], -1, 1) |
|
pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu() |
|
total_pose.append(np.asarray(pose_code)) |
|
total_pose = np.concatenate(total_pose, axis=0) |
|
mean = np.mean(total_pose, axis=0) |
|
std = np.std(total_pose, axis=0) |
|
prevar = (mean, std) |
|
np.save(file, prevar, allow_pickle=True) |
|
|
|
return mean, std |
|
|
|
def main(): |
|
parser = parse_args() |
|
args = parser.parse_args() |
|
device = torch.device(args.gpu) |
|
torch.cuda.set_device(device) |
|
|
|
config = load_JsonConfig(args.config_file) |
|
|
|
print('init model...') |
|
generator = init_model(config.Model.model_name, args.model_path, args, config) |
|
print('init pre-pose vectors...') |
|
mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator) |
|
|
|
main() |