Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| sys.path.append(os.getcwd()) | |
| from glob import glob | |
| import numpy as np | |
| import json | |
| import smplx as smpl | |
| from nets import * | |
| from repro_nets import * | |
| from trainer.options import parse_args | |
| from data_utils import torch_data | |
| from trainer.config import load_JsonConfig | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils import data | |
| def init_model(model_name, model_path, args, config): | |
| if model_name == 'freeMo': | |
| # generator = freeMo_Generator(args) | |
| # generator = freeMo_Generator(args) | |
| generator = freeMo_dev(args, config) | |
| # generator.load_state_dict(torch.load(model_path)['generator']) | |
| elif model_name == 'smplx_S2G': | |
| generator = smplx_S2G(args, config) | |
| elif model_name == 'StyleGestures': | |
| generator = StyleGesture_Generator( | |
| args, | |
| config | |
| ) | |
| elif model_name == 'Audio2Gestures': | |
| config.Train.using_mspec_stat = False | |
| generator = Audio2Gesture_Generator( | |
| args, | |
| config, | |
| torch.zeros([1, 1, 108]), | |
| torch.ones([1, 1, 108]) | |
| ) | |
| elif model_name == 'S2G': | |
| generator = S2G_Generator( | |
| args, | |
| config, | |
| ) | |
| elif model_name == 'Tmpt': | |
| generator = S2G_Generator( | |
| args, | |
| config, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) | |
| if model_name == 'smplx_S2G': | |
| generator.generator.load_state_dict(model_ckpt['generator']['generator']) | |
| elif 'generator' in list(model_ckpt.keys()): | |
| generator.load_state_dict(model_ckpt['generator']) | |
| else: | |
| model_ckpt = {'generator': model_ckpt} | |
| generator.load_state_dict(model_ckpt) | |
| return generator | |
| def prevar_loader(data_root, speakers, args, config, model_path, device, generator): | |
| path = model_path.split('ckpt')[0] | |
| file = os.path.join(os.path.dirname(path), "pre_variable.npy") | |
| data_base = torch_data( | |
| data_root=data_root, | |
| speakers=speakers, | |
| split='pre', | |
| limbscaling=False, | |
| normalization=config.Data.pose.normalization, | |
| norm_method=config.Data.pose.norm_method, | |
| split_trans_zero=False, | |
| num_pre_frames=config.Data.pose.pre_pose_length, | |
| num_generate_length=config.Data.pose.generate_length, | |
| num_frames=15, | |
| aud_feat_win_size=config.Data.aud.aud_feat_win_size, | |
| aud_feat_dim=config.Data.aud.aud_feat_dim, | |
| feat_method=config.Data.aud.feat_method, | |
| smplx=True, | |
| audio_sr=22000, | |
| convert_to_6d=config.Data.pose.convert_to_6d, | |
| expression=config.Data.pose.expression | |
| ) | |
| data_base.get_dataset() | |
| pre_set = data_base.all_dataset | |
| pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True) | |
| total_pose = [] | |
| with torch.no_grad(): | |
| for bat in pre_loader: | |
| pose = bat['poses'].to(device).to(torch.float32) | |
| expression = bat['expression'].to(device).to(torch.float32) | |
| pose = pose.permute(0, 2, 1) | |
| pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0) | |
| expression = expression.permute(0, 2, 1) | |
| expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0) | |
| pose = torch.cat([pose, expression], dim=-1) | |
| pose = pose.reshape(pose.shape[0], -1, 1) | |
| pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu() | |
| total_pose.append(np.asarray(pose_code)) | |
| total_pose = np.concatenate(total_pose, axis=0) | |
| mean = np.mean(total_pose, axis=0) | |
| std = np.std(total_pose, axis=0) | |
| prevar = (mean, std) | |
| np.save(file, prevar, allow_pickle=True) | |
| return mean, std | |
| def main(): | |
| parser = parse_args() | |
| args = parser.parse_args() | |
| device = torch.device(args.gpu) | |
| torch.cuda.set_device(device) | |
| config = load_JsonConfig(args.config_file) | |
| print('init model...') | |
| generator = init_model(config.Model.model_name, args.model_path, args, config) | |
| print('init pre-pose vectors...') | |
| mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator) | |
| main() |