Spaces:
Sleeping
Sleeping
| ''' | |
| not exactly the same as the official repo but the results are good | |
| ''' | |
| import sys | |
| import os | |
| from data_utils.lower_body import c_index_3d, c_index_6d | |
| sys.path.append(os.getcwd()) | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| import math | |
| from nets.base import TrainWrapperBaseClass | |
| from nets.layers import SeqEncoder1D | |
| from losses import KeypointLoss, L1Loss, KLLoss | |
| from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta | |
| from nets.utils import denormalize | |
| class Conv1d_tf(nn.Conv1d): | |
| """ | |
| Conv1d with the padding behavior from TF | |
| modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(Conv1d_tf, self).__init__(*args, **kwargs) | |
| self.padding = kwargs.get("padding", "same") | |
| def _compute_padding(self, input, dim): | |
| input_size = input.size(dim + 2) | |
| filter_size = self.weight.size(dim + 2) | |
| effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 | |
| out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] | |
| total_padding = max( | |
| 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size | |
| ) | |
| additional_padding = int(total_padding % 2 != 0) | |
| return additional_padding, total_padding | |
| def forward(self, input): | |
| if self.padding == "VALID": | |
| return F.conv1d( | |
| input, | |
| self.weight, | |
| self.bias, | |
| self.stride, | |
| padding=0, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| ) | |
| rows_odd, padding_rows = self._compute_padding(input, dim=0) | |
| if rows_odd: | |
| input = F.pad(input, [0, rows_odd]) | |
| return F.conv1d( | |
| input, | |
| self.weight, | |
| self.bias, | |
| self.stride, | |
| padding=(padding_rows // 2), | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| ) | |
| def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'): | |
| if k is None and s is None: | |
| if not downsample: | |
| k = 3 | |
| s = 1 | |
| else: | |
| k = 4 | |
| s = 2 | |
| if type == '1d': | |
| conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) | |
| if norm == 'bn': | |
| norm_block = nn.BatchNorm1d(out_channels) | |
| elif norm == 'ln': | |
| norm_block = nn.LayerNorm(out_channels) | |
| elif type == '2d': | |
| conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) | |
| norm_block = nn.BatchNorm2d(out_channels) | |
| else: | |
| assert False | |
| return nn.Sequential( | |
| conv_block, | |
| norm_block, | |
| nn.LeakyReLU(0.2, True) | |
| ) | |
| class Decoder(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super(Decoder, self).__init__() | |
| self.up1 = nn.Sequential( | |
| ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2), | |
| ConvNormRelu(in_ch // 2, in_ch // 2), | |
| nn.Upsample(scale_factor=2, mode='nearest') | |
| ) | |
| self.up2 = nn.Sequential( | |
| ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4), | |
| ConvNormRelu(in_ch // 4, in_ch // 4), | |
| nn.Upsample(scale_factor=2, mode='nearest') | |
| ) | |
| self.up3 = nn.Sequential( | |
| ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8), | |
| ConvNormRelu(in_ch // 8, in_ch // 8), | |
| nn.Conv1d(in_ch // 8, out_ch, 1, 1) | |
| ) | |
| def forward(self, x, x1, x2, x3): | |
| x = F.interpolate(x, x3.shape[2]) | |
| x = torch.cat([x, x3], dim=1) | |
| x = self.up1(x) | |
| x = F.interpolate(x, x2.shape[2]) | |
| x = torch.cat([x, x2], dim=1) | |
| x = self.up2(x) | |
| x = F.interpolate(x, x1.shape[2]) | |
| x = torch.cat([x, x1], dim=1) | |
| x = self.up3(x) | |
| return x | |
| class EncoderDecoder(nn.Module): | |
| def __init__(self, n_frames, each_dim): | |
| super().__init__() | |
| self.n_frames = n_frames | |
| self.down1 = nn.Sequential( | |
| ConvNormRelu(64, 64, '1d', False), | |
| ConvNormRelu(64, 128, '1d', False), | |
| ) | |
| self.down2 = nn.Sequential( | |
| ConvNormRelu(128, 128, '1d', False), | |
| ConvNormRelu(128, 256, '1d', False), | |
| ) | |
| self.down3 = nn.Sequential( | |
| ConvNormRelu(256, 256, '1d', False), | |
| ConvNormRelu(256, 512, '1d', False), | |
| ) | |
| self.down4 = nn.Sequential( | |
| ConvNormRelu(512, 512, '1d', False), | |
| ConvNormRelu(512, 1024, '1d', False), | |
| ) | |
| self.down = nn.MaxPool1d(kernel_size=2) | |
| self.up = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3]) | |
| self.body_decoder = Decoder(1024, each_dim[1]) | |
| self.hand_decoder = Decoder(1024, each_dim[2]) | |
| def forward(self, spectrogram, time_steps=None): | |
| if time_steps is None: | |
| time_steps = self.n_frames | |
| x1 = self.down1(spectrogram) | |
| x = self.down(x1) | |
| x2 = self.down2(x) | |
| x = self.down(x2) | |
| x3 = self.down3(x) | |
| x = self.down(x3) | |
| x = self.down4(x) | |
| x = self.up(x) | |
| face = self.face_decoder(x, x1, x2, x3) | |
| body = self.body_decoder(x, x1, x2, x3) | |
| hand = self.hand_decoder(x, x1, x2, x3) | |
| return face, body, hand | |
| class Generator(nn.Module): | |
| def __init__(self, | |
| each_dim, | |
| training=False, | |
| device=None | |
| ): | |
| super().__init__() | |
| self.training = training | |
| self.device = device | |
| self.encoderdecoder = EncoderDecoder(15, each_dim) | |
| def forward(self, in_spec, time_steps=None): | |
| if time_steps is not None: | |
| self.gen_length = time_steps | |
| face, body, hand = self.encoderdecoder(in_spec) | |
| out = torch.cat([face, body, hand], dim=1) | |
| out = out.transpose(1, 2) | |
| return out | |
| class Discriminator(nn.Module): | |
| def __init__(self, input_dim): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| ConvNormRelu(input_dim, 128, '1d'), | |
| ConvNormRelu(128, 256, '1d'), | |
| nn.MaxPool1d(kernel_size=2), | |
| ConvNormRelu(256, 256, '1d'), | |
| ConvNormRelu(256, 512, '1d'), | |
| nn.MaxPool1d(kernel_size=2), | |
| ConvNormRelu(512, 512, '1d'), | |
| ConvNormRelu(512, 1024, '1d'), | |
| nn.MaxPool1d(kernel_size=2), | |
| nn.Conv1d(1024, 1, 1, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| out = self.net(x) | |
| return out | |
| class TrainWrapper(TrainWrapperBaseClass): | |
| def __init__(self, args, config) -> None: | |
| self.args = args | |
| self.config = config | |
| self.device = torch.device(self.args.gpu) | |
| self.global_step = 0 | |
| self.convert_to_6d = self.config.Data.pose.convert_to_6d | |
| self.init_params() | |
| self.generator = Generator( | |
| each_dim=self.each_dim, | |
| training=not self.args.infer, | |
| device=self.device, | |
| ).to(self.device) | |
| self.discriminator = Discriminator( | |
| input_dim=self.each_dim[1] + self.each_dim[2] + 64 | |
| ).to(self.device) | |
| if self.convert_to_6d: | |
| self.c_index = c_index_6d | |
| else: | |
| self.c_index = c_index_3d | |
| self.MSELoss = KeypointLoss().to(self.device) | |
| self.L1Loss = L1Loss().to(self.device) | |
| super().__init__(args, config) | |
| def init_params(self): | |
| scale = 1 | |
| global_orient = round(0 * scale) | |
| leye_pose = reye_pose = round(0 * scale) | |
| jaw_pose = round(3 * scale) | |
| body_pose = round((63 - 24) * scale) | |
| left_hand_pose = right_hand_pose = round(45 * scale) | |
| expression = 100 | |
| b_j = 0 | |
| jaw_dim = jaw_pose | |
| b_e = b_j + jaw_dim | |
| eye_dim = leye_pose + reye_pose | |
| b_b = b_e + eye_dim | |
| body_dim = global_orient + body_pose | |
| b_h = b_b + body_dim | |
| hand_dim = left_hand_pose + right_hand_pose | |
| b_f = b_h + hand_dim | |
| face_dim = expression | |
| self.dim_list = [b_j, b_e, b_b, b_h, b_f] | |
| self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim | |
| self.pose = int(self.full_dim / round(3 * scale)) | |
| self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] | |
| def __call__(self, bat): | |
| assert (not self.args.infer), "infer mode" | |
| self.global_step += 1 | |
| loss_dict = {} | |
| aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) | |
| expression = bat['expression'].to(self.device).to(torch.float32) | |
| jaw = poses[:, :3, :] | |
| poses = poses[:, self.c_index, :] | |
| pred = self.generator(in_spec=aud) | |
| D_loss, D_loss_dict = self.get_loss( | |
| pred_poses=pred.detach(), | |
| gt_poses=poses, | |
| aud=aud, | |
| mode='training_D', | |
| ) | |
| self.discriminator_optimizer.zero_grad() | |
| D_loss.backward() | |
| self.discriminator_optimizer.step() | |
| G_loss, G_loss_dict = self.get_loss( | |
| pred_poses=pred, | |
| gt_poses=poses, | |
| aud=aud, | |
| expression=expression, | |
| jaw=jaw, | |
| mode='training_G', | |
| ) | |
| self.generator_optimizer.zero_grad() | |
| G_loss.backward() | |
| self.generator_optimizer.step() | |
| total_loss = None | |
| loss_dict = {} | |
| for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()): | |
| loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0) | |
| return total_loss, loss_dict | |
| def get_loss(self, | |
| pred_poses, | |
| gt_poses, | |
| aud=None, | |
| jaw=None, | |
| expression=None, | |
| mode='training_G', | |
| ): | |
| loss_dict = {} | |
| aud = aud.transpose(1, 2) | |
| gt_poses = gt_poses.transpose(1, 2) | |
| gt_aud = torch.cat([gt_poses, aud], dim=2) | |
| pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2) | |
| if mode == 'training_D': | |
| dis_real = self.discriminator(gt_aud) | |
| dis_fake = self.discriminator(pred_aud) | |
| dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss( | |
| torch.zeros_like(dis_fake).to(self.device), dis_fake) | |
| loss_dict['dis'] = dis_error | |
| return dis_error, loss_dict | |
| elif mode == 'training_G': | |
| jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2)) | |
| face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2)) | |
| body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39]) | |
| hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:]) | |
| l1_loss = jaw_loss + face_loss + body_loss + hand_loss | |
| dis_output = self.discriminator(pred_aud) | |
| gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output) | |
| gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error | |
| loss_dict['gen'] = gen_error | |
| loss_dict['jaw_loss'] = jaw_loss | |
| loss_dict['face_loss'] = face_loss | |
| loss_dict['body_loss'] = body_loss | |
| loss_dict['hand_loss'] = hand_loss | |
| return gen_loss, loss_dict | |
| else: | |
| raise ValueError(mode) | |
| def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs): | |
| output = [] | |
| assert self.args.infer, "train mode" | |
| self.generator.eval() | |
| if self.config.Data.pose.normalization: | |
| assert norm_stats is not None | |
| data_mean = norm_stats[0] | |
| data_std = norm_stats[1] | |
| pre_length = self.config.Data.pose.pre_pose_length | |
| generate_length = self.config.Data.pose.generate_length | |
| # assert pre_length == initial_pose.shape[-1] | |
| # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32) | |
| # B = pre_poses.shape[0] | |
| aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0) | |
| num_poses_to_generate = aud_feat.shape[-1] | |
| aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) | |
| aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) | |
| with torch.no_grad(): | |
| pred_poses = self.generator(aud_feat) | |
| pred_poses = pred_poses.cpu().numpy() | |
| output = pred_poses.squeeze() | |
| return output | |
| def generate(self, aud, id): | |
| self.generator.eval() | |
| pred_poses = self.generator(aud) | |
| return pred_poses | |
| if __name__ == '__main__': | |
| from trainer.options import parse_args | |
| parser = parse_args() | |
| args = parser.parse_args( | |
| ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64', | |
| '--infer']) | |
| generator = TrainWrapper(args) | |
| aud_fn = '../sample_audio/jon.wav' | |
| initial_pose = torch.randn(64, 108, 4) | |
| norm_stats = (np.random.randn(108), np.random.randn(108)) | |
| output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats) | |
| print(output.shape) | |