diff --git a/evaluation/FGD.py b/evaluation/FGD.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5521ee30b8751ef3bdd980ee91231ab78fea6e5
--- /dev/null
+++ b/evaluation/FGD.py
@@ -0,0 +1,199 @@
+import time
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy import linalg
+import math
+from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
+
+import warnings
+warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
+
+
+change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
+class EmbeddingSpaceEvaluator:
+ def __init__(self, ae, vae, device):
+
+ # init embed net
+ self.ae = ae
+ # self.vae = vae
+
+ # storage
+ self.real_feat_list = []
+ self.generated_feat_list = []
+ self.real_joints_list = []
+ self.generated_joints_list = []
+ self.real_6d_list = []
+ self.generated_6d_list = []
+ self.audio_beat_list = []
+
+ def reset(self):
+ self.real_feat_list = []
+ self.generated_feat_list = []
+
+ def get_no_of_samples(self):
+ return len(self.real_feat_list)
+
+ def push_samples(self, generated_poses, real_poses):
+ # self.net.eval()
+ # convert poses to latent features
+ real_feat, real_poses = self.ae.extract(real_poses)
+ generated_feat, generated_poses = self.ae.extract(generated_poses)
+
+ num_joints = real_poses.shape[2] // 3
+
+ real_feat = real_feat.squeeze()
+ generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
+
+ self.real_feat_list.append(real_feat.data.cpu().numpy())
+ self.generated_feat_list.append(generated_feat.data.cpu().numpy())
+
+ # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
+ # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
+ #
+ # self.real_feat_list.append(real_poses.data.cpu().numpy())
+ # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
+
+ def push_joints(self, generated_poses, real_poses):
+ self.real_joints_list.append(real_poses.data.cpu())
+ self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
+
+ def push_aud(self, aud):
+ self.audio_beat_list.append(aud.squeeze().data.cpu())
+
+ def get_MAAC(self):
+ ang_vel_list = []
+ for real_joints in self.real_joints_list:
+ real_joints[:, 15:21] = real_joints[:, 16:22]
+ vec = real_joints[:, 15:21] - real_joints[:, 13:19]
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
+ angle = torch.acos(inner_product) / math.pi
+ ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
+ ang_vel_list.append(ang_vel.unsqueeze(dim=0))
+ all_vel = torch.cat(ang_vel_list, dim=0)
+ MAAC = all_vel.mean(dim=0)
+ return MAAC
+
+ def get_BCscore(self):
+ thres = 0.01
+ sigma = 0.1
+ sum_1 = 0
+ total_beat = 0
+ for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
+ motion_beat_time = []
+ if joints.dim() == 4:
+ joints = joints[0]
+ joints[:, 15:21] = joints[:, 16:22]
+ vec = joints[:, 15:21] - joints[:, 13:19]
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
+ angle = torch.acos(inner_product) / math.pi
+ ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
+
+ angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
+
+ sum_2 = 0
+ for i in range(angle_diff.shape[1]):
+ motion_beat_time = []
+ for t in range(1, joints.shape[0]-1):
+ if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
+ if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
+ t][i] >= thres):
+ motion_beat_time.append(float(t) / 30.0)
+ if (len(motion_beat_time) == 0):
+ continue
+ motion_beat_time = torch.tensor(motion_beat_time)
+ sum = 0
+ for audio in audio_beat_time:
+ sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
+ sum_2 = sum_2 + sum
+ total_beat = total_beat + len(audio_beat_time)
+ sum_1 = sum_1 + sum_2
+ return sum_1/total_beat
+
+
+ def get_scores(self):
+ generated_feats = np.vstack(self.generated_feat_list)
+ real_feats = np.vstack(self.real_feat_list)
+
+ def frechet_distance(samples_A, samples_B):
+ A_mu = np.mean(samples_A, axis=0)
+ A_sigma = np.cov(samples_A, rowvar=False)
+ B_mu = np.mean(samples_B, axis=0)
+ B_sigma = np.cov(samples_B, rowvar=False)
+ try:
+ frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
+ except ValueError:
+ frechet_dist = 1e+10
+ return frechet_dist
+
+ ####################################################################
+ # frechet distance
+ frechet_dist = frechet_distance(generated_feats, real_feats)
+
+ ####################################################################
+ # distance between real and generated samples on the latent feature space
+ dists = []
+ for i in range(real_feats.shape[0]):
+ d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
+ dists.append(d)
+ feat_dist = np.mean(dists)
+
+ return frechet_dist, feat_dist
+
+ @staticmethod
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representative data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representative data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, \
+ 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, \
+ 'Training and test covariances have different dimensions'
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return (diff.dot(diff) + np.trace(sigma1) +
+ np.trace(sigma2) - 2 * tr_covmean)
\ No newline at end of file
diff --git a/evaluation/__init__.py b/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/evaluation/__pycache__/__init__.cpython-37.pyc b/evaluation/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1631bfe6e7bc4b06edead82accfc11190e66830
Binary files /dev/null and b/evaluation/__pycache__/__init__.cpython-37.pyc differ
diff --git a/evaluation/__pycache__/metrics.cpython-37.pyc b/evaluation/__pycache__/metrics.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc7c877f5fd731130c82d1b67a180ac9b69cc251
Binary files /dev/null and b/evaluation/__pycache__/metrics.cpython-37.pyc differ
diff --git a/evaluation/diversity_LVD.py b/evaluation/diversity_LVD.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd7dd81118692ea0c361c6016afa29d98665315
--- /dev/null
+++ b/evaluation/diversity_LVD.py
@@ -0,0 +1,64 @@
+'''
+LVD: different initial pose
+diversity: same initial pose
+'''
+import os
+import sys
+sys.path.append(os.getcwd())
+
+from glob import glob
+
+from argparse import ArgumentParser
+import json
+
+from evaluation.util import *
+from evaluation.metrics import *
+from tqdm import tqdm
+
+parser = ArgumentParser()
+parser.add_argument('--speaker', required=True, type=str)
+parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
+args = parser.parse_args()
+
+speaker = args.speaker
+test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
+
+LVD_list = []
+diversity_list = []
+
+for aud in tqdm(test_audios):
+ base_name = os.path.splitext(aud)[0]
+ gt_path = get_full_path(aud, speaker, 'val')
+ _, gt_poses, _ = get_gts(gt_path)
+ gt_poses = gt_poses[np.newaxis,...]
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
+ for post_fix in args.post_fix:
+ pred_path = base_name + '_'+post_fix+'.json'
+ pred_poses = np.array(json.load(open(pred_path)))
+ # print(pred_poses.shape)#(B, seq_len, 108)
+ pred_poses = cvt25(pred_poses, gt_poses)
+ # print(pred_poses.shape)#(B, seq, pose_dim)
+
+ gt_valid_points = hand_points(gt_poses)
+ pred_valid_points = hand_points(pred_poses)
+
+ lvd = LVD(gt_valid_points, pred_valid_points)
+ # div = diversity(pred_valid_points)
+
+ LVD_list.append(lvd)
+ # diversity_list.append(div)
+
+ # gt_velocity = peak_velocity(gt_valid_points, order=2)
+ # pred_velocity = peak_velocity(pred_valid_points, order=2)
+
+ # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
+ # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
+
+ # gt_consistency_list.append(gt_consistency)
+ # pred_consistency_list.append(pred_consistency)
+
+lvd = np.mean(LVD_list)
+# diversity_list = np.mean(diversity_list)
+
+print('LVD:', lvd)
+# print("diversity:", diversity_list)
\ No newline at end of file
diff --git a/evaluation/get_quality_samples.py b/evaluation/get_quality_samples.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8ef393cd310aa2e75f871122a62f45dd525e47c
--- /dev/null
+++ b/evaluation/get_quality_samples.py
@@ -0,0 +1,62 @@
+'''
+'''
+import os
+import sys
+sys.path.append(os.getcwd())
+
+from glob import glob
+
+from argparse import ArgumentParser
+import json
+
+from evaluation.util import *
+from evaluation.metrics import *
+from tqdm import tqdm
+
+parser = ArgumentParser()
+parser.add_argument('--speaker', required=True, type=str)
+parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
+args = parser.parse_args()
+
+speaker = args.speaker
+test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
+
+quality_samples={'gt':[]}
+for post_fix in args.post_fix:
+ quality_samples[post_fix] = []
+
+for aud in tqdm(test_audios):
+ base_name = os.path.splitext(aud)[0]
+ gt_path = get_full_path(aud, speaker, 'val')
+ _, gt_poses, _ = get_gts(gt_path)
+ gt_poses = gt_poses[np.newaxis,...]
+ gt_valid_points = valid_points(gt_poses)
+ # print(gt_valid_points.shape)
+ quality_samples['gt'].append(gt_valid_points)
+
+ for post_fix in args.post_fix:
+ pred_path = base_name + '_'+post_fix+'.json'
+ pred_poses = np.array(json.load(open(pred_path)))
+ # print(pred_poses.shape)#(B, seq_len, 108)
+ pred_poses = cvt25(pred_poses, gt_poses)
+ # print(pred_poses.shape)#(B, seq, pose_dim)
+
+ pred_valid_points = valid_points(pred_poses)[0:1]
+ quality_samples[post_fix].append(pred_valid_points)
+
+quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
+for post_fix in args.post_fix:
+ quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
+
+print('gt:', quality_samples['gt'].shape)
+quality_samples['gt'] = quality_samples['gt'].tolist()
+for post_fix in args.post_fix:
+ print(post_fix, ':', quality_samples[post_fix].shape)
+ quality_samples[post_fix] = quality_samples[post_fix].tolist()
+
+save_dir = '../../experiments/'
+os.makedirs(save_dir, exist_ok=True)
+save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
+with open(save_name, 'w') as f:
+ json.dump(quality_samples, f)
+
diff --git a/evaluation/metrics.py b/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..93dc8fde8b8de57cc2f0f7387afd5de7cb753835
--- /dev/null
+++ b/evaluation/metrics.py
@@ -0,0 +1,109 @@
+'''
+Warning: metrics are for reference only, may have limited significance
+'''
+import os
+import sys
+sys.path.append(os.getcwd())
+import numpy as np
+import torch
+
+from data_utils.lower_body import rearrange, symmetry
+import torch.nn.functional as F
+
+def data_driven_baselines(gt_kps):
+ '''
+ gt_kps: T, D
+ '''
+ gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
+
+ mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
+ mean = np.mean(np.abs(gt_velocity-mean))
+ last_step = gt_kps[1] - gt_kps[0]
+ last_step = last_step[np.newaxis] #(1, D)
+ last_step = np.mean(np.abs(gt_velocity-last_step))
+ return last_step, mean
+
+def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
+ if gt_kps.shape[0] > pr_kps.shape[1]:
+ length = pr_kps.shape[1]
+ else:
+ length = gt_kps.shape[0]
+ gt_kps = gt_kps[:length]
+ pr_kps = pr_kps[:, :length]
+ global symmetry
+ symmetry = torch.tensor(symmetry).bool()
+
+ if symmetrical:
+ # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
+ gt_kps = gt_kps[:, rearrange]
+ ns_gt_kps = gt_kps[:, ~symmetry]
+ ys_gt_kps = gt_kps[:, symmetry]
+ ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
+ ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
+ ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
+ left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
+ right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
+ move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
+ ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
+ ys_gt_velocity = ys_gt_velocity.transpose(0,1)
+ gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
+
+ pr_kps = pr_kps[:, :, rearrange]
+ ns_pr_kps = pr_kps[:, :, ~symmetry]
+ ys_pr_kps = pr_kps[:, :, symmetry]
+ ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
+ ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
+ ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
+ left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
+ right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
+ move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
+ torch.zeros(left_pr_vel.shape).cuda())
+ ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
+ ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
+ ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
+ pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
+ else:
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
+ pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
+
+ if weight:
+ w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
+ else:
+ w = 1 / gt_velocity.shape[0]
+
+ v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
+
+ return v_diff
+
+
+def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
+ gt_kps = gt_kps.squeeze()
+ pr_kps = pr_kps.squeeze()
+ if len(pr_kps.shape) == 4:
+ return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
+ # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
+ length = gt_kps.shape[0]-10
+ # gt_kps = gt_kps[25:length]
+ # pr_kps = pr_kps[25:length] #(T, D)
+ # if pr_kps.shape[0] < gt_kps.shape[0]:
+ # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
+
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
+ pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
+
+ return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
+
+def diversity(kps):
+ '''
+ kps: bs, seq, dim
+ '''
+ dis_list = []
+ #the distance between each pair
+ for i in range(kps.shape[0]):
+ for j in range(i+1, kps.shape[0]):
+ seq_i = kps[i]
+ seq_j = kps[j]
+
+ dis = np.mean(np.abs(seq_i - seq_j))
+ dis_list.append(dis)
+ return np.mean(dis_list)
diff --git a/evaluation/mode_transition.py b/evaluation/mode_transition.py
new file mode 100644
index 0000000000000000000000000000000000000000..92cd0e5ecfe688b7a8add932af0303bd8d5ed947
--- /dev/null
+++ b/evaluation/mode_transition.py
@@ -0,0 +1,60 @@
+import os
+import sys
+sys.path.append(os.getcwd())
+
+from glob import glob
+
+from argparse import ArgumentParser
+import json
+
+from evaluation.util import *
+from evaluation.metrics import *
+from tqdm import tqdm
+
+parser = ArgumentParser()
+parser.add_argument('--speaker', required=True, type=str)
+parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
+args = parser.parse_args()
+
+speaker = args.speaker
+test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
+
+precision_list=[]
+recall_list=[]
+accuracy_list=[]
+
+for aud in tqdm(test_audios):
+ base_name = os.path.splitext(aud)[0]
+ gt_path = get_full_path(aud, speaker, 'val')
+ _, gt_poses, _ = get_gts(gt_path)
+ if gt_poses.shape[0] < 50:
+ continue
+ gt_poses = gt_poses[np.newaxis,...]
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
+ for post_fix in args.post_fix:
+ pred_path = base_name + '_'+post_fix+'.json'
+ pred_poses = np.array(json.load(open(pred_path)))
+ # print(pred_poses.shape)#(B, seq_len, 108)
+ pred_poses = cvt25(pred_poses, gt_poses)
+ # print(pred_poses.shape)#(B, seq, pose_dim)
+
+ gt_valid_points = valid_points(gt_poses)
+ pred_valid_points = valid_points(pred_poses)
+
+ # print(gt_valid_points.shape, pred_valid_points.shape)
+
+ gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
+ pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
+
+ # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
+ # pred_mode_transition_seq = baseline
+ precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
+ precision_list.append(precision)
+ recall_list.append(recall)
+ accuracy_list.append(accuracy)
+print(len(precision_list), len(recall_list), len(accuracy_list))
+precision_list = np.mean(precision_list)
+recall_list = np.mean(recall_list)
+accuracy_list = np.mean(accuracy_list)
+
+print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
diff --git a/evaluation/peak_velocity.py b/evaluation/peak_velocity.py
new file mode 100644
index 0000000000000000000000000000000000000000..3842b918375176099cd60a8f9ede50d8920b3e4a
--- /dev/null
+++ b/evaluation/peak_velocity.py
@@ -0,0 +1,65 @@
+import os
+import sys
+sys.path.append(os.getcwd())
+
+from glob import glob
+
+from argparse import ArgumentParser
+import json
+
+from evaluation.util import *
+from evaluation.metrics import *
+from tqdm import tqdm
+
+parser = ArgumentParser()
+parser.add_argument('--speaker', required=True, type=str)
+parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
+args = parser.parse_args()
+
+speaker = args.speaker
+test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
+
+gt_consistency_list=[]
+pred_consistency_list=[]
+
+for aud in tqdm(test_audios):
+ base_name = os.path.splitext(aud)[0]
+ gt_path = get_full_path(aud, speaker, 'val')
+ _, gt_poses, _ = get_gts(gt_path)
+ gt_poses = gt_poses[np.newaxis,...]
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
+ for post_fix in args.post_fix:
+ pred_path = base_name + '_'+post_fix+'.json'
+ pred_poses = np.array(json.load(open(pred_path)))
+ # print(pred_poses.shape)#(B, seq_len, 108)
+ pred_poses = cvt25(pred_poses, gt_poses)
+ # print(pred_poses.shape)#(B, seq, pose_dim)
+
+ gt_valid_points = hand_points(gt_poses)
+ pred_valid_points = hand_points(pred_poses)
+
+ gt_velocity = peak_velocity(gt_valid_points, order=2)
+ pred_velocity = peak_velocity(pred_valid_points, order=2)
+
+ gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
+ pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
+
+ gt_consistency_list.append(gt_consistency)
+ pred_consistency_list.append(pred_consistency)
+
+gt_consistency_list = np.concatenate(gt_consistency_list)
+pred_consistency_list = np.concatenate(pred_consistency_list)
+
+print(gt_consistency_list.max(), gt_consistency_list.min())
+print(pred_consistency_list.max(), pred_consistency_list.min())
+print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
+print(np.std(gt_consistency_list), np.std(pred_consistency_list))
+
+draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
+draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
+
+to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
+to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
+
+np.save('%s_gt.npy'%(speaker), gt_consistency_list)
+np.save('%s_pred.npy'%(speaker), pred_consistency_list)
\ No newline at end of file
diff --git a/evaluation/util.py b/evaluation/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c18a0ab89ca812e4a2126bd18ee8e588b7ec43
--- /dev/null
+++ b/evaluation/util.py
@@ -0,0 +1,148 @@
+import os
+from glob import glob
+import numpy as np
+import json
+from matplotlib import pyplot as plt
+import pandas as pd
+def get_gts(clip):
+ '''
+ clip: abs path to the clip dir
+ '''
+ keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
+
+ upper_body_points = list(np.arange(0, 25))
+ poses = []
+ confs = []
+ neck_to_nose_len = []
+ mean_position = []
+ for kp_file in keypoints_files:
+ kp_load = json.load(open(kp_file, 'r'))['people'][0]
+ posepts = kp_load['pose_keypoints_2d']
+ lhandpts = kp_load['hand_left_keypoints_2d']
+ rhandpts = kp_load['hand_right_keypoints_2d']
+ facepts = kp_load['face_keypoints_2d']
+
+ neck = np.array(posepts).reshape(-1,3)[1]
+ nose = np.array(posepts).reshape(-1,3)[0]
+ x_offset = abs(neck[0]-nose[0])
+ y_offset = abs(neck[1]-nose[1])
+ neck_to_nose_len.append(y_offset)
+ mean_position.append([neck[0],neck[1]])
+
+ keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
+
+ upper_body = keypoints[upper_body_points, :]
+ hand_points = keypoints[25:, :]
+ keypoints = np.vstack([upper_body, hand_points])
+
+ poses.append(keypoints)
+
+ if len(neck_to_nose_len) > 0:
+ scale_factor = np.mean(neck_to_nose_len)
+ else:
+ raise ValueError(clip)
+ mean_position = np.mean(np.array(mean_position), axis=0)
+
+ unlocalized_poses = np.array(poses).copy()
+ localized_poses = []
+ for i in range(len(poses)):
+ keypoints = poses[i]
+ neck = keypoints[1].copy()
+
+ keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
+ keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
+ localized_poses.append(keypoints.reshape(-1))
+
+ localized_poses=np.array(localized_poses)
+ return unlocalized_poses, localized_poses, (scale_factor, mean_position)
+
+def get_full_path(wav_name, speaker, split):
+ '''
+ get clip path from aud file
+ '''
+ wav_name = os.path.basename(wav_name)
+ wav_name = os.path.splitext(wav_name)[0]
+ clip_name, vid_name = wav_name[:10], wav_name[11:]
+
+ full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
+
+ assert os.path.isdir(full_path), full_path
+
+ return full_path
+
+def smooth(res):
+ '''
+ res: (B, seq_len, pose_dim)
+ '''
+ window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
+ w_size=7
+ for i in range(10, res.shape[1]-3):
+ window.append(res[:, i+3, :])
+ if len(window) > w_size:
+ window = window[1:]
+
+ if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
+ res[:, i, :] = np.mean(window, axis=1)
+
+ return res
+
+def cvt25(pred_poses, gt_poses=None):
+ '''
+ gt_poses: (1, seq_len, 270), 135 *2
+ pred_poses: (B, seq_len, 108), 54 * 2
+ '''
+ if gt_poses is None:
+ gt_poses = np.zeros_like(pred_poses)
+ else:
+ gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
+
+ length = min(pred_poses.shape[1], gt_poses.shape[1])
+ pred_poses = pred_poses[:, :length, :]
+ gt_poses = gt_poses[:, :length, :]
+ gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
+ pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
+
+ gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
+ gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
+
+ return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
+
+def hand_points(seq):
+ '''
+ seq: (B, seq_len, 135*2)
+ hands only
+ '''
+ hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
+ return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
+
+def valid_points(seq):
+ '''
+ hands with some head points
+ '''
+ valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
+
+ seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
+ assert seq.shape[-1] == 108, seq.shape
+ return seq
+
+def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
+ plt.figure()
+ plt.hist(seq, bins=100, range=(0, 100), color=color)
+ plt.savefig(save_name)
+
+def to_excel(seq, save_name='res.xlsx'):
+ '''
+ seq: (T)
+ '''
+ df = pd.DataFrame(seq)
+ writer = pd.ExcelWriter(save_name)
+ df.to_excel(writer, 'sheet1')
+ writer.save()
+ writer.close()
+
+
+if __name__ == '__main__':
+ random_data = np.random.randint(0, 10, 100)
+ draw_cdf(random_data)
\ No newline at end of file
diff --git a/losses/__init__.py b/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..37fdea0eb6c3190e7001567cfe17dc296bf811e8
--- /dev/null
+++ b/losses/__init__.py
@@ -0,0 +1 @@
+from .losses import *
\ No newline at end of file
diff --git a/losses/__pycache__/__init__.cpython-37.pyc b/losses/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a099ba9ea1527c7efdbd683bf96e5cd43dd8d932
Binary files /dev/null and b/losses/__pycache__/__init__.cpython-37.pyc differ
diff --git a/losses/__pycache__/losses.cpython-37.pyc b/losses/__pycache__/losses.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff2ede0fb38bc991969ee097ef54cc16cb499989
Binary files /dev/null and b/losses/__pycache__/losses.cpython-37.pyc differ
diff --git a/losses/losses.py b/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e433ca256b3ed54b77fbc6ca8751aa32959153
--- /dev/null
+++ b/losses/losses.py
@@ -0,0 +1,91 @@
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class KeypointLoss(nn.Module):
+ def __init__(self):
+ super(KeypointLoss, self).__init__()
+
+ def forward(self, pred_seq, gt_seq, gt_conf=None):
+ #pred_seq: (B, C, T)
+ if gt_conf is not None:
+ gt_conf = gt_conf >= 0.01
+ return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
+ else:
+ return F.mse_loss(pred_seq, gt_seq)
+
+
+class KLLoss(nn.Module):
+ def __init__(self, kl_tolerance):
+ super(KLLoss, self).__init__()
+ self.kl_tolerance = kl_tolerance
+
+ def forward(self, mu, var, mul=1):
+ kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
+ kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
+ # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
+ if self.kl_tolerance is not None:
+ # above_line = kld_loss[kld_loss > self.kl_tolerance]
+ # if len(above_line) > 0:
+ # kld_loss = torch.mean(kld_loss)
+ # else:
+ # kld_loss = 0
+ kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
+ # else:
+ kld_loss = torch.mean(kld_loss)
+ return kld_loss
+
+
+class L2KLLoss(nn.Module):
+ def __init__(self, kl_tolerance):
+ super(L2KLLoss, self).__init__()
+ self.kl_tolerance = kl_tolerance
+
+ def forward(self, x):
+ # TODO: check
+ kld_loss = torch.sum(x ** 2, dim=1)
+ if self.kl_tolerance is not None:
+ above_line = kld_loss[kld_loss > self.kl_tolerance]
+ if len(above_line) > 0:
+ kld_loss = torch.mean(kld_loss)
+ else:
+ kld_loss = 0
+ else:
+ kld_loss = torch.mean(kld_loss)
+ return kld_loss
+
+class L2RegLoss(nn.Module):
+ def __init__(self):
+ super(L2RegLoss, self).__init__()
+
+ def forward(self, x):
+ #TODO: check
+ return torch.sum(x**2)
+
+
+class L2Loss(nn.Module):
+ def __init__(self):
+ super(L2Loss, self).__init__()
+
+ def forward(self, x):
+ # TODO: check
+ return torch.sum(x ** 2)
+
+
+class AudioLoss(nn.Module):
+ def __init__(self):
+ super(AudioLoss, self).__init__()
+
+ def forward(self, dynamics, gt_poses):
+ #pay attention, normalized
+ mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
+ gt = gt_poses - mean
+ return F.mse_loss(dynamics, gt)
+
+L1Loss = nn.L1Loss
\ No newline at end of file
diff --git a/nets/LS3DCG.py b/nets/LS3DCG.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1b2728385c1c196501c73822644a461ead309ec
--- /dev/null
+++ b/nets/LS3DCG.py
@@ -0,0 +1,414 @@
+'''
+not exactly the same as the official repo but the results are good
+'''
+import sys
+import os
+
+from data_utils.lower_body import c_index_3d, c_index_6d
+
+sys.path.append(os.getcwd())
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import math
+
+from nets.base import TrainWrapperBaseClass
+from nets.layers import SeqEncoder1D
+from losses import KeypointLoss, L1Loss, KLLoss
+from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
+from nets.utils import denormalize
+
+class Conv1d_tf(nn.Conv1d):
+ """
+ Conv1d with the padding behavior from TF
+ modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Conv1d_tf, self).__init__(*args, **kwargs)
+ self.padding = kwargs.get("padding", "same")
+
+ def _compute_padding(self, input, dim):
+ input_size = input.size(dim + 2)
+ filter_size = self.weight.size(dim + 2)
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
+ total_padding = max(
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
+ )
+ additional_padding = int(total_padding % 2 != 0)
+
+ return additional_padding, total_padding
+
+ def forward(self, input):
+ if self.padding == "VALID":
+ return F.conv1d(
+ input,
+ self.weight,
+ self.bias,
+ self.stride,
+ padding=0,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
+ if rows_odd:
+ input = F.pad(input, [0, rows_odd])
+
+ return F.conv1d(
+ input,
+ self.weight,
+ self.bias,
+ self.stride,
+ padding=(padding_rows // 2),
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+
+
+def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
+ if k is None and s is None:
+ if not downsample:
+ k = 3
+ s = 1
+ else:
+ k = 4
+ s = 2
+
+ if type == '1d':
+ conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
+ if norm == 'bn':
+ norm_block = nn.BatchNorm1d(out_channels)
+ elif norm == 'ln':
+ norm_block = nn.LayerNorm(out_channels)
+ elif type == '2d':
+ conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
+ norm_block = nn.BatchNorm2d(out_channels)
+ else:
+ assert False
+
+ return nn.Sequential(
+ conv_block,
+ norm_block,
+ nn.LeakyReLU(0.2, True)
+ )
+
+class Decoder(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super(Decoder, self).__init__()
+ self.up1 = nn.Sequential(
+ ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
+ ConvNormRelu(in_ch // 2, in_ch // 2),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+ self.up2 = nn.Sequential(
+ ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
+ ConvNormRelu(in_ch // 4, in_ch // 4),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+ self.up3 = nn.Sequential(
+ ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
+ ConvNormRelu(in_ch // 8, in_ch // 8),
+ nn.Conv1d(in_ch // 8, out_ch, 1, 1)
+ )
+
+ def forward(self, x, x1, x2, x3):
+ x = F.interpolate(x, x3.shape[2])
+ x = torch.cat([x, x3], dim=1)
+ x = self.up1(x)
+ x = F.interpolate(x, x2.shape[2])
+ x = torch.cat([x, x2], dim=1)
+ x = self.up2(x)
+ x = F.interpolate(x, x1.shape[2])
+ x = torch.cat([x, x1], dim=1)
+ x = self.up3(x)
+ return x
+
+
+class EncoderDecoder(nn.Module):
+ def __init__(self, n_frames, each_dim):
+ super().__init__()
+ self.n_frames = n_frames
+
+ self.down1 = nn.Sequential(
+ ConvNormRelu(64, 64, '1d', False),
+ ConvNormRelu(64, 128, '1d', False),
+ )
+ self.down2 = nn.Sequential(
+ ConvNormRelu(128, 128, '1d', False),
+ ConvNormRelu(128, 256, '1d', False),
+ )
+ self.down3 = nn.Sequential(
+ ConvNormRelu(256, 256, '1d', False),
+ ConvNormRelu(256, 512, '1d', False),
+ )
+ self.down4 = nn.Sequential(
+ ConvNormRelu(512, 512, '1d', False),
+ ConvNormRelu(512, 1024, '1d', False),
+ )
+
+ self.down = nn.MaxPool1d(kernel_size=2)
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
+
+ self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
+ self.body_decoder = Decoder(1024, each_dim[1])
+ self.hand_decoder = Decoder(1024, each_dim[2])
+
+ def forward(self, spectrogram, time_steps=None):
+ if time_steps is None:
+ time_steps = self.n_frames
+
+ x1 = self.down1(spectrogram)
+ x = self.down(x1)
+ x2 = self.down2(x)
+ x = self.down(x2)
+ x3 = self.down3(x)
+ x = self.down(x3)
+ x = self.down4(x)
+ x = self.up(x)
+
+ face = self.face_decoder(x, x1, x2, x3)
+ body = self.body_decoder(x, x1, x2, x3)
+ hand = self.hand_decoder(x, x1, x2, x3)
+
+ return face, body, hand
+
+
+class Generator(nn.Module):
+ def __init__(self,
+ each_dim,
+ training=False,
+ device=None
+ ):
+ super().__init__()
+
+ self.training = training
+ self.device = device
+
+ self.encoderdecoder = EncoderDecoder(15, each_dim)
+
+ def forward(self, in_spec, time_steps=None):
+ if time_steps is not None:
+ self.gen_length = time_steps
+
+ face, body, hand = self.encoderdecoder(in_spec)
+ out = torch.cat([face, body, hand], dim=1)
+ out = out.transpose(1, 2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, input_dim):
+ super().__init__()
+ self.net = nn.Sequential(
+ ConvNormRelu(input_dim, 128, '1d'),
+ ConvNormRelu(128, 256, '1d'),
+ nn.MaxPool1d(kernel_size=2),
+ ConvNormRelu(256, 256, '1d'),
+ ConvNormRelu(256, 512, '1d'),
+ nn.MaxPool1d(kernel_size=2),
+ ConvNormRelu(512, 512, '1d'),
+ ConvNormRelu(512, 1024, '1d'),
+ nn.MaxPool1d(kernel_size=2),
+ nn.Conv1d(1024, 1, 1, 1),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+
+ out = self.net(x)
+ return out
+
+
+class TrainWrapper(TrainWrapperBaseClass):
+ def __init__(self, args, config) -> None:
+ self.args = args
+ self.config = config
+ self.device = torch.device(self.args.gpu)
+ self.global_step = 0
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
+ self.init_params()
+
+ self.generator = Generator(
+ each_dim=self.each_dim,
+ training=not self.args.infer,
+ device=self.device,
+ ).to(self.device)
+ self.discriminator = Discriminator(
+ input_dim=self.each_dim[1] + self.each_dim[2] + 64
+ ).to(self.device)
+ if self.convert_to_6d:
+ self.c_index = c_index_6d
+ else:
+ self.c_index = c_index_3d
+ self.MSELoss = KeypointLoss().to(self.device)
+ self.L1Loss = L1Loss().to(self.device)
+ super().__init__(args, config)
+
+ def init_params(self):
+ scale = 1
+
+ global_orient = round(0 * scale)
+ leye_pose = reye_pose = round(0 * scale)
+ jaw_pose = round(3 * scale)
+ body_pose = round((63 - 24) * scale)
+ left_hand_pose = right_hand_pose = round(45 * scale)
+
+ expression = 100
+
+ b_j = 0
+ jaw_dim = jaw_pose
+ b_e = b_j + jaw_dim
+ eye_dim = leye_pose + reye_pose
+ b_b = b_e + eye_dim
+ body_dim = global_orient + body_pose
+ b_h = b_b + body_dim
+ hand_dim = left_hand_pose + right_hand_pose
+ b_f = b_h + hand_dim
+ face_dim = expression
+
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
+ self.pose = int(self.full_dim / round(3 * scale))
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
+
+ def __call__(self, bat):
+ assert (not self.args.infer), "infer mode"
+ self.global_step += 1
+
+ loss_dict = {}
+
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
+ expression = bat['expression'].to(self.device).to(torch.float32)
+ jaw = poses[:, :3, :]
+ poses = poses[:, self.c_index, :]
+
+ pred = self.generator(in_spec=aud)
+
+ D_loss, D_loss_dict = self.get_loss(
+ pred_poses=pred.detach(),
+ gt_poses=poses,
+ aud=aud,
+ mode='training_D',
+ )
+
+ self.discriminator_optimizer.zero_grad()
+ D_loss.backward()
+ self.discriminator_optimizer.step()
+
+ G_loss, G_loss_dict = self.get_loss(
+ pred_poses=pred,
+ gt_poses=poses,
+ aud=aud,
+ expression=expression,
+ jaw=jaw,
+ mode='training_G',
+ )
+ self.generator_optimizer.zero_grad()
+ G_loss.backward()
+ self.generator_optimizer.step()
+
+ total_loss = None
+ loss_dict = {}
+ for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
+ loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
+
+ return total_loss, loss_dict
+
+ def get_loss(self,
+ pred_poses,
+ gt_poses,
+ aud=None,
+ jaw=None,
+ expression=None,
+ mode='training_G',
+ ):
+ loss_dict = {}
+ aud = aud.transpose(1, 2)
+ gt_poses = gt_poses.transpose(1, 2)
+ gt_aud = torch.cat([gt_poses, aud], dim=2)
+ pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
+
+ if mode == 'training_D':
+ dis_real = self.discriminator(gt_aud)
+ dis_fake = self.discriminator(pred_aud)
+ dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
+ torch.zeros_like(dis_fake).to(self.device), dis_fake)
+ loss_dict['dis'] = dis_error
+
+ return dis_error, loss_dict
+ elif mode == 'training_G':
+ jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
+ face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
+ body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
+ hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
+ l1_loss = jaw_loss + face_loss + body_loss + hand_loss
+
+ dis_output = self.discriminator(pred_aud)
+ gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
+ gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
+
+ loss_dict['gen'] = gen_error
+ loss_dict['jaw_loss'] = jaw_loss
+ loss_dict['face_loss'] = face_loss
+ loss_dict['body_loss'] = body_loss
+ loss_dict['hand_loss'] = hand_loss
+ return gen_loss, loss_dict
+ else:
+ raise ValueError(mode)
+
+ def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
+ output = []
+ assert self.args.infer, "train mode"
+ self.generator.eval()
+
+ if self.config.Data.pose.normalization:
+ assert norm_stats is not None
+ data_mean = norm_stats[0]
+ data_std = norm_stats[1]
+
+ pre_length = self.config.Data.pose.pre_pose_length
+ generate_length = self.config.Data.pose.generate_length
+ # assert pre_length == initial_pose.shape[-1]
+ # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
+ # B = pre_poses.shape[0]
+
+ aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
+ num_poses_to_generate = aud_feat.shape[-1]
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
+
+ with torch.no_grad():
+ pred_poses = self.generator(aud_feat)
+ pred_poses = pred_poses.cpu().numpy()
+ output = pred_poses.squeeze()
+
+ return output
+
+ def generate(self, aud, id):
+ self.generator.eval()
+ pred_poses = self.generator(aud)
+ return pred_poses
+
+
+if __name__ == '__main__':
+ from trainer.options import parse_args
+
+ parser = parse_args()
+ args = parser.parse_args(
+ ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
+ '--infer'])
+
+ generator = TrainWrapper(args)
+
+ aud_fn = '../sample_audio/jon.wav'
+ initial_pose = torch.randn(64, 108, 4)
+ norm_stats = (np.random.randn(108), np.random.randn(108))
+ output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
+
+ print(output.shape)
diff --git a/nets/__init__.py b/nets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0669d82d7506b314fa3fccd7dd412445fa0b37e1
--- /dev/null
+++ b/nets/__init__.py
@@ -0,0 +1,8 @@
+from .smplx_face import TrainWrapper as s2g_face
+from .smplx_body_vq import TrainWrapper as s2g_body_vq
+from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
+from .body_ae import TrainWrapper as s2g_body_ae
+from .LS3DCG import TrainWrapper as LS3DCG
+from .base import TrainWrapperBaseClass
+
+from .utils import normalize, denormalize
\ No newline at end of file
diff --git a/nets/__pycache__/LS3DCG.cpython-37.pyc b/nets/__pycache__/LS3DCG.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b805df22b6a5908c15e9a8d45380c287c864df79
Binary files /dev/null and b/nets/__pycache__/LS3DCG.cpython-37.pyc differ
diff --git a/nets/__pycache__/__init__.cpython-37.pyc b/nets/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd77b05f4e2503c9766d2581325594482c8ab30c
Binary files /dev/null and b/nets/__pycache__/__init__.cpython-37.pyc differ
diff --git a/nets/__pycache__/base.cpython-37.pyc b/nets/__pycache__/base.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f94093ffb32eb5d7a08ee7ffc694f1dd667a3134
Binary files /dev/null and b/nets/__pycache__/base.cpython-37.pyc differ
diff --git a/nets/__pycache__/body_ae.cpython-37.pyc b/nets/__pycache__/body_ae.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a42d9b5d9ae7795677e92931304111ded3f4c68d
Binary files /dev/null and b/nets/__pycache__/body_ae.cpython-37.pyc differ
diff --git a/nets/__pycache__/init_model.cpython-37.pyc b/nets/__pycache__/init_model.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..089d2e7ec5d5bbd73d1fff7cef7c792c378de03a
Binary files /dev/null and b/nets/__pycache__/init_model.cpython-37.pyc differ
diff --git a/nets/__pycache__/layers.cpython-37.pyc b/nets/__pycache__/layers.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0df91833195fee22a25b90b578cac16c53bf11d
Binary files /dev/null and b/nets/__pycache__/layers.cpython-37.pyc differ
diff --git a/nets/__pycache__/smplx_body_pixel.cpython-37.pyc b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9be6845fada6e59cc304238f4c9e847ea0255a8
Binary files /dev/null and b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc differ
diff --git a/nets/__pycache__/smplx_body_vq.cpython-37.pyc b/nets/__pycache__/smplx_body_vq.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..480f25b32e812601156f4ab65d5d8e1ad6915704
Binary files /dev/null and b/nets/__pycache__/smplx_body_vq.cpython-37.pyc differ
diff --git a/nets/__pycache__/smplx_face.cpython-37.pyc b/nets/__pycache__/smplx_face.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12138283f09af4e1e631e9b8704ab8a8b273f3de
Binary files /dev/null and b/nets/__pycache__/smplx_face.cpython-37.pyc differ
diff --git a/nets/__pycache__/utils.cpython-37.pyc b/nets/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8092271dc39faad72b0c77283492a2f75a0067c8
Binary files /dev/null and b/nets/__pycache__/utils.cpython-37.pyc differ
diff --git a/nets/base.py b/nets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c07caa27ba642dd5a48cebfcf51e4e79edd574
--- /dev/null
+++ b/nets/base.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+class TrainWrapperBaseClass():
+ def __init__(self, args, config) -> None:
+ self.init_optimizer()
+
+ def init_optimizer(self) -> None:
+ print('using Adam')
+ self.generator_optimizer = optim.Adam(
+ self.generator.parameters(),
+ lr = self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+ if self.discriminator is not None:
+ self.discriminator_optimizer = optim.Adam(
+ self.discriminator.parameters(),
+ lr = self.config.Train.learning_rate.discriminator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+
+ def __call__(self, bat):
+ raise NotImplementedError
+
+ def get_loss(self, **kwargs):
+ raise NotImplementedError
+
+ def state_dict(self):
+ model_state = {
+ 'generator': self.generator.state_dict(),
+ 'generator_optim': self.generator_optimizer.state_dict(),
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
+ }
+ return model_state
+
+ def parameters(self):
+ return self.generator.parameters()
+
+ def load_state_dict(self, state_dict):
+ if 'generator' in state_dict:
+ self.generator.load_state_dict(state_dict['generator'])
+ else:
+ self.generator.load_state_dict(state_dict)
+
+ if 'generator_optim' in state_dict and self.generator_optimizer is not None:
+ self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
+
+ if self.discriminator is not None:
+ self.discriminator.load_state_dict(state_dict['discriminator'])
+
+ if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
+ self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
+
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
+ raise NotImplementedError
+
+ def init_params(self):
+ if self.config.Data.pose.convert_to_6d:
+ scale = 2
+ else:
+ scale = 1
+
+ global_orient = round(0 * scale)
+ leye_pose = reye_pose = round(0 * scale)
+ jaw_pose = round(0 * scale)
+ body_pose = round((63 - 24) * scale)
+ left_hand_pose = right_hand_pose = round(45 * scale)
+ if self.expression:
+ expression = 100
+ else:
+ expression = 0
+
+ b_j = 0
+ jaw_dim = jaw_pose
+ b_e = b_j + jaw_dim
+ eye_dim = leye_pose + reye_pose
+ b_b = b_e + eye_dim
+ body_dim = global_orient + body_pose
+ b_h = b_b + body_dim
+ hand_dim = left_hand_pose + right_hand_pose
+ b_f = b_h + hand_dim
+ face_dim = expression
+
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
+ self.pose = int(self.full_dim / round(3 * scale))
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
\ No newline at end of file
diff --git a/nets/body_ae.py b/nets/body_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a9f8bc0ee92f8410da71d711bb19dbcc254d1af
--- /dev/null
+++ b/nets/body_ae.py
@@ -0,0 +1,152 @@
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from nets.base import TrainWrapperBaseClass
+from nets.spg.s2glayers import Discriminator as D_S2G
+from nets.spg.vqvae_1d import AE as s2g_body
+import torch
+import torch.optim as optim
+import torch.nn.functional as F
+
+from data_utils.lower_body import c_index, c_index_3d, c_index_6d
+
+
+def separate_aa(aa):
+ aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
+ axis = F.normalize(aa[:, :, :, :3], dim=-1)
+ angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
+ return axis, angle
+
+
+class TrainWrapper(TrainWrapperBaseClass):
+ '''
+ a wrapper receving a batch from data_utils and calculate loss
+ '''
+
+ def __init__(self, args, config):
+ self.args = args
+ self.config = config
+ self.device = torch.device(self.args.gpu)
+ self.global_step = 0
+
+ self.gan = False
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
+ self.preleng = self.config.Data.pose.pre_pose_length
+ self.expression = self.config.Data.pose.expression
+ self.epoch = 0
+ self.init_params()
+ self.num_classes = 4
+ self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
+ num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+ if self.gan:
+ self.discriminator = D_S2G(
+ pose_dim=110 + 64, pose=self.pose
+ ).to(self.device)
+ else:
+ self.discriminator = None
+
+ if self.convert_to_6d:
+ self.c_index = c_index_6d
+ else:
+ self.c_index = c_index_3d
+
+ super().__init__(args, config)
+
+ def init_optimizer(self):
+
+ self.g_optimizer = optim.Adam(
+ self.g.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+
+ def state_dict(self):
+ model_state = {
+ 'g': self.g.state_dict(),
+ 'g_optim': self.g_optimizer.state_dict(),
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
+ }
+ return model_state
+
+
+ def __call__(self, bat):
+ # assert (not self.args.infer), "infer mode"
+ self.global_step += 1
+
+ total_loss = None
+ loss_dict = {}
+
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
+
+ # id = bat['speaker'].to(self.device) - 20
+ # id = F.one_hot(id, self.num_classes)
+
+ poses = poses[:, self.c_index, :]
+ gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
+
+ loss = 0
+ loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
+
+ return total_loss, loss_dict
+
+ def vq_train(self, gt, name, model, dict, total_loss, pre=None):
+ x_recon = model(gt_poses=gt, pre_state=pre)
+ loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
+ # total_loss = total_loss + loss
+
+ if name == 'g':
+ optimizer_name = 'g_optimizer'
+
+ optimizer = getattr(self, optimizer_name)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ for key in list(loss_dict.keys()):
+ dict[name + key] = loss_dict.get(key, 0).item()
+ return dict, total_loss
+
+ def get_loss(self,
+ pred_poses,
+ gt_poses,
+ pre=None
+ ):
+ loss_dict = {}
+
+
+ rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
+ v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
+ v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
+ velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
+
+ if pre is None:
+ f0_vel = 0
+ else:
+ v0_pr = pred_poses[:, 0] - pre[:, -1]
+ v0_gt = gt_poses[:, 0] - pre[:, -1]
+ f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
+
+ gen_loss = rec_loss + velocity_loss + f0_vel
+
+ loss_dict['rec_loss'] = rec_loss
+ loss_dict['velocity_loss'] = velocity_loss
+ # loss_dict['e_q_loss'] = e_q_loss
+ if pre is not None:
+ loss_dict['f0_vel'] = f0_vel
+
+ return gen_loss, loss_dict
+
+ def load_state_dict(self, state_dict):
+ self.g.load_state_dict(state_dict['g'])
+
+ def extract(self, x):
+ self.g.eval()
+ if x.shape[2] > self.full_dim:
+ if x.shape[2] == 239:
+ x = x[:, :, 102:]
+ x = x[:, :, self.c_index]
+ feat = self.g.encode(x)
+ return feat.transpose(1, 2), x
diff --git a/nets/init_model.py b/nets/init_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a849d9e49afc54048979ebc12c743c4a8b8620a7
--- /dev/null
+++ b/nets/init_model.py
@@ -0,0 +1,35 @@
+from nets import *
+
+
+def init_model(model_name, args, config):
+
+ if model_name == 's2g_face':
+ generator = s2g_face(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_vq':
+ generator = s2g_body_vq(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_pixel':
+ generator = s2g_body_pixel(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_ae':
+ generator = s2g_body_ae(
+ args,
+ config,
+ )
+ elif model_name == 's2g_LS3DCG':
+ generator = LS3DCG(
+ args,
+ config,
+ )
+ else:
+ raise ValueError
+ return generator
+
+
diff --git a/nets/layers.py b/nets/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..79251b42b6e0fe839ec04dc38472ef36165208ac
--- /dev/null
+++ b/nets/layers.py
@@ -0,0 +1,1052 @@
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+# TODO: be aware of the actual netork structures
+
+def get_log(x):
+ log = 0
+ while x > 1:
+ if x % 2 == 0:
+ x = x // 2
+ log += 1
+ else:
+ raise ValueError('x is not a power of 2')
+
+ return log
+
+
+class ConvNormRelu(nn.Module):
+ '''
+ (B,C_in,H,W) -> (B, C_out, H, W)
+ there exist some kernel size that makes the result is not H/s
+ #TODO: there might some problems with residual
+ '''
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ type='1d',
+ leaky=False,
+ downsample=False,
+ kernel_size=None,
+ stride=None,
+ padding=None,
+ p=0,
+ groups=1,
+ residual=False,
+ norm='bn'):
+ '''
+ conv-bn-relu
+ '''
+ super(ConvNormRelu, self).__init__()
+ self.residual = residual
+ self.norm_type = norm
+ # kernel_size = k
+ # stride = s
+
+ if kernel_size is None and stride is None:
+ if not downsample:
+ kernel_size = 3
+ stride = 1
+ else:
+ kernel_size = 4
+ stride = 2
+
+ if padding is None:
+ if isinstance(kernel_size, int) and isinstance(stride, tuple):
+ padding = tuple(int((kernel_size - st) / 2) for st in stride)
+ elif isinstance(kernel_size, tuple) and isinstance(stride, int):
+ padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
+ elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
+ padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
+ else:
+ padding = int((kernel_size - stride) / 2)
+
+ if self.residual:
+ if downsample:
+ if type == '1d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ elif type == '2d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ else:
+ if in_channels == out_channels:
+ self.residual_layer = nn.Identity()
+ else:
+ if type == '1d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ elif type == '2d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+
+ in_channels = in_channels * groups
+ out_channels = out_channels * groups
+ if type == '1d':
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm1d(out_channels)
+ self.dropout = nn.Dropout(p=p)
+ elif type == '2d':
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.dropout = nn.Dropout2d(p=p)
+ if norm == 'gn':
+ self.norm = nn.GroupNorm(2, out_channels)
+ elif norm == 'ln':
+ self.norm = nn.LayerNorm(out_channels)
+ if leaky:
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
+ else:
+ self.relu = nn.ReLU()
+
+ def forward(self, x, **kwargs):
+ if self.norm_type == 'ln':
+ out = self.dropout(self.conv(x))
+ out = self.norm(out.transpose(1,2)).transpose(1,2)
+ else:
+ out = self.norm(self.dropout(self.conv(x)))
+ if self.residual:
+ residual = self.residual_layer(x)
+ out += residual
+ return self.relu(out)
+
+
+class UNet1D(nn.Module):
+ def __init__(self,
+ input_channels,
+ output_channels,
+ max_depth=5,
+ kernel_size=None,
+ stride=None,
+ p=0,
+ groups=1):
+ super(UNet1D, self).__init__()
+ self.pre_downsampling_conv = nn.ModuleList([])
+ self.conv1 = nn.ModuleList([])
+ self.conv2 = nn.ModuleList([])
+ self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
+ self.max_depth = max_depth
+ self.groups = groups
+
+ self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels,
+ type='1d', leaky=True, downsample=False,
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
+ self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels,
+ type='1d', leaky=True, downsample=False,
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
+
+ for i in range(self.max_depth):
+ self.conv1.append(ConvNormRelu(output_channels, output_channels,
+ type='1d', leaky=True, downsample=True,
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
+
+ for i in range(self.max_depth):
+ self.conv2.append(ConvNormRelu(output_channels, output_channels,
+ type='1d', leaky=True, downsample=False,
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
+
+ def forward(self, x):
+
+ input_size = x.shape[-1]
+
+ assert get_log(
+ input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth'
+
+ x = nn.Sequential(*self.pre_downsampling_conv)(x)
+
+ residuals = []
+ residuals.append(x)
+ for i, conv1 in enumerate(self.conv1):
+ x = conv1(x)
+ if i < self.max_depth - 1:
+ residuals.append(x)
+
+ for i, conv2 in enumerate(self.conv2):
+ x = self.upconv(x) + residuals[self.max_depth - i - 1]
+ x = conv2(x)
+
+ return x
+
+
+class UNet2D(nn.Module):
+ def __init__(self):
+ super(UNet2D, self).__init__()
+ raise NotImplementedError('2D Unet is wierd')
+
+
+class AudioPoseEncoder1D(nn.Module):
+ '''
+ (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T)
+ '''
+
+ def __init__(self,
+ C_in,
+ C_out,
+ kernel_size=None,
+ stride=None,
+ min_layer_nums=None
+ ):
+ super(AudioPoseEncoder1D, self).__init__()
+ self.C_in = C_in
+ self.C_out = C_out
+
+ conv_layers = nn.ModuleList([])
+ cur_C = C_in
+ num_layers = 0
+ while cur_C < self.C_out:
+ conv_layers.append(ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=cur_C * 2,
+ kernel_size=kernel_size,
+ stride=stride
+ ))
+ cur_C *= 2
+ num_layers += 1
+
+ if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums):
+ while (cur_C != C_out) or num_layers < min_layer_nums:
+ conv_layers.append(ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=C_out,
+ kernel_size=kernel_size,
+ stride=stride
+ ))
+ num_layers += 1
+ cur_C = C_out
+
+ self.conv_layers = nn.Sequential(*conv_layers)
+
+ def forward(self, x):
+ '''
+ x: (B, C, T)
+ '''
+ x = self.conv_layers(x)
+ return x
+
+
+class AudioPoseEncoder2D(nn.Module):
+ '''
+ (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T)
+ '''
+
+ def __init__(self):
+ raise NotImplementedError
+
+
+class AudioPoseEncoderRNN(nn.Module):
+ '''
+ (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T)
+ '''
+
+ def __init__(self,
+ C_in,
+ hidden_size,
+ num_layers,
+ rnn_cell='gru',
+ bidirectional=False
+ ):
+ super(AudioPoseEncoderRNN, self).__init__()
+ if rnn_cell == 'gru':
+ self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=bidirectional)
+ elif rnn_cell == 'lstm':
+ self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=bidirectional)
+ else:
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
+
+ def forward(self, x, state=None):
+
+ x = x.permute(0, 2, 1)
+ x, state = self.cell(x, state)
+ x = x.permute(0, 2, 1)
+
+ return x
+
+
+class AudioPoseEncoderGraph(nn.Module):
+ '''
+ (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V)
+ '''
+
+ def __init__(self,
+ layers_config, # 理应是(C_in, C_out, kernel_size)的list
+ A, # adjacent matrix (num_parts, V, V)
+ residual,
+ local_bn=False,
+ share_weights=False
+ ) -> None:
+ super().__init__()
+ self.A = A
+ self.num_joints = A.shape[1]
+ self.num_parts = A.shape[0]
+ self.C_in = layers_config[0][0]
+ self.C_out = layers_config[-1][1]
+
+ self.conv_layers = nn.ModuleList([
+ GraphConvNormRelu(
+ C_in=c_in,
+ C_out=c_out,
+ A=self.A,
+ residual=residual,
+ local_bn=local_bn,
+ kernel_size=k,
+ share_weights=share_weights
+ ) for (c_in, c_out, k) in layers_config
+ ])
+
+ self.conv_layers = nn.Sequential(*self.conv_layers)
+
+ def forward(self, x):
+ '''
+ x: (B, C, T), C should be num_joints*D
+ output: (B, D, T, V)
+ '''
+ B, C, T = x.shape
+ x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面
+ x = x.permute(0, 2, 3, 1) # (B, D, T, V)
+ assert x.shape[1] == self.C_in
+
+ x_conved = self.conv_layers(x)
+
+ # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T)
+
+ return x_conved
+
+
+class SeqEncoder2D(nn.Module):
+ '''
+ seq_encoder, encoding a seq to a vector
+ (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out)
+ '''
+
+ def __init__(self,
+ C_in, # should be 2
+ T_in,
+ C_out,
+ num_joints,
+ min_layer_num=None,
+ residual=False
+ ):
+ super(SeqEncoder2D, self).__init__()
+ self.C_in = C_in
+ self.C_out = C_out
+ self.T_in = T_in
+ self.num_joints = num_joints
+
+ conv_layers = nn.ModuleList([])
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_in,
+ out_channels=32,
+ type='2d',
+ residual=residual
+ ))
+
+ cur_C = 32
+ cur_H = T_in
+ cur_W = num_joints
+ num_layers = 1
+ while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1):
+ ks = [3, 3]
+ st = [1, 1]
+
+ if cur_H > 1:
+ if cur_H > 4:
+ ks[0] = 4
+ st[0] = 2
+ else:
+ ks[0] = cur_H
+ st[0] = cur_H
+ if cur_W > 1:
+ if cur_W > 4:
+ ks[1] = 4
+ st[1] = 2
+ else:
+ ks[1] = cur_W
+ st[1] = cur_W
+
+ conv_layers.append(ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=min(C_out, cur_C * 2),
+ type='2d',
+ kernel_size=tuple(ks),
+ stride=tuple(st),
+ residual=residual
+ ))
+ cur_C = min(cur_C * 2, C_out)
+ if cur_H > 1:
+ if cur_H > 4:
+ cur_H //= 2
+ else:
+ cur_H = 1
+ if cur_W > 1:
+ if cur_W > 4:
+ cur_W //= 2
+ else:
+ cur_W = 1
+ num_layers += 1
+
+ if min_layer_num is not None and (num_layers < min_layer_num):
+ while num_layers < min_layer_num:
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='2d',
+ kernel_size=1,
+ stride=1,
+ residual=residual
+ ))
+ num_layers += 1
+
+ self.conv_layers = nn.Sequential(*conv_layers)
+ self.num_layers = num_layers
+
+ def forward(self, x):
+ B, C, T = x.shape
+ x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front
+ x = x.permute(0, 2, 3, 1) # (B, D, T, V)
+ assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints
+
+ x = self.conv_layers(x)
+ return x.squeeze()
+
+
+class SeqEncoder1D(nn.Module):
+ '''
+ (B, C, T)->(B, D)
+ '''
+
+ def __init__(self,
+ C_in,
+ C_out,
+ T_in,
+ min_layer_nums=None
+ ):
+ super(SeqEncoder1D, self).__init__()
+ conv_layers = nn.ModuleList([])
+ cur_C = C_in
+ cur_T = T_in
+ self.num_layers = 0
+ while (cur_C < C_out) or (cur_T > 1):
+ ks = 3
+ st = 1
+ if cur_T > 1:
+ if cur_T > 4:
+ ks = 4
+ st = 2
+ else:
+ ks = cur_T
+ st = cur_T
+
+ conv_layers.append(ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=min(C_out, cur_C * 2),
+ type='1d',
+ kernel_size=ks,
+ stride=st
+ ))
+ cur_C = min(cur_C * 2, C_out)
+ if cur_T > 1:
+ if cur_T > 4:
+ cur_T = cur_T // 2
+ else:
+ cur_T = 1
+ self.num_layers += 1
+
+ if min_layer_nums is not None and (self.num_layers < min_layer_nums):
+ while self.num_layers < min_layer_nums:
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d',
+ kernel_size=1,
+ stride=1
+ ))
+ self.num_layers += 1
+ self.conv_layers = nn.Sequential(*conv_layers)
+
+ def forward(self, x):
+ x = self.conv_layers(x)
+ return x.squeeze()
+
+
+class SeqEncoderRNN(nn.Module):
+ '''
+ (B, C, T) -> (B, T, C) -> (B, D)
+ LSTM/GRU-FC
+ '''
+
+ def __init__(self,
+ hidden_size,
+ in_size,
+ num_rnn_layers,
+ rnn_cell='gru',
+ bidirectional=False
+ ):
+ super(SeqEncoderRNN, self).__init__()
+ self.hidden_size = hidden_size
+ self.in_size = in_size
+ self.num_rnn_layers = num_rnn_layers
+ self.bidirectional = bidirectional
+
+ if rnn_cell == 'gru':
+ self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
+ batch_first=True, bidirectional=bidirectional)
+ elif rnn_cell == 'lstm':
+ self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
+ batch_first=True, bidirectional=bidirectional)
+
+ def forward(self, x, state=None):
+
+ x = x.permute(0, 2, 1)
+ B, T, C = x.shape
+ x, _ = self.cell(x, state)
+ if self.bidirectional:
+ out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1)
+ else:
+ out = x[:, -1, :]
+ assert out.shape[0] == B
+ return out
+
+
+class SeqEncoderGraph(nn.Module):
+ '''
+ '''
+
+ def __init__(self,
+ embedding_size,
+ layer_configs,
+ residual,
+ local_bn,
+ A,
+ T,
+ share_weights=False
+ ) -> None:
+ super().__init__()
+
+ self.C_in = layer_configs[0][0]
+ self.C_out = embedding_size
+
+ self.num_joints = A.shape[1]
+
+ self.graph_encoder = AudioPoseEncoderGraph(
+ layers_config=layer_configs,
+ A=A,
+ residual=residual,
+ local_bn=local_bn,
+ share_weights=share_weights
+ )
+
+ cur_C = layer_configs[-1][1]
+ self.spatial_pool = ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=cur_C,
+ type='2d',
+ kernel_size=(1, self.num_joints),
+ stride=(1, 1),
+ padding=(0, 0)
+ )
+
+ temporal_pool = nn.ModuleList([])
+ cur_H = T
+ num_layers = 0
+ self.temporal_conv_info = []
+ while cur_C < self.C_out or cur_H > 1:
+ self.temporal_conv_info.append(cur_C)
+ ks = [3, 1]
+ st = [1, 1]
+
+ if cur_H > 1:
+ if cur_H > 4:
+ ks[0] = 4
+ st[0] = 2
+ else:
+ ks[0] = cur_H
+ st[0] = cur_H
+
+ temporal_pool.append(ConvNormRelu(
+ in_channels=cur_C,
+ out_channels=min(self.C_out, cur_C * 2),
+ type='2d',
+ kernel_size=tuple(ks),
+ stride=tuple(st)
+ ))
+ cur_C = min(cur_C * 2, self.C_out)
+
+ if cur_H > 1:
+ if cur_H > 4:
+ cur_H //= 2
+ else:
+ cur_H = 1
+
+ num_layers += 1
+
+ self.temporal_pool = nn.Sequential(*temporal_pool)
+ print("graph seq encoder info: temporal pool:", self.temporal_conv_info)
+ self.num_layers = num_layers
+ # need fc?
+
+ def forward(self, x):
+ '''
+ x: (B, C, T)
+ '''
+ B, C, T = x.shape
+ x = self.graph_encoder(x)
+ x = self.spatial_pool(x)
+ x = self.temporal_pool(x)
+ x = x.view(B, self.C_out)
+
+ return x
+
+
+class SeqDecoder2D(nn.Module):
+ '''
+ (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T)
+ '''
+
+ def __init__(self):
+ super(SeqDecoder2D, self).__init__()
+ raise NotImplementedError
+
+
+class SeqDecoder1D(nn.Module):
+ '''
+ (B, D)->(B, D, 1)->...->(B, C_out, T)
+ '''
+
+ def __init__(self,
+ D_in,
+ C_out,
+ T_out,
+ min_layer_num=None
+ ):
+ super(SeqDecoder1D, self).__init__()
+ self.T_out = T_out
+ self.min_layer_num = min_layer_num
+
+ cur_t = 1
+
+ self.pre_conv = ConvNormRelu(
+ in_channels=D_in,
+ out_channels=C_out,
+ type='1d'
+ )
+ self.num_layers = 1
+ self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
+ self.conv_layers = nn.ModuleList([])
+ cur_t *= 2
+ while cur_t <= T_out:
+ self.conv_layers.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d'
+ ))
+ cur_t *= 2
+ self.num_layers += 1
+
+ post_conv = nn.ModuleList([ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d'
+ )])
+ self.num_layers += 1
+ if min_layer_num is not None and self.num_layers < min_layer_num:
+ while self.num_layers < min_layer_num:
+ post_conv.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d'
+ ))
+ self.num_layers += 1
+ self.post_conv = nn.Sequential(*post_conv)
+
+ def forward(self, x):
+
+ x = x.unsqueeze(-1)
+ x = self.pre_conv(x)
+ for conv in self.conv_layers:
+ x = self.upconv(x)
+ x = conv(x)
+
+ x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest')
+ x = self.post_conv(x)
+ return x
+
+
+class SeqDecoderRNN(nn.Module):
+ '''
+ (B, D)->(B, C_out, T)
+ '''
+
+ def __init__(self,
+ hidden_size,
+ C_out,
+ T_out,
+ num_layers,
+ rnn_cell='gru'
+ ):
+ super(SeqDecoderRNN, self).__init__()
+ self.num_steps = T_out
+ if rnn_cell == 'gru':
+ self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ elif rnn_cell == 'lstm':
+ self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ else:
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
+
+ self.fc = nn.Linear(hidden_size, C_out)
+
+ def forward(self, hidden, frame_0):
+ frame_0 = frame_0.permute(0, 2, 1)
+ dec_input = frame_0
+ outputs = []
+ for i in range(self.num_steps):
+ frame_out, hidden = self.cell(dec_input, hidden)
+ frame_out = self.fc(frame_out)
+ dec_input = frame_out
+ outputs.append(frame_out)
+ output = torch.cat(outputs, dim=1)
+ return output.permute(0, 2, 1)
+
+
+class SeqTranslator2D(nn.Module):
+ '''
+ (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out)
+ '''
+
+ def __init__(self,
+ C_in=64,
+ C_out=108,
+ T_in=75,
+ T_out=25,
+ residual=True
+ ):
+ super(SeqTranslator2D, self).__init__()
+ print("Warning: hard coded")
+ self.C_in = C_in
+ self.C_out = C_out
+ self.T_in = T_in
+ self.T_out = T_out
+ self.residual = residual
+
+ self.conv_layers = nn.Sequential(
+ ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1),
+ ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
+ ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
+
+ ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)),
+ ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
+ ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
+
+ ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)),
+ ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)),
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
+
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1),
+ )
+
+ def forward(self, x):
+ assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in
+ x = x.view(x.shape[0], 1, x.shape[1], x.shape[2])
+ x = self.conv_layers(x)
+ x = x.squeeze(2)
+ return x
+
+
+class SeqTranslator1D(nn.Module):
+ '''
+ (B, C, T)->(B, C_out, T)
+ '''
+
+ def __init__(self,
+ C_in,
+ C_out,
+ kernel_size=None,
+ stride=None,
+ min_layers_num=None,
+ residual=True,
+ norm='bn'
+ ):
+ super(SeqTranslator1D, self).__init__()
+
+ conv_layers = nn.ModuleList([])
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_in,
+ out_channels=C_out,
+ type='1d',
+ kernel_size=kernel_size,
+ stride=stride,
+ residual=residual,
+ norm=norm
+ ))
+ self.num_layers = 1
+ if min_layers_num is not None and self.num_layers < min_layers_num:
+ while self.num_layers < min_layers_num:
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d',
+ kernel_size=kernel_size,
+ stride=stride,
+ residual=residual,
+ norm=norm
+ ))
+ self.num_layers += 1
+ self.conv_layers = nn.Sequential(*conv_layers)
+
+ def forward(self, x):
+ return self.conv_layers(x)
+
+
+class SeqTranslatorRNN(nn.Module):
+ '''
+ (B, C, T)->(B, C_out, T)
+ LSTM-FC
+ '''
+
+ def __init__(self,
+ C_in,
+ C_out,
+ hidden_size,
+ num_layers,
+ rnn_cell='gru'
+ ):
+ super(SeqTranslatorRNN, self).__init__()
+
+ if rnn_cell == 'gru':
+ self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ elif rnn_cell == 'lstm':
+ self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
+ bidirectional=False)
+ else:
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
+
+ self.fc = nn.Linear(hidden_size, C_out)
+
+ def forward(self, x, frame_0):
+
+ num_steps = x.shape[-1]
+ x = x.permute(0, 2, 1)
+ frame_0 = frame_0.permute(0, 2, 1)
+ _, hidden = self.enc_cell(x, None)
+
+ outputs = []
+ for i in range(num_steps):
+ inputs = frame_0
+ output_frame, hidden = self.dec_cell(inputs, hidden)
+ output_frame = self.fc(output_frame)
+ frame_0 = output_frame
+ outputs.append(output_frame)
+ outputs = torch.cat(outputs, dim=1)
+ return outputs.permute(0, 2, 1)
+
+
+class ResBlock(nn.Module):
+ def __init__(self,
+ input_dim,
+ fc_dim,
+ afn,
+ nfn
+ ):
+ '''
+ afn: activation fn
+ nfn: normalization fn
+ '''
+ super(ResBlock, self).__init__()
+
+ self.input_dim = input_dim
+ self.fc_dim = fc_dim
+ self.afn = afn
+ self.nfn = nfn
+
+ if self.afn != 'relu':
+ raise ValueError('Wrong')
+
+ if self.nfn == 'layer_norm':
+ raise ValueError('wrong')
+
+ self.layers = nn.Sequential(
+ nn.Linear(self.input_dim, self.fc_dim // 2),
+ nn.ReLU(),
+ nn.Linear(self.fc_dim // 2, self.fc_dim // 2),
+ nn.ReLU(),
+ nn.Linear(self.fc_dim // 2, self.fc_dim),
+ nn.ReLU()
+ )
+
+ self.shortcut_layer = nn.Sequential(
+ nn.Linear(self.input_dim, self.fc_dim),
+ nn.ReLU(),
+ )
+
+ def forward(self, inputs):
+ return self.layers(inputs) + self.shortcut_layer(inputs)
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False):
+ super(AudioEncoder, self).__init__()
+ self.in_channels = channels[0]
+ self.augmentation = augmentation
+
+ model = []
+ acti = nn.LeakyReLU(0.2)
+
+ nr_layer = len(channels) - 1
+
+ for i in range(nr_layer):
+ if conv_pool is None:
+ model.append(nn.ReflectionPad1d(padding))
+ model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
+ model.append(acti)
+ else:
+ model.append(nn.ReflectionPad1d(padding))
+ model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
+ model.append(acti)
+ model.append(conv_pool(kernel_size=2, stride=2))
+
+ if self.augmentation:
+ model.append(
+ nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride)
+ )
+ model.append(acti)
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+
+ x = x[:, :self.in_channels, :]
+ x = self.model(x)
+ return x
+
+
+class AudioDecoder(nn.Module):
+ def __init__(self, channels, kernel_size=7, ups=25):
+ super(AudioDecoder, self).__init__()
+
+ model = []
+ pad = (kernel_size - 1) // 2
+ acti = nn.LeakyReLU(0.2)
+
+ for i in range(len(channels) - 2):
+ model.append(nn.Upsample(scale_factor=2, mode='nearest'))
+ model.append(nn.ReflectionPad1d(pad))
+ model.append(nn.Conv1d(channels[i], channels[i + 1],
+ kernel_size=kernel_size, stride=1))
+ if i == 0 or i == 1:
+ model.append(nn.Dropout(p=0.2))
+ if not i == len(channels) - 2:
+ model.append(acti)
+
+ model.append(nn.Upsample(size=ups, mode='nearest'))
+ model.append(nn.ReflectionPad1d(pad))
+ model.append(nn.Conv1d(channels[-2], channels[-1],
+ kernel_size=kernel_size, stride=1))
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class Audio2Pose(nn.Module):
+ def __init__(self, pose_dim, embed_size, augmentation, ups=25):
+ super(Audio2Pose, self).__init__()
+ self.pose_dim = pose_dim
+ self.embed_size = embed_size
+ self.augmentation = augmentation
+
+ self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1,
+ conv_pool=nn.AvgPool1d, augmentation=self.augmentation)
+ if self.augmentation:
+ self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim])
+ else:
+ self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups)
+
+ if self.augmentation:
+ self.pose_enc = nn.Sequential(
+ nn.Linear(self.embed_size // 2, 256),
+ nn.LayerNorm(256)
+ )
+
+ def forward(self, audio_feat, dec_input=None):
+
+ B = audio_feat.shape[0]
+
+ aud_embed = self.aud_enc.forward(audio_feat)
+
+ if self.augmentation:
+ dec_input = dec_input.squeeze(0)
+ dec_embed = self.pose_enc(dec_input)
+ dec_embed = dec_embed.unsqueeze(2)
+ dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1])
+ aud_embed = torch.cat([aud_embed, dec_embed], dim=1)
+
+ out = self.aud_dec.forward(aud_embed)
+ return out
+
+
+if __name__ == '__main__':
+ import numpy as np
+ import os
+ import sys
+
+ test_model = SeqEncoder2D(
+ C_in=2,
+ T_in=25,
+ C_out=512,
+ num_joints=54,
+ )
+ print(test_model.num_layers)
+
+ input = torch.randn((64, 108, 25))
+ output = test_model(input)
+ print(output.shape)
\ No newline at end of file
diff --git a/nets/smplx_body_pixel.py b/nets/smplx_body_pixel.py
new file mode 100644
index 0000000000000000000000000000000000000000..02bb6c672ecf18371f1ff0f16732c8c16db9f2a8
--- /dev/null
+++ b/nets/smplx_body_pixel.py
@@ -0,0 +1,326 @@
+import os
+import sys
+
+import torch
+from torch.optim.lr_scheduler import StepLR
+
+sys.path.append(os.getcwd())
+
+from nets.layers import *
+from nets.base import TrainWrapperBaseClass
+from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
+from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder
+from nets.spg.vqvae_1d import AudioEncoder
+from nets.utils import parse_audio, denormalize
+from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
+import numpy as np
+import torch.optim as optim
+import torch.nn.functional as F
+from sklearn.preprocessing import normalize
+
+from data_utils.lower_body import c_index, c_index_3d, c_index_6d
+from data_utils.utils import smooth_geom, get_mfcc_sepa
+
+
+class TrainWrapper(TrainWrapperBaseClass):
+ '''
+ a wrapper receving a batch from data_utils and calculate loss
+ '''
+
+ def __init__(self, args, config):
+ self.args = args
+ self.config = config
+ self.device = torch.device(self.args.gpu)
+ self.global_step = 0
+
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
+ self.expression = self.config.Data.pose.expression
+ self.epoch = 0
+ self.init_params()
+ self.num_classes = 4
+ self.audio = True
+ self.composition = self.config.Model.composition
+ self.bh_model = self.config.Model.bh_model
+
+ if self.audio:
+ self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
+ else:
+ self.audioencoder = None
+ if self.convert_to_6d:
+ dim, layer = 512, 10
+ else:
+ dim, layer = 256, 15
+ self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
+ self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+ self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+
+ model_path = self.config.Model.vq_path
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
+ self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
+
+ if torch.cuda.device_count() > 1:
+ self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1])
+ self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1])
+ self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1])
+ if self.audioencoder is not None:
+ self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1])
+
+ self.discriminator = None
+ if self.convert_to_6d:
+ self.c_index = c_index_6d
+ else:
+ self.c_index = c_index_3d
+
+ super().__init__(args, config)
+
+ def init_optimizer(self):
+
+ print('using Adam')
+ self.generator_optimizer = optim.Adam(
+ self.generator.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+ if self.audioencoder is not None:
+ opt = self.config.Model.AudioOpt
+ if opt == 'Adam':
+ self.audioencoder_optimizer = optim.Adam(
+ self.audioencoder.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+ else:
+ print('using SGD')
+ self.audioencoder_optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad,self.audioencoder.parameters()),
+ lr=self.config.Train.learning_rate.generator_learning_rate*10,
+ momentum=0.9,
+ nesterov=False,
+ )
+
+ def state_dict(self):
+ model_state = {
+ 'generator': self.generator.state_dict(),
+ 'generator_optim': self.generator_optimizer.state_dict(),
+ 'audioencoder': self.audioencoder.state_dict() if self.audio else None,
+ 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
+ }
+ return model_state
+
+ def load_state_dict(self, state_dict):
+
+ from collections import OrderedDict
+ new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.`
+ for k, v in state_dict.items():
+ sub_dict = OrderedDict()
+ if v is not None:
+ for k1, v1 in v.items():
+ name = k1.replace('module.', '')
+ sub_dict[name] = v1
+ new_state_dict[k] = sub_dict
+ state_dict = new_state_dict
+ if 'generator' in state_dict:
+ self.generator.load_state_dict(state_dict['generator'])
+ else:
+ self.generator.load_state_dict(state_dict)
+
+ if 'generator_optim' in state_dict and self.generator_optimizer is not None:
+ self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
+
+ if self.discriminator is not None:
+ self.discriminator.load_state_dict(state_dict['discriminator'])
+
+ if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
+ self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
+
+ if 'audioencoder' in state_dict and self.audioencoder is not None:
+ self.audioencoder.load_state_dict(state_dict['audioencoder'])
+
+ def init_params(self):
+ if self.config.Data.pose.convert_to_6d:
+ scale = 2
+ else:
+ scale = 1
+
+ global_orient = round(0 * scale)
+ leye_pose = reye_pose = round(0 * scale)
+ jaw_pose = round(0 * scale)
+ body_pose = round((63 - 24) * scale)
+ left_hand_pose = right_hand_pose = round(45 * scale)
+ if self.expression:
+ expression = 100
+ else:
+ expression = 0
+
+ b_j = 0
+ jaw_dim = jaw_pose
+ b_e = b_j + jaw_dim
+ eye_dim = leye_pose + reye_pose
+ b_b = b_e + eye_dim
+ body_dim = global_orient + body_pose
+ b_h = b_b + body_dim
+ hand_dim = left_hand_pose + right_hand_pose
+ b_f = b_h + hand_dim
+ face_dim = expression
+
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
+ self.pose = int(self.full_dim / round(3 * scale))
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
+
+ def __call__(self, bat):
+ # assert (not self.args.infer), "infer mode"
+ self.global_step += 1
+
+ total_loss = None
+ loss_dict = {}
+
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
+
+ id = bat['speaker'].to(self.device) - 20
+ # id = F.one_hot(id, self.num_classes)
+
+ poses = poses[:, self.c_index, :]
+
+ aud = aud.permute(0, 2, 1)
+ gt_poses = poses.permute(0, 2, 1)
+
+ with torch.no_grad():
+ self.g_body.eval()
+ self.g_hand.eval()
+ if torch.cuda.device_count() > 1:
+ _, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
+ _, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
+ else:
+ _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
+ _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
+ latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1)
+ latents = latents.detach()
+
+ if self.audio:
+ audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
+ logits = self.generator(latents[:, :], id, audio)
+ else:
+ logits = self.generator(latents, id)
+ logits = logits.permute(0, 2, 3, 1).contiguous()
+
+ self.generator_optimizer.zero_grad()
+ if self.audio:
+ self.audioencoder_optimizer.zero_grad()
+
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
+ loss.backward()
+
+ grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
+
+ if torch.isnan(grad).sum() > 0:
+ print('fuck')
+
+ loss_dict['grad'] = grad.item()
+ loss_dict['ce_loss'] = loss.item()
+ self.generator_optimizer.step()
+ if self.audio:
+ self.audioencoder_optimizer.step()
+
+ return total_loss, loss_dict
+
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None,
+ continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs):
+ '''
+ initial_pose: (B, C, T), normalized
+ (aud_fn, txgfile) -> generated motion (B, T, C)
+ '''
+ output = []
+
+ assert self.args.infer, "train mode"
+ self.generator.eval()
+ self.g_body.eval()
+ self.g_hand.eval()
+
+ if continuity:
+ aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps)
+ else:
+ aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am)
+ aud_feat = aud_feat.transpose(1, 0)
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
+
+ if id is None:
+ id = torch.tensor([0]).to(self.device)
+ else:
+ id = id.repeat(B)
+
+ with torch.no_grad():
+ aud_feat = aud_feat.permute(0, 2, 1)
+ if continuity:
+ self.audioencoder.eval()
+ pre_pose = {}
+ pre_pose['b'] = pre_pose['h'] = None
+ pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose)
+ pre_pose['b'] = body_0[:, :, -4:].transpose(1,2)
+ pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2)
+ _, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose)
+ body = torch.cat([body_0, body_1], dim=2)
+ hand = torch.cat([hand_0, hand_1], dim=2)
+
+ else:
+ if self.audio:
+ self.audioencoder.eval()
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio)
+ else:
+ latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B)
+
+ body_latents = latents[..., 0]
+ hand_latents = latents[..., 1]
+
+ body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
+ hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
+
+ pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy()
+
+ output = pred_poses
+
+ return output
+
+ def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None):
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio,
+ pre_latents=pre_latents, pre_audio=pre_audio)
+
+ body_latents = latents[..., 0]
+ hand_latents = latents[..., 1]
+
+ body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1],
+ latents=body_latents, pre_state=pre_pose['b'])
+ hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1],
+ latents=hand_latents, pre_state=pre_pose['h'])
+
+ return latents, audio, body, hand
+
+ def generate(self, aud, id, frame_num=0):
+
+ self.generator.eval()
+ self.g_body.eval()
+ self.g_hand.eval()
+ aud_feat = aud.permute(0, 2, 1)
+ if self.audio:
+ self.audioencoder.eval()
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio)
+ else:
+ latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0])
+
+ body_latents = latents[..., 0]
+ hand_latents = latents[..., 1]
+
+ body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
+ hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
+
+ pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2)
+ return pred_poses
diff --git a/nets/smplx_body_vq.py b/nets/smplx_body_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..95770ac251b873de26b4530f0a37fe43ed5e14f5
--- /dev/null
+++ b/nets/smplx_body_vq.py
@@ -0,0 +1,302 @@
+import os
+import sys
+
+from torch.optim.lr_scheduler import StepLR
+
+sys.path.append(os.getcwd())
+
+from nets.layers import *
+from nets.base import TrainWrapperBaseClass
+from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G
+from nets.spg.vqvae_1d import VQVAE as s2g_body
+from nets.utils import parse_audio, denormalize
+from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
+import numpy as np
+import torch.optim as optim
+import torch.nn.functional as F
+from sklearn.preprocessing import normalize
+
+from data_utils.lower_body import c_index, c_index_3d, c_index_6d
+
+
+class TrainWrapper(TrainWrapperBaseClass):
+ '''
+ a wrapper receving a batch from data_utils and calculate loss
+ '''
+
+ def __init__(self, args, config):
+ self.args = args
+ self.config = config
+ self.device = torch.device(self.args.gpu)
+ self.global_step = 0
+
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
+ self.expression = self.config.Data.pose.expression
+ self.epoch = 0
+ self.init_params()
+ self.num_classes = 4
+ self.composition = self.config.Model.composition
+ if self.composition:
+ self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+ self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+ else:
+ self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num,
+ num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
+
+ self.discriminator = None
+
+ if self.convert_to_6d:
+ self.c_index = c_index_6d
+ else:
+ self.c_index = c_index_3d
+
+ super().__init__(args, config)
+
+ def init_optimizer(self):
+ print('using Adam')
+ if self.composition:
+ self.g_body_optimizer = optim.Adam(
+ self.g_body.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+ self.g_hand_optimizer = optim.Adam(
+ self.g_hand.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+ else:
+ self.g_optimizer = optim.Adam(
+ self.g.parameters(),
+ lr=self.config.Train.learning_rate.generator_learning_rate,
+ betas=[0.9, 0.999]
+ )
+
+ def state_dict(self):
+ if self.composition:
+ model_state = {
+ 'g_body': self.g_body.state_dict(),
+ 'g_body_optim': self.g_body_optimizer.state_dict(),
+ 'g_hand': self.g_hand.state_dict(),
+ 'g_hand_optim': self.g_hand_optimizer.state_dict(),
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
+ }
+ else:
+ model_state = {
+ 'g': self.g.state_dict(),
+ 'g_optim': self.g_optimizer.state_dict(),
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
+ }
+ return model_state
+
+ def init_params(self):
+ if self.config.Data.pose.convert_to_6d:
+ scale = 2
+ else:
+ scale = 1
+
+ global_orient = round(0 * scale)
+ leye_pose = reye_pose = round(0 * scale)
+ jaw_pose = round(0 * scale)
+ body_pose = round((63 - 24) * scale)
+ left_hand_pose = right_hand_pose = round(45 * scale)
+ if self.expression:
+ expression = 100
+ else:
+ expression = 0
+
+ b_j = 0
+ jaw_dim = jaw_pose
+ b_e = b_j + jaw_dim
+ eye_dim = leye_pose + reye_pose
+ b_b = b_e + eye_dim
+ body_dim = global_orient + body_pose
+ b_h = b_b + body_dim
+ hand_dim = left_hand_pose + right_hand_pose
+ b_f = b_h + hand_dim
+ face_dim = expression
+
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
+ self.pose = int(self.full_dim / round(3 * scale))
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
+
+ def __call__(self, bat):
+ # assert (not self.args.infer), "infer mode"
+ self.global_step += 1
+
+ total_loss = None
+ loss_dict = {}
+
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
+
+ # id = bat['speaker'].to(self.device) - 20
+ # id = F.one_hot(id, self.num_classes)
+
+ poses = poses[:, self.c_index, :]
+ gt_poses = poses.permute(0, 2, 1)
+ b_poses = gt_poses[..., :self.each_dim[1]]
+ h_poses = gt_poses[..., self.each_dim[1]:]
+
+ if self.composition:
+ loss = 0
+ loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss)
+ loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss)
+ else:
+ loss = 0
+ loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
+
+ return total_loss, loss_dict
+
+ def vq_train(self, gt, name, model, dict, total_loss, pre=None):
+ e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre)
+ loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre)
+ # total_loss = total_loss + loss
+
+ if name == 'b':
+ optimizer_name = 'g_body_optimizer'
+ elif name == 'h':
+ optimizer_name = 'g_hand_optimizer'
+ elif name == 'g':
+ optimizer_name = 'g_optimizer'
+ else:
+ raise ValueError("model's name must be b or h")
+ optimizer = getattr(self, optimizer_name)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ for key in list(loss_dict.keys()):
+ dict[name + key] = loss_dict.get(key, 0).item()
+ return dict, total_loss
+
+ def get_loss(self,
+ pred_poses,
+ gt_poses,
+ e_q_loss,
+ pre=None
+ ):
+ loss_dict = {}
+
+
+ rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
+ v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
+ v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
+ velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
+
+ if pre is None:
+ f0_vel = 0
+ else:
+ v0_pr = pred_poses[:, 0] - pre[:, -1]
+ v0_gt = gt_poses[:, 0] - pre[:, -1]
+ f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
+
+ gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel
+
+ loss_dict['rec_loss'] = rec_loss
+ loss_dict['velocity_loss'] = velocity_loss
+ # loss_dict['e_q_loss'] = e_q_loss
+ if pre is not None:
+ loss_dict['f0_vel'] = f0_vel
+
+ return gen_loss, loss_dict
+
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False,
+ id=None, fps=15, sr=22000, smooth=False, **kwargs):
+ '''
+ initial_pose: (B, C, T), normalized
+ (aud_fn, txgfile) -> generated motion (B, T, C)
+ '''
+ output = []
+
+ assert self.args.infer, "train mode"
+ if self.composition:
+ self.g_body.eval()
+ self.g_hand.eval()
+ else:
+ self.g.eval()
+
+ if self.config.Data.pose.normalization:
+ assert norm_stats is not None
+ data_mean = norm_stats[0]
+ data_std = norm_stats[1]
+
+ # assert initial_pose.shape[-1] == pre_length
+ if initial_pose is not None:
+ gt = initial_pose[:, :, :].to(self.device).to(torch.float32)
+ pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32)
+ poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
+ B = pre_poses.shape[0]
+ else:
+ gt = None
+ pre_poses = None
+ B = 1
+
+ if type(aud_fn) == torch.Tensor:
+ aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device)
+ num_poses_to_generate = aud_feat.shape[-1]
+ else:
+ aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
+ aud_feat = aud_feat[:, :]
+ num_poses_to_generate = aud_feat.shape[-1]
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
+
+ # pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32)
+ if id is None:
+ id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device)
+
+ with torch.no_grad():
+ aud_feat = aud_feat.permute(0, 2, 1)
+ gt_poses = gt[:, self.c_index].permute(0, 2, 1)
+ if self.composition:
+ if continuity:
+ pred_poses_body = []
+ pred_poses_hand = []
+ pre_b = None
+ pre_h = None
+ for i in range(5):
+ _, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b)
+ pre_b = pred_body[..., -1:].transpose(1,2)
+ pred_poses_body.append(pred_body)
+ _, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h)
+ pre_h = pred_hand[..., -1:].transpose(1,2)
+ pred_poses_hand.append(pred_hand)
+
+ pred_poses_body = torch.cat(pred_poses_body, dim=2)
+ pred_poses_hand = torch.cat(pred_poses_hand, dim=2)
+ else:
+ _, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
+ _, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
+ pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1)
+ else:
+ _, pred_poses = self.g(gt_poses=gt_poses, id=id)
+ pred_poses = pred_poses.transpose(1, 2).cpu().numpy()
+ output = pred_poses
+
+ if self.config.Data.pose.normalization:
+ output = denormalize(output, data_mean, data_std)
+
+ if smooth:
+ lamda = 0.8
+ smooth_f = 10
+ frame = 149
+ for i in range(smooth_f):
+ f = frame + i
+ l = lamda * (i + 1) / smooth_f
+ output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f]
+
+ output = np.concatenate(output, axis=1)
+
+ return output
+
+ def load_state_dict(self, state_dict):
+ if self.composition:
+ self.g_body.load_state_dict(state_dict['g_body'])
+ self.g_hand.load_state_dict(state_dict['g_hand'])
+ else:
+ self.g.load_state_dict(state_dict['g'])
diff --git a/nets/smplx_face.py b/nets/smplx_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..e591b9dab674770b60655f607892b068f412d75a
--- /dev/null
+++ b/nets/smplx_face.py
@@ -0,0 +1,238 @@
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from nets.layers import *
+from nets.base import TrainWrapperBaseClass
+# from nets.spg.faceformer import Faceformer
+from nets.spg.s2g_face import Generator as s2g_face
+from losses import KeypointLoss
+from nets.utils import denormalize
+from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
+import numpy as np
+import torch.optim as optim
+import torch.nn.functional as F
+from sklearn.preprocessing import normalize
+import smplx
+
+
+class TrainWrapper(TrainWrapperBaseClass):
+ '''
+ a wrapper receving a batch from data_utils and calculate loss
+ '''
+
+ def __init__(self, args, config):
+ self.args = args
+ self.config = config
+ self.device = torch.device(self.args.gpu)
+ self.global_step = 0
+
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
+ self.expression = self.config.Data.pose.expression
+ self.epoch = 0
+ self.init_params()
+ self.num_classes = 4
+
+ self.generator = s2g_face(
+ n_poses=self.config.Data.pose.generate_length,
+ each_dim=self.each_dim,
+ dim_list=self.dim_list,
+ training=not self.args.infer,
+ device=self.device,
+ identity=False if self.convert_to_6d else True,
+ num_classes=self.num_classes,
+ ).to(self.device)
+
+ # self.generator = Faceformer().to(self.device)
+
+ self.discriminator = None
+ self.am = None
+
+ self.MSELoss = KeypointLoss().to(self.device)
+ super().__init__(args, config)
+
+ def init_optimizer(self):
+ self.generator_optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad,self.generator.parameters()),
+ lr=0.001,
+ momentum=0.9,
+ nesterov=False,
+ )
+
+ def init_params(self):
+ if self.convert_to_6d:
+ scale = 2
+ else:
+ scale = 1
+
+ global_orient = round(3 * scale)
+ leye_pose = reye_pose = round(3 * scale)
+ jaw_pose = round(3 * scale)
+ body_pose = round(63 * scale)
+ left_hand_pose = right_hand_pose = round(45 * scale)
+ if self.expression:
+ expression = 100
+ else:
+ expression = 0
+
+ b_j = 0
+ jaw_dim = jaw_pose
+ b_e = b_j + jaw_dim
+ eye_dim = leye_pose + reye_pose
+ b_b = b_e + eye_dim
+ body_dim = global_orient + body_pose
+ b_h = b_b + body_dim
+ hand_dim = left_hand_pose + right_hand_pose
+ b_f = b_h + hand_dim
+ face_dim = expression
+
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
+ self.pose = int(self.full_dim / round(3 * scale))
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
+
+ def __call__(self, bat):
+ # assert (not self.args.infer), "infer mode"
+ self.global_step += 1
+
+ total_loss = None
+ loss_dict = {}
+
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
+ id = bat['speaker'].to(self.device) - 20
+ id = F.one_hot(id, self.num_classes)
+
+ aud = aud.permute(0, 2, 1)
+ gt_poses = poses.permute(0, 2, 1)
+
+ if self.expression:
+ expression = bat['expression'].to(self.device).to(torch.float32)
+ gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
+
+ pred_poses, _ = self.generator(
+ aud,
+ gt_poses,
+ id,
+ )
+
+ G_loss, G_loss_dict = self.get_loss(
+ pred_poses=pred_poses,
+ gt_poses=gt_poses,
+ pre_poses=None,
+ mode='training_G',
+ gt_conf=None,
+ aud=aud,
+ )
+
+ self.generator_optimizer.zero_grad()
+ G_loss.backward()
+ grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
+ loss_dict['grad'] = grad.item()
+ self.generator_optimizer.step()
+
+ for key in list(G_loss_dict.keys()):
+ loss_dict[key] = G_loss_dict.get(key, 0).item()
+
+ return total_loss, loss_dict
+
+ def get_loss(self,
+ pred_poses,
+ gt_poses,
+ pre_poses,
+ aud,
+ mode='training_G',
+ gt_conf=None,
+ exp=1,
+ gt_nzero=None,
+ pre_nzero=None,
+ ):
+ loss_dict = {}
+
+
+ [b_j, b_e, b_b, b_h, b_f] = self.dim_list
+
+ MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
+ if self.expression:
+ expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
+ else:
+ expl = 0
+
+ gen_loss = expl + MSELoss
+
+ loss_dict['MSELoss'] = MSELoss
+ if self.expression:
+ loss_dict['exp_loss'] = expl
+
+ return gen_loss, loss_dict
+
+ def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
+ '''
+ initial_pose: (B, C, T), normalized
+ (aud_fn, txgfile) -> generated motion (B, T, C)
+ '''
+ output = []
+
+ # assert self.args.infer, "train mode"
+ self.generator.eval()
+
+ if self.config.Data.pose.normalization:
+ assert norm_stats is not None
+ data_mean = norm_stats[0]
+ data_std = norm_stats[1]
+
+ # assert initial_pose.shape[-1] == pre_length
+ if initial_pose is not None:
+ gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
+ pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
+ poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
+ B = pre_poses.shape[0]
+ else:
+ gt = None
+ pre_poses=None
+ B = 1
+
+ if type(aud_fn) == torch.Tensor:
+ aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
+ num_poses_to_generate = aud_feat.shape[-1]
+ else:
+ aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
+ if frame is None:
+ frame = aud_feat.shape[2]*30//16000
+ #
+ if id is None:
+ id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
+ else:
+ id = F.one_hot(id, self.num_classes).to(self.generator.device)
+
+ with torch.no_grad():
+ pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
+ pred_poses = pred_poses.cpu().numpy()
+ output = pred_poses
+
+ if self.config.Data.pose.normalization:
+ output = denormalize(output, data_mean, data_std)
+
+ return output
+
+
+ def generate(self, wv2_feat, frame):
+ '''
+ initial_pose: (B, C, T), normalized
+ (aud_fn, txgfile) -> generated motion (B, T, C)
+ '''
+ output = []
+
+ # assert self.args.infer, "train mode"
+ self.generator.eval()
+
+ B = 1
+
+ id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
+ id = id.repeat(wv2_feat.shape[0], 1)
+
+ with torch.no_grad():
+ pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
+ return pred_poses
diff --git a/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ae3f74945d1e857902656a7f3938c5929845ab9
Binary files /dev/null and b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc differ
diff --git a/nets/spg/__pycache__/s2g_face.cpython-37.pyc b/nets/spg/__pycache__/s2g_face.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dc7edd5b4082fcfcf0ef4d9f868b0b37ab1bbd4
Binary files /dev/null and b/nets/spg/__pycache__/s2g_face.cpython-37.pyc differ
diff --git a/nets/spg/__pycache__/s2glayers.cpython-37.pyc b/nets/spg/__pycache__/s2glayers.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cab33d662e58a9f0c9e8126e0d49b40d4760f74d
Binary files /dev/null and b/nets/spg/__pycache__/s2glayers.cpython-37.pyc differ
diff --git a/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70727ba6ce575c5829db4e47a04bf8813428d96a
Binary files /dev/null and b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc differ
diff --git a/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c8d4dd80c3387e10535349774171623f34c8a38
Binary files /dev/null and b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc differ
diff --git a/nets/spg/__pycache__/wav2vec.cpython-37.pyc b/nets/spg/__pycache__/wav2vec.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e935f81110e0674b8f634fe7bfb2495acae33154
Binary files /dev/null and b/nets/spg/__pycache__/wav2vec.cpython-37.pyc differ
diff --git a/nets/spg/gated_pixelcnn_v2.py b/nets/spg/gated_pixelcnn_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb0ec8c0c1dd4bc1e2b6e9f55fd1104b547b05f
--- /dev/null
+++ b/nets/spg/gated_pixelcnn_v2.py
@@ -0,0 +1,179 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ try:
+ nn.init.xavier_uniform_(m.weight.data)
+ m.bias.data.fill_(0)
+ except AttributeError:
+ print("Skipping initialization of ", classname)
+
+
+class GatedActivation(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ x, y = x.chunk(2, dim=1)
+ return F.tanh(x) * F.sigmoid(y)
+
+
+class GatedMaskedConv2d(nn.Module):
+ def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
+ super().__init__()
+ assert kernel % 2 == 1, print("Kernel size must be odd")
+ self.mask_type = mask_type
+ self.residual = residual
+ self.bh_model = bh_model
+
+ self.class_cond_embedding = nn.Embedding(n_classes, 2 * dim)
+ self.class_cond_embedding = self.class_cond_embedding.to("cpu")
+
+ kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
+ padding_shp = (kernel // 2, 1 if self.bh_model else 0)
+ self.vert_stack = nn.Conv2d(
+ dim, dim * 2,
+ kernel_shp, 1, padding_shp
+ )
+
+ self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
+
+ kernel_shp = (1, 2)
+ padding_shp = (0, 1)
+ self.horiz_stack = nn.Conv2d(
+ dim, dim * 2,
+ kernel_shp, 1, padding_shp
+ )
+
+ self.horiz_resid = nn.Conv2d(dim, dim, 1)
+
+ self.gate = GatedActivation()
+
+ def make_causal(self):
+ self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
+ self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
+
+ def forward(self, x_v, x_h, h):
+ if self.mask_type == 'A':
+ self.make_causal()
+
+ h = h.to(self.class_cond_embedding.weight.device)
+ h = self.class_cond_embedding(h)
+
+ h_vert = self.vert_stack(x_v)
+ h_vert = h_vert[:, :, :x_v.size(-2), :]
+ out_v = self.gate(h_vert + h[:, :, None, None])
+
+ if self.bh_model:
+ h_horiz = self.horiz_stack(x_h)
+ h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
+ v2h = self.vert_to_horiz(h_vert)
+
+ out = self.gate(v2h + h_horiz + h[:, :, None, None])
+ if self.residual:
+ out_h = self.horiz_resid(out) + x_h
+ else:
+ out_h = self.horiz_resid(out)
+ else:
+ if self.residual:
+ out_v = self.horiz_resid(out_v) + x_v
+ else:
+ out_v = self.horiz_resid(out_v)
+ out_h = out_v
+
+ return out_v, out_h
+
+
+class GatedPixelCNN(nn.Module):
+ def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
+ super().__init__()
+ self.dim = dim
+ self.audio = audio
+ self.bh_model = bh_model
+
+ if self.audio:
+ self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
+ self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
+ self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
+
+ # Create embedding layer to embed input
+ self.embedding = nn.Embedding(input_dim, dim)
+
+ # Building the PixelCNN layer by layer
+ self.layers = nn.ModuleList()
+
+ # Initial block with Mask-A convolution
+ # Rest with Mask-B convolutions
+ for i in range(n_layers):
+ mask_type = 'A' if i == 0 else 'B'
+ kernel = 7 if i == 0 else 3
+ residual = False if i == 0 else True
+
+ self.layers.append(
+ GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
+ )
+
+ # Add the output layer
+ self.output_conv = nn.Sequential(
+ nn.Conv2d(dim, 512, 1),
+ nn.ReLU(True),
+ nn.Conv2d(512, input_dim, 1)
+ )
+
+ self.apply(weights_init)
+
+ self.dp = nn.Dropout(0.1)
+ self.to("cpu")
+
+ def forward(self, x, label, aud=None):
+ shp = x.size() + (-1,)
+ x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
+ x = x.permute(0, 3, 1, 2) # (B, C, W, W)
+
+ x_v, x_h = (x, x)
+ for i, layer in enumerate(self.layers):
+ if i == 1 and self.audio is True:
+ aud = self.embedding_aud(aud)
+ a = torch.ones(aud.shape[-2]).to(aud.device)
+ a = self.dp(a)
+ aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
+ x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
+ if self.bh_model:
+ x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
+ x_v, x_h = layer(x_v, x_h, label)
+
+ if self.bh_model:
+ return self.output_conv(x_h)
+ else:
+ return self.output_conv(x_v)
+
+ def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
+ param = next(self.parameters())
+ x = torch.zeros(
+ (batch_size, *shape),
+ dtype=torch.int64, device=param.device
+ )
+ if pre_latents is not None:
+ x = torch.cat([pre_latents, x], dim=1)
+ aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
+ h0 = pre_latents.shape[1]
+ h = h0 + shape[0]
+ else:
+ h0 = 0
+ h = shape[0]
+
+ for i in range(h0, h):
+ for j in range(shape[1]):
+ if self.audio:
+ logits = self.forward(x, label, aud_feat)
+ else:
+ logits = self.forward(x, label)
+ probs = F.softmax(logits[:, :, i, j], -1)
+ x.data[:, i, j].copy_(
+ probs.multinomial(1).squeeze().data
+ )
+ return x[:, h0:h]
diff --git a/nets/spg/s2g_face.py b/nets/spg/s2g_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..b221df6c7bff0912640dfffe46267f8a131cc829
--- /dev/null
+++ b/nets/spg/s2g_face.py
@@ -0,0 +1,226 @@
+'''
+not exactly the same as the official repo but the results are good
+'''
+import sys
+import os
+
+from transformers import Wav2Vec2Processor
+
+from .wav2vec import Wav2Vec2Model
+from torchaudio.sox_effects import apply_effects_tensor
+
+sys.path.append(os.getcwd())
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio as ta
+import math
+from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
+
+
+""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
+
+
+def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
+ """
+ :param audio: 1 x T tensor containing a 16kHz audio signal
+ :param frame_rate: frame rate for video (we need one audio chunk per video frame)
+ :param chunk_size: number of audio samples per chunk
+ :return: num_chunks x chunk_size tensor containing sliced audio
+ """
+ samples_per_frame = 16000 // frame_rate
+ padding = (chunk_size - samples_per_frame) // 2
+ audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
+ anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
+ audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
+ return audio
+
+
+class MeshtalkEncoder(nn.Module):
+ def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
+ """
+ :param latent_dim: size of the latent audio embedding
+ :param model_name: name of the model, used to load and save the model
+ """
+ super().__init__()
+
+ self.melspec = ta.transforms.MelSpectrogram(
+ sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
+ )
+
+ conv_len = 5
+ self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
+ self.weights_init(self.convert_dimensions)
+ self.receptive_field = conv_len
+
+ convs = []
+ for i in range(6):
+ dilation = 2 * (i % 3 + 1)
+ self.receptive_field += (conv_len - 1) * dilation
+ convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
+ self.weights_init(convs[-1])
+ self.convs = torch.nn.ModuleList(convs)
+ self.code = torch.nn.Linear(128, latent_dim)
+
+ self.apply(lambda x: self.weights_init(x))
+
+ def weights_init(self, m):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.init.xavier_uniform_(m.weight)
+ try:
+ torch.nn.init.constant_(m.bias, .01)
+ except:
+ pass
+
+ def forward(self, audio: torch.Tensor):
+ """
+ :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
+ :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
+ """
+ B, T = audio.shape[0], audio.shape[1]
+ x = self.melspec(audio).squeeze(1)
+ x = torch.log(x.clamp(min=1e-10, max=None))
+ if T == 1:
+ x = x.unsqueeze(1)
+
+ # Convert to the right dimensionality
+ x = x.view(-1, x.shape[2], x.shape[3])
+ x = F.leaky_relu(self.convert_dimensions(x), .2)
+
+ # Process stacks
+ for conv in self.convs:
+ x_ = F.leaky_relu(conv(x), .2)
+ if self.training:
+ x_ = F.dropout(x_, .2)
+ l = (x.shape[2] - x_.shape[2]) // 2
+ x = (x[:, :, l:-l] + x_) / 2
+
+ x = torch.mean(x, dim=-1)
+ x = x.view(B, T, x.shape[-1])
+ x = self.code(x)
+
+ return {"code": x}
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
+ super().__init__()
+ self.identity = identity
+ if self.identity:
+ in_dim = in_dim + 64
+ self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
+ self.first_net = SeqTranslator1D(in_dim, out_dim,
+ min_layers_num=3,
+ residual=True,
+ norm='ln'
+ )
+ self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
+ self.dropout = nn.Dropout(0.1)
+ # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
+
+ def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
+
+ spectrogram = spectrogram
+ spectrogram = self.dropout(spectrogram)
+ if self.identity:
+ id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
+ id = self.id_mlp(id)
+ spectrogram = torch.cat([spectrogram, id], dim=1)
+ x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
+ if time_steps is not None:
+ x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
+ # x1, _ = self.att(x1, x1, x1)
+ # x1, hidden_state = self.grus(x1)
+ # x1 = x1.permute(0, 2, 1)
+ hidden_state=None
+
+ return x1, hidden_state
+
+
+class Generator(nn.Module):
+ def __init__(self,
+ n_poses,
+ each_dim: list,
+ dim_list: list,
+ training=False,
+ device=None,
+ identity=True,
+ num_classes=0,
+ ):
+ super().__init__()
+
+ self.training = training
+ self.device = device
+ self.gen_length = n_poses
+ self.identity = identity
+
+ norm = 'ln'
+ in_dim = 256
+ out_dim = 256
+
+ self.encoder_choice = 'faceformer'
+
+ if self.encoder_choice == 'meshtalk':
+ self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
+ elif self.encoder_choice == 'faceformer':
+ # wav2vec 2.0 weights initialization
+ self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
+ self.audio_encoder.feature_extractor._freeze_parameters()
+ self.audio_feature_map = nn.Linear(768, in_dim)
+ else:
+ self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
+
+ self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
+
+ self.dim_list = dim_list
+
+ self.decoder = nn.ModuleList()
+ self.final_out = nn.ModuleList()
+
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(out_dim, 64, norm=norm),
+ ConvNormRelu(64, 64, norm=norm),
+ ConvNormRelu(64, 64, norm=norm),
+ ))
+ self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
+
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(out_dim, out_dim, norm=norm),
+ ConvNormRelu(out_dim, out_dim, norm=norm),
+ ConvNormRelu(out_dim, out_dim, norm=norm),
+ ))
+ self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
+
+ def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
+ if self.training:
+ time_steps = gt_poses.shape[1]
+
+ # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
+ if self.encoder_choice == 'meshtalk':
+ in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
+ feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
+ elif self.encoder_choice == 'faceformer':
+ hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
+ feature = self.audio_feature_map(hidden_states).transpose(1, 2)
+ else:
+ feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
+
+ # hidden_states = in_spec
+
+ feature, _ = self.audio_middle(feature, id=id)
+
+ out = []
+
+ for i in range(self.decoder.__len__()):
+ mid = self.decoder[i](feature)
+ mid = self.final_out[i](mid)
+ out.append(mid)
+
+ out = torch.cat(out, dim=1)
+ out = out.transpose(1, 2)
+
+ return out, None
+
+
diff --git a/nets/spg/s2glayers.py b/nets/spg/s2glayers.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a439e6bc0c4973586d39f3b113aa3752ff077fa
--- /dev/null
+++ b/nets/spg/s2glayers.py
@@ -0,0 +1,522 @@
+'''
+not exactly the same as the official repo but the results are good
+'''
+import sys
+import os
+
+sys.path.append(os.getcwd())
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from nets.layers import SeqEncoder1D, SeqTranslator1D
+
+""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
+
+
+class Conv2d_tf(nn.Conv2d):
+ """
+ Conv2d with the padding behavior from TF
+ from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Conv2d_tf, self).__init__(*args, **kwargs)
+ self.padding = kwargs.get("padding", "SAME")
+
+ def _compute_padding(self, input, dim):
+ input_size = input.size(dim + 2)
+ filter_size = self.weight.size(dim + 2)
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
+ total_padding = max(
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
+ )
+ additional_padding = int(total_padding % 2 != 0)
+
+ return additional_padding, total_padding
+
+ def forward(self, input):
+ if self.padding == "VALID":
+ return F.conv2d(
+ input,
+ self.weight,
+ self.bias,
+ self.stride,
+ padding=0,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
+ cols_odd, padding_cols = self._compute_padding(input, dim=1)
+ if rows_odd or cols_odd:
+ input = F.pad(input, [0, cols_odd, 0, rows_odd])
+
+ return F.conv2d(
+ input,
+ self.weight,
+ self.bias,
+ self.stride,
+ padding=(padding_rows // 2, padding_cols // 2),
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+
+
+class Conv1d_tf(nn.Conv1d):
+ """
+ Conv1d with the padding behavior from TF
+ modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Conv1d_tf, self).__init__(*args, **kwargs)
+ self.padding = kwargs.get("padding")
+
+ def _compute_padding(self, input, dim):
+ input_size = input.size(dim + 2)
+ filter_size = self.weight.size(dim + 2)
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
+ total_padding = max(
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
+ )
+ additional_padding = int(total_padding % 2 != 0)
+
+ return additional_padding, total_padding
+
+ def forward(self, input):
+ # if self.padding == "valid":
+ # return F.conv1d(
+ # input,
+ # self.weight,
+ # self.bias,
+ # self.stride,
+ # padding=0,
+ # dilation=self.dilation,
+ # groups=self.groups,
+ # )
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
+ if rows_odd:
+ input = F.pad(input, [0, rows_odd])
+
+ return F.conv1d(
+ input,
+ self.weight,
+ self.bias,
+ self.stride,
+ padding=(padding_rows // 2),
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+
+
+def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1,
+ nonlinear='lrelu', bn='bn'):
+ if k is None and s is None:
+ if not downsample:
+ k = 3
+ s = 1
+ padding = 'same'
+ else:
+ k = 4
+ s = 2
+ padding = 'valid'
+
+ if type == '1d':
+ conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
+ norm_block = nn.BatchNorm1d(out_channels)
+ elif type == '2d':
+ conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
+ norm_block = nn.BatchNorm2d(out_channels)
+ else:
+ assert False
+ if bn != 'bn':
+ if bn == 'gn':
+ norm_block = nn.GroupNorm(1, out_channels)
+ elif bn == 'ln':
+ norm_block = nn.LayerNorm(out_channels)
+ else:
+ norm_block = nn.Identity()
+ if nonlinear == 'lrelu':
+ nlinear = nn.LeakyReLU(0.2, True)
+ elif nonlinear == 'tanh':
+ nlinear = nn.Tanh()
+ elif nonlinear == 'none':
+ nlinear = nn.Identity()
+
+ return nn.Sequential(
+ conv_block,
+ norm_block,
+ nlinear
+ )
+
+
+class UnetUp(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super(UnetUp, self).__init__()
+ self.conv = ConvNormRelu(in_ch, out_ch)
+
+ def forward(self, x1, x2):
+ # x1 = torch.repeat_interleave(x1, 2, dim=2)
+ # x1 = x1[:, :, :x2.shape[2]]
+ x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear')
+ x = x1 + x2
+ x = self.conv(x)
+ return x
+
+
+class UNet(nn.Module):
+ def __init__(self, input_dim, dim):
+ super(UNet, self).__init__()
+ # dim = 512
+ self.down1 = nn.Sequential(
+ ConvNormRelu(input_dim, input_dim, '1d', False),
+ ConvNormRelu(input_dim, dim, '1d', False),
+ ConvNormRelu(dim, dim, '1d', False)
+ )
+ self.gru = nn.GRU(dim, dim, 1, batch_first=True)
+ self.down2 = ConvNormRelu(dim, dim, '1d', True)
+ self.down3 = ConvNormRelu(dim, dim, '1d', True)
+ self.down4 = ConvNormRelu(dim, dim, '1d', True)
+ self.down5 = ConvNormRelu(dim, dim, '1d', True)
+ self.down6 = ConvNormRelu(dim, dim, '1d', True)
+ self.up1 = UnetUp(dim, dim)
+ self.up2 = UnetUp(dim, dim)
+ self.up3 = UnetUp(dim, dim)
+ self.up4 = UnetUp(dim, dim)
+ self.up5 = UnetUp(dim, dim)
+
+ def forward(self, x1, pre_pose=None, w_pre=False):
+ x2_0 = self.down1(x1)
+ if w_pre:
+ i = 1
+ x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1)
+ x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1)
+ # x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15]
+ else:
+ # x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2)
+ x2 = x2_0
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x6 = self.down5(x5)
+ x7 = self.down6(x6)
+ x = self.up1(x7, x6)
+ x = self.up2(x, x5)
+ x = self.up3(x, x4)
+ x = self.up4(x, x3)
+ x = self.up5(x, x2) # [B, 512, 15]
+ return x, x2_0
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, n_frames, template_length, pose=False, common_dim=512):
+ super().__init__()
+ self.n_frames = n_frames
+ self.pose = pose
+ self.step = 0
+ self.weight = 0
+ if self.pose:
+ # self.first_net = nn.Sequential(
+ # ConvNormRelu(1, 64, '2d', False),
+ # ConvNormRelu(64, 64, '2d', True),
+ # ConvNormRelu(64, 128, '2d', False),
+ # ConvNormRelu(128, 128, '2d', True),
+ # ConvNormRelu(128, 256, '2d', False),
+ # ConvNormRelu(256, 256, '2d', True),
+ # ConvNormRelu(256, 256, '2d', False),
+ # ConvNormRelu(256, 256, '2d', False, padding='VALID')
+ # )
+ # decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4,
+ # dim_feedforward=2 * args.feature_dim, batch_first=True)
+ # a = nn.TransformerDecoder
+ self.first_net = SeqTranslator1D(256, 256,
+ min_layers_num=4,
+ residual=True
+ )
+ self.dropout_0 = nn.Dropout(0.1)
+ self.mu_fc = nn.Conv1d(256, 128, 1, 1)
+ self.var_fc = nn.Conv1d(256, 128, 1, 1)
+ self.trans_motion = SeqTranslator1D(common_dim, common_dim,
+ kernel_size=1,
+ stride=1,
+ min_layers_num=3,
+ residual=True
+ )
+ # self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1)
+ self.unet = UNet(128 + template_length, common_dim)
+
+ else:
+ self.first_net = SeqTranslator1D(256, 256,
+ min_layers_num=4,
+ residual=True
+ )
+ self.dropout_0 = nn.Dropout(0.1)
+ # self.att = nn.MultiheadAttention(256, 4, dropout=0.1)
+ self.unet = UNet(256, 256)
+ self.dropout_1 = nn.Dropout(0.0)
+
+ def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False):
+ self.step = self.step + 1
+ if self.pose:
+ spect = spectrogram.transpose(1, 2)
+ if w_pre:
+ spect = spect[:, :, :]
+
+ out = self.first_net(spect)
+ out = self.dropout_0(out)
+
+ mu = self.mu_fc(out)
+ var = self.var_fc(out)
+ audio = self.__reparam(mu, var)
+ # audio = out
+
+ # template = self.trans_motion(template)
+ x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1)
+ # x1 = out
+ #x1, _ = self.att(x1, x1, x1)
+ #x1 = x1.permute(1,2,0)
+ x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre)
+ else:
+ spectrogram = spectrogram.transpose(1, 2)
+ x1 = self.first_net(spectrogram)#.permute(2,0,1)
+ #out, _ = self.att(out, out, out)
+ #out = out.permute(1, 2, 0)
+ x1 = self.dropout_0(x1)
+ x1, x2_0 = self.unet(x1)
+ x1 = self.dropout_1(x1)
+ mu = None
+ var = None
+
+ return x1, (mu, var), x2_0
+
+ def __reparam(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std, device='cuda')
+ z = eps * std + mu
+ return z
+
+
+class Generator(nn.Module):
+ def __init__(self,
+ n_poses,
+ pose_dim,
+ pose,
+ n_pre_poses,
+ each_dim: list,
+ dim_list: list,
+ use_template=False,
+ template_length=0,
+ training=False,
+ device=None,
+ separate=False,
+ expression=False
+ ):
+ super().__init__()
+
+ self.use_template = use_template
+ self.template_length = template_length
+ self.training = training
+ self.device = device
+ self.separate = separate
+ self.pose = pose
+ self.decoderf = True
+ self.expression = expression
+
+ common_dim = 256
+
+ if self.use_template:
+ assert template_length > 0
+ # self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device)
+ # self.pose_encoder = SeqEncoder1D(
+ # C_in=pose_dim,
+ # C_out=512,
+ # T_in=n_poses,
+ # min_layer_nums=6
+ #
+ # )
+ self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim,
+ # kernel_size=1,
+ # stride=1,
+ min_layers_num=3,
+ residual=True
+ )
+ self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
+ self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
+
+ else:
+ self.template_length = 0
+
+ self.gen_length = n_poses
+
+ self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim)
+ self.speech_encoder = AudioEncoder(n_poses, template_length, False)
+
+ # self.pre_pose_encoder = SeqEncoder1D(
+ # C_in=pose_dim,
+ # C_out=128,
+ # T_in=15,
+ # min_layer_nums=3
+ #
+ # )
+ # self.pmu_fc = nn.Linear(128, 64)
+ # self.pvar_fc = nn.Linear(128, 64)
+
+ self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim,
+ min_layers_num=5,
+ residual=True
+ )
+ self.decoder_in = 256 + 64
+ self.dim_list = dim_list
+
+ if self.separate:
+ self.decoder = nn.ModuleList()
+ self.final_out = nn.ModuleList()
+
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(256, 64),
+ ConvNormRelu(64, 64),
+ ConvNormRelu(64, 64),
+ ))
+ self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
+
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(common_dim, common_dim),
+ ConvNormRelu(common_dim, common_dim),
+ ConvNormRelu(common_dim, common_dim),
+ ))
+ self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1))
+
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(common_dim, common_dim),
+ ConvNormRelu(common_dim, common_dim),
+ ConvNormRelu(common_dim, common_dim),
+ ))
+ self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1))
+
+ if self.expression:
+ self.decoder.append(nn.Sequential(
+ ConvNormRelu(256, 256),
+ ConvNormRelu(256, 256),
+ ConvNormRelu(256, 256),
+ ))
+ self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1))
+ else:
+ self.decoder = nn.Sequential(
+ ConvNormRelu(self.decoder_in, 512),
+ ConvNormRelu(512, 512),
+ ConvNormRelu(512, 512),
+ ConvNormRelu(512, 512),
+ ConvNormRelu(512, 512),
+ ConvNormRelu(512, 512),
+ )
+ self.final_out = nn.Conv1d(512, pose_dim, 1, 1)
+
+ def __reparam(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std, device=self.device)
+ z = eps * std + mu
+ return z
+
+ def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True):
+ if time_steps is not None:
+ self.gen_length = time_steps
+
+ if self.use_template:
+ if self.training:
+ if w_pre:
+ in_spec = in_spec[:, 15:, :]
+ pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1))
+ pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1))
+ mu = self.mu_fc(pose_enc)
+ var = self.var_fc(pose_enc)
+ template = self.__reparam(mu, var)
+ else:
+ pre_pose = None
+ pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
+ mu = self.mu_fc(pose_enc)
+ var = self.var_fc(pose_enc)
+ template = self.__reparam(mu, var)
+ elif pre_poses is not None:
+ if w_pre:
+ pre_pose = pre_poses[:, -1:, :-50]
+ if norm:
+ pre_pose = pre_pose.reshape(1, -1, 55, 5)
+ pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1),
+ F.normalize(pre_pose[..., 3:5], dim=-1)],
+ dim=-1).reshape(1, -1, 275)
+ pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1))
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to(
+ in_spec.device)
+ else:
+ pre_pose = None
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
+ elif gt_poses is not None:
+ template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
+ elif template is None:
+ pre_pose = None
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
+ else:
+ template = None
+ mu = None
+ var = None
+
+ a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre)
+ s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps)
+
+ out = []
+
+ if self.separate:
+ for i in range(self.decoder.__len__()):
+ if i == 0 or i == 3:
+ mid = self.decoder[i](s_f)
+ else:
+ mid = self.decoder[i](a_t_f)
+ mid = self.final_out[i](mid)
+ out.append(mid)
+ out = torch.cat(out, dim=1)
+
+ else:
+ out = self.decoder(a_t_f)
+ out = self.final_out(out)
+
+ out = out.transpose(1, 2)
+
+ if self.training:
+ if w_pre:
+ return out, template, mu, var, (mu2, var2, x2_0, pre_pose)
+ else:
+ return out, template, mu, var, (mu2, var2, None, None)
+ else:
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, pose_dim, pose):
+ super().__init__()
+ self.net = nn.Sequential(
+ Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'),
+ nn.LeakyReLU(0.2, True),
+ ConvNormRelu(64, 128, '1d', True),
+ ConvNormRelu(128, 256, '1d', k=4, s=1),
+ Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'),
+ )
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+
+ out = self.net(x)
+ return out
+
+
+def main():
+ d = Discriminator(275, 55)
+ x = torch.randn([8, 60, 275])
+ result = d(x)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nets/spg/vqvae_1d.py b/nets/spg/vqvae_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd15bd6439b949bf89098af274b3e7ccac9b5f5
--- /dev/null
+++ b/nets/spg/vqvae_1d.py
@@ -0,0 +1,235 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .wav2vec import Wav2Vec2Model
+from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
+
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
+ super(AudioEncoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+ self._num_residual_hiddens = num_residual_hiddens
+
+ self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
+
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
+ self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
+ sample='down')
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
+ self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+
+ def forward(self, x, frame_num=0):
+ h = self.project(x)
+ h = self._enc_1(h)
+ h = self._down_1(h)
+ h = self._enc_2(h)
+ h = self._down_2(h)
+ h = self._enc_3(h)
+ return h
+
+
+class Wav2VecEncoder(nn.Module):
+ def __init__(self, num_hiddens, num_residual_layers):
+ super(Wav2VecEncoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(
+ "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
+ self.audio_encoder.feature_extractor._freeze_parameters()
+
+ self.project = ConvNormRelu(768, self._num_hiddens, leaky=True)
+
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+ self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+ self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+
+ def forward(self, x, frame_num):
+ h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2)
+ h = self.project(h)
+ h = self._enc_1(h)
+ h = self._down_1(h)
+ h = self._enc_2(h)
+ h = self._down_2(h)
+ h = self._enc_3(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
+ super(Encoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+ self._num_residual_hiddens = num_residual_hiddens
+
+ self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
+
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
+ self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
+ sample='down')
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
+ self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+
+ self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
+
+ def forward(self, x):
+ h = self.project(x)
+ h = self._enc_1(h)
+ h = self._down_1(h)
+ h = self._enc_2(h)
+ h = self._down_2(h)
+ h = self._enc_3(h)
+ h = self.pre_vq_conv(h)
+ return h
+
+
+class Frame_Enc(nn.Module):
+ def __init__(self, in_dim, num_hiddens):
+ super(Frame_Enc, self).__init__()
+ self.in_dim = in_dim
+ self.num_hiddens = num_hiddens
+
+ # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4)
+ self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1)
+ self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True)
+ self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1)
+ self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1)
+
+ def forward(self, x):
+ # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1)
+ x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1)
+ second_last = self.proj_2(x)
+ last = self.proj_1(x)
+ return second_last, last
+
+
+
+class Decoder(nn.Module):
+ def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
+ super(Decoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+ self._num_residual_hiddens = num_residual_hiddens
+
+ self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
+
+ self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
+ self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
+ self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
+ self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
+ sample='up')
+ self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
+
+ if ae:
+ self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4)
+ self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True)
+ self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True)
+
+ self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
+
+ def forward(self, h, last_frame=None):
+
+ h = self.aft_vq_conv(h)
+ h = self._dec_1(h)
+ h = self._up_2(h)
+ h = self._dec_2(h)
+ h = self._up_3(h)
+ h = self._dec_3(h)
+
+ recon = self.project(h)
+ return recon, None
+
+
+class Pre_VQ(nn.Module):
+ def __init__(self, num_hiddens, embedding_dim, num_chunks):
+ super(Pre_VQ, self).__init__()
+ self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks)
+ self.bn = nn.GroupNorm(num_chunks, num_hiddens)
+ self.relu = nn.ReLU()
+ self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ x = self.proj(x)
+ return x
+
+
+class VQVAE(nn.Module):
+ """VQ-VAE"""
+
+ def __init__(self, in_dim, embedding_dim, num_embeddings,
+ num_hiddens, num_residual_layers, num_residual_hiddens,
+ commitment_cost=0.25, decay=0.99, share=False):
+ super().__init__()
+ self.in_dim = in_dim
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.share_code_vq = share
+
+ self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
+ self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
+ self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
+
+ def forward(self, gt_poses, id=None, pre_state=None):
+ z = self.encoder(gt_poses.transpose(1, 2))
+ if not self.training:
+ e, _ = self.vq_layer(z)
+ x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
+ return e, x_recon
+
+ e, e_q_loss = self.vq_layer(z)
+ gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
+
+ return e_q_loss, gt_recon.transpose(1, 2)
+
+ def encode(self, gt_poses, id=None):
+ z = self.encoder(gt_poses.transpose(1, 2))
+ e, latents = self.vq_layer(z)
+ return e, latents
+
+ def decode(self, b, w, e=None, latents=None, pre_state=None):
+ if e is not None:
+ x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
+ else:
+ e = self.vq_layer.quantize(latents)
+ e = e.view(b, w, -1).permute(0, 2, 1).contiguous()
+ x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
+ return x
+
+
+class AE(nn.Module):
+ """VQ-VAE"""
+
+ def __init__(self, in_dim, embedding_dim, num_embeddings,
+ num_hiddens, num_residual_layers, num_residual_hiddens):
+ super().__init__()
+ self.in_dim = in_dim
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+
+ self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
+ self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True)
+
+ def forward(self, gt_poses, id=None, pre_state=None):
+ z = self.encoder(gt_poses.transpose(1, 2))
+ if not self.training:
+ x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
+ return z, x_recon
+ gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
+
+ return gt_recon.transpose(1, 2)
+
+ def encode(self, gt_poses, id=None):
+ z = self.encoder(gt_poses.transpose(1, 2))
+ return z
diff --git a/nets/spg/vqvae_modules.py b/nets/spg/vqvae_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c83bc0399bc3bc034881407ed49d223d8c86ba9
--- /dev/null
+++ b/nets/spg/vqvae_modules.py
@@ -0,0 +1,380 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import datasets, transforms
+import matplotlib.pyplot as plt
+
+
+
+
+class CasualCT(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ leaky=False,
+ p=0,
+ groups=1, ):
+ '''
+ conv-bn-relu
+ '''
+ super(CasualCT, self).__init__()
+ padding = 0
+ kernel_size = 2
+ stride = 2
+ in_channels = in_channels * groups
+ out_channels = out_channels * groups
+
+ self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm1d(out_channels)
+ self.dropout = nn.Dropout(p=p)
+ if leaky:
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
+ else:
+ self.relu = nn.ReLU()
+
+ def forward(self, x, **kwargs):
+ out = self.norm(self.dropout(self.conv(x)))
+ return self.relu(out)
+
+
+class CasualConv(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ leaky=False,
+ p=0,
+ groups=1,
+ downsample=False):
+ '''
+ conv-bn-relu
+ '''
+ super(CasualConv, self).__init__()
+ padding = 0
+ kernel_size = 2
+ stride = 1
+ self.downsample = downsample
+ if self.downsample:
+ kernel_size = 2
+ stride = 2
+
+ in_channels = in_channels * groups
+ out_channels = out_channels * groups
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm1d(out_channels)
+ self.dropout = nn.Dropout(p=p)
+ if leaky:
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
+ else:
+ self.relu = nn.ReLU()
+
+ def forward(self, x, pre_state=None):
+ if not self.downsample:
+ if pre_state is not None:
+ x = torch.cat([pre_state, x], dim=-1)
+ else:
+ zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device)
+ x = torch.cat([zeros, x], dim=-1)
+ out = self.norm(self.dropout(self.conv(x)))
+ return self.relu(out)
+
+
+class ConvNormRelu(nn.Module):
+ '''
+ (B,C_in,H,W) -> (B, C_out, H, W)
+ there exist some kernel size that makes the result is not H/s
+ #TODO: there might some problems with residual
+ '''
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ leaky=False,
+ sample='none',
+ p=0,
+ groups=1,
+ residual=False,
+ norm='bn'):
+ '''
+ conv-bn-relu
+ '''
+ super(ConvNormRelu, self).__init__()
+ self.residual = residual
+ self.norm_type = norm
+ padding = 1
+
+ if sample == 'none':
+ kernel_size = 3
+ stride = 1
+ elif sample == 'one':
+ padding = 0
+ kernel_size = stride = 1
+ else:
+ kernel_size = 4
+ stride = 2
+
+ if self.residual:
+ if sample == 'down':
+ self.residual_layer = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ elif sample == 'up':
+ self.residual_layer = nn.ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ else:
+ if in_channels == out_channels:
+ self.residual_layer = nn.Identity()
+ else:
+ self.residual_layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+
+ in_channels = in_channels * groups
+ out_channels = out_channels * groups
+ if sample == 'up':
+ self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ else:
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm1d(out_channels)
+ self.dropout = nn.Dropout(p=p)
+ if leaky:
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
+ else:
+ self.relu = nn.ReLU()
+
+ def forward(self, x, **kwargs):
+ out = self.norm(self.dropout(self.conv(x)))
+ if self.residual:
+ residual = self.residual_layer(x)
+ out += residual
+ return self.relu(out)
+
+
+class Res_CNR_Stack(nn.Module):
+ def __init__(self,
+ channels,
+ layers,
+ sample='none',
+ leaky=False,
+ casual=False,
+ ):
+ super(Res_CNR_Stack, self).__init__()
+
+ if casual:
+ kernal_size = 1
+ padding = 0
+ conv = CasualConv
+ else:
+ kernal_size = 3
+ padding = 1
+ conv = ConvNormRelu
+
+ if sample == 'one':
+ kernal_size = 1
+ padding = 0
+
+ self._layers = nn.ModuleList()
+ for i in range(layers):
+ self._layers.append(conv(channels, channels, leaky=leaky, sample=sample))
+ self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding)
+ self.norm = nn.BatchNorm1d(channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, x, pre_state=None):
+ # cur_state = []
+ h = x
+ for i in range(self._layers.__len__()):
+ # cur_state.append(h[..., -1:])
+ h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None)
+ h = self.norm(self.conv(h))
+ return self.relu(h + x)
+
+
+class ExponentialMovingAverage(nn.Module):
+ """Maintains an exponential moving average for a value.
+
+ This module keeps track of a hidden exponential moving average that is
+ initialized as a vector of zeros which is then normalized to give the average.
+ This gives us a moving average which isn't biased towards either zero or the
+ initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)
+
+ Initially:
+ hidden_0 = 0
+ Then iteratively:
+ hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay)
+ average_i = hidden_i / (1 - decay^i)
+ """
+
+ def __init__(self, init_value, decay):
+ super().__init__()
+
+ self.decay = decay
+ self.counter = 0
+ self.register_buffer("hidden", torch.zeros_like(init_value))
+
+ def forward(self, value):
+ self.counter += 1
+ self.hidden.sub_((self.hidden - value) * (1 - self.decay))
+ average = self.hidden / (1 - self.decay ** self.counter)
+ return average
+
+
+class VectorQuantizerEMA(nn.Module):
+ """
+ VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings.
+ Args:
+ embedding_dim (int): the dimensionality of the tensors in the
+ quantized space. Inputs to the modules must be in this format as well.
+ num_embeddings (int): the number of vectors in the quantized space.
+ commitment_cost (float): scalar which controls the weighting of the loss terms (see
+ equation 4 in the paper - this variable is Beta).
+ decay (float): decay for the moving averages.
+ epsilon (float): small float constant to avoid numerical instability.
+ """
+
+ def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay,
+ epsilon=1e-5):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.commitment_cost = commitment_cost
+ self.epsilon = epsilon
+
+ # initialize embeddings as buffers
+ embeddings = torch.empty(self.num_embeddings, self.embedding_dim)
+ nn.init.xavier_uniform_(embeddings)
+ self.register_buffer("embeddings", embeddings)
+ self.ema_dw = ExponentialMovingAverage(self.embeddings, decay)
+
+ # also maintain ema_cluster_size, which record the size of each embedding
+ self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay)
+
+ def forward(self, x):
+ # [B, C, H, W] -> [B, H, W, C]
+ x = x.permute(0, 2, 1).contiguous()
+ # [B, H, W, C] -> [BHW, C]
+ flat_x = x.reshape(-1, self.embedding_dim)
+
+ encoding_indices = self.get_code_indices(flat_x)
+ quantized = self.quantize(encoding_indices)
+ quantized = quantized.view_as(x) # [B, W, C]
+
+ if not self.training:
+ quantized = quantized.permute(0, 2, 1).contiguous()
+ return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2])
+
+ # update embeddings with EMA
+ with torch.no_grad():
+ encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
+ updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0))
+ n = torch.sum(updated_ema_cluster_size)
+ updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
+ (n + self.num_embeddings * self.epsilon) * n)
+ dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster
+ updated_ema_dw = self.ema_dw(dw)
+ normalised_updated_ema_w = (
+ updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1))
+ self.embeddings.data = normalised_updated_ema_w
+
+ # commitment loss
+ e_latent_loss = F.mse_loss(x, quantized.detach())
+ loss = self.commitment_cost * e_latent_loss
+
+ # Straight Through Estimator
+ quantized = x + (quantized - x).detach()
+
+ quantized = quantized.permute(0, 2, 1).contiguous()
+ return quantized, loss
+
+ def get_code_indices(self, flat_x):
+ # compute L2 distance
+ distances = (
+ torch.sum(flat_x ** 2, dim=1, keepdim=True) +
+ torch.sum(self.embeddings ** 2, dim=1) -
+ 2. * torch.matmul(flat_x, self.embeddings.t())
+ ) # [N, M]
+ encoding_indices = torch.argmin(distances, dim=1) # [N,]
+ return encoding_indices
+
+ def quantize(self, encoding_indices):
+ """Returns embedding tensor for a batch of indices."""
+ return F.embedding(encoding_indices, self.embeddings)
+
+
+
+class Casual_Encoder(nn.Module):
+ def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
+ super(Casual_Encoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+ self._num_residual_hiddens = num_residual_hiddens
+
+ self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1)
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
+ self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True)
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
+ self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True)
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
+ # self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
+
+ def forward(self, x):
+ h = self.project(x)
+ h, _ = self._enc_1(h)
+ h = self._down_1(h)
+ h, _ = self._enc_2(h)
+ h = self._down_2(h)
+ h, _ = self._enc_3(h)
+ # h = self.pre_vq_conv(h)
+ return h
+
+
+class Casual_Decoder(nn.Module):
+ def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
+ super(Casual_Decoder, self).__init__()
+ self._num_hiddens = num_hiddens
+ self._num_residual_layers = num_residual_layers
+ self._num_residual_hiddens = num_residual_hiddens
+
+ # self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
+ self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
+ self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True)
+ self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
+ self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True)
+ self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
+ self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1)
+
+ def forward(self, h, pre_state=None):
+ cur_state = []
+ # h = self.aft_vq_conv(x)
+ h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None)
+ cur_state.append(s)
+ h = self._up_2(h)
+ h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None)
+ cur_state.append(s)
+ h = self._up_3(h)
+ h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None)
+ cur_state.append(s)
+ recon = self.project(h)
+ return recon, cur_state
\ No newline at end of file
diff --git a/nets/spg/wav2vec.py b/nets/spg/wav2vec.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5d0eff66e67de14ceba283fa6ce43f156c7ddc2
--- /dev/null
+++ b/nets/spg/wav2vec.py
@@ -0,0 +1,143 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import copy
+import math
+from transformers import Wav2Vec2Model,Wav2Vec2Config
+from transformers.modeling_outputs import BaseModelOutput
+from typing import Optional, Tuple
+_CONFIG_FOR_DOC = "Wav2Vec2Config"
+
+# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
+# initialize our encoder with the pre-trained wav2vec 2.0 weights.
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+ mask_idcs = []
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ lengths = np.full(num_mask, mask_length)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+ return mask
+
+# linear interpolation layer
+def linear_interpolation(features, input_fps, output_fps, output_len=None):
+ features = features.transpose(1, 2)
+ seq_len = features.shape[2] / float(input_fps)
+ if output_len is None:
+ output_len = int(seq_len * output_fps)
+ output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear')
+ return output_features.transpose(1, 2)
+
+
+class Wav2Vec2Model(Wav2Vec2Model):
+ def __init__(self, config):
+ super().__init__(config)
+ def forward(
+ self,
+ input_values,
+ attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ frame_num=None
+ ):
+ self.config.output_attentions = True
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.feature_extractor(input_values)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num)
+
+ if attention_mask is not None:
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ attention_mask = torch.zeros(
+ hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
+ )
+ attention_mask[
+ (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
+ ] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+
+ hidden_states = self.feature_projection(hidden_states)
+
+ if self.config.apply_spec_augment and self.training:
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ if self.config.mask_time_prob > 0:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ self.config.mask_time_prob,
+ self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=2,
+ )
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
+ if self.config.mask_feature_prob > 0:
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ self.config.mask_feature_prob,
+ self.config.mask_feature_length,
+ )
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
+ encoder_outputs = self.encoder(
+ hidden_states[0],
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = encoder_outputs[0]
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/nets/utils.py b/nets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b03374e8023016b3bec4f66ab16cb421222ef1
--- /dev/null
+++ b/nets/utils.py
@@ -0,0 +1,122 @@
+import json
+import textgrid as tg
+import numpy as np
+
+def get_parameter_size(model):
+ total_num = sum(p.numel() for p in model.parameters())
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return total_num, trainable_num
+
+def denormalize(kps, data_mean, data_std):
+ '''
+ kps: (B, T, C)
+ '''
+ data_std = data_std.reshape(1, 1, -1)
+ data_mean = data_mean.reshape(1, 1, -1)
+ return (kps * data_std) + data_mean
+
+def normalize(kps, data_mean, data_std):
+ '''
+ kps: (B, T, C)
+ '''
+ data_std = data_std.squeeze().reshape(1, 1, -1)
+ data_mean = data_mean.squeeze().reshape(1, 1, -1)
+
+ return (kps-data_mean) / data_std
+
+def parse_audio(textgrid_file):
+ '''a demo implementation'''
+ words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when']
+ txt=tg.TextGrid.fromFile(textgrid_file)
+
+ total_time=int(np.ceil(txt.maxTime))
+ code_seq=np.zeros(total_time)
+
+ word_level=txt[0]
+
+ for i in range(len(word_level)):
+ start_time=word_level[i].minTime
+ end_time=word_level[i].maxTime
+ mark=word_level[i].mark
+
+ if mark in words:
+ start=int(np.round(start_time))
+ end=int(np.round(end_time))
+
+ if start >= len(code_seq) or end >= len(code_seq):
+ code_seq[-1] = 1
+ else:
+ code_seq[start]=1
+
+ return code_seq
+
+
+def get_path(model_name, model_type):
+ if model_name == 's2g_body_pixel':
+ if model_type == 'mfcc':
+ return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
+ elif model_type == 'wv2':
+ return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth'
+ elif model_type == 'random':
+ return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth'
+ elif model_type == 'wbhmodel':
+ return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth'
+ elif model_type == 'wobhmodel':
+ return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth'
+ elif model_name == 's2g_body':
+ if model_type == 'a+m-vae':
+ return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth'
+ elif model_type == 'a-vae':
+ return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth'
+ elif model_type == 'a-ed':
+ return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth'
+ elif model_name == 's2g_LS3DCG':
+ return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
+ elif model_name == 's2g_body_vq':
+ if model_type == 'n_com_1024':
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth'
+ elif model_type == 'n_com_2048':
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth'
+ elif model_type == 'n_com_4096':
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth'
+ elif model_type == 'n_com_8192':
+ return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth'
+ elif model_type == 'n_com_16384':
+ return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth'
+ elif model_type == 'n_com_170000':
+ return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth'
+ elif model_type == 'com_1024':
+ return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth'
+ elif model_type == 'com_2048':
+ return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth'
+ elif model_type == 'com_4096':
+ return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth'
+ elif model_type == 'com_8192':
+ return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth'
+ elif model_type == 'com_16384':
+ return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth'
+
+
+def get_dpath(model_name, model_type):
+ if model_name == 's2g_body_pixel':
+ if model_type == 'audio':
+ return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth'
+ elif model_type == 'wv2':
+ return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth'
+ elif model_type == 'random':
+ return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth'
+ elif model_type == 'wbhmodel':
+ return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth'
+ # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth'
+ elif model_type == 'wobhmodel':
+ return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth'
+ # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth'
+ elif model_name == 's2g_body':
+ if model_type == 'a+m-vae':
+ return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth'
+ elif model_type == 'a-vae':
+ return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth'
+ elif model_type == 'a-ed':
+ return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth'
+ elif model_name == 's2g_LS3DCG':
+ return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth'
\ No newline at end of file
diff --git a/scripts/.idea/__init__.py b/scripts/.idea/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/.idea/aws.xml b/scripts/.idea/aws.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b63b642cfb4254fc0f7058903abc5b481895c4ef
--- /dev/null
+++ b/scripts/.idea/aws.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/deployment.xml b/scripts/.idea/deployment.xml
new file mode 100644
index 0000000000000000000000000000000000000000..14f2c41a46d0210c6395ed0e7bfd3b630211f699
--- /dev/null
+++ b/scripts/.idea/deployment.xml
@@ -0,0 +1,70 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/get_prevar.py b/scripts/.idea/get_prevar.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4b2dfb1892e9ff79c8074f35e84d897c64ff673
--- /dev/null
+++ b/scripts/.idea/get_prevar.py
@@ -0,0 +1,132 @@
+import os
+import sys
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+sys.path.append(os.getcwd())
+from glob import glob
+
+import numpy as np
+import json
+import smplx as smpl
+
+from nets import *
+from repro_nets import *
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils import data
+
+def init_model(model_name, model_path, args, config):
+ if model_name == 'freeMo':
+ # generator = freeMo_Generator(args)
+ # generator = freeMo_Generator(args)
+ generator = freeMo_dev(args, config)
+ # generator.load_state_dict(torch.load(model_path)['generator'])
+ elif model_name == 'smplx_S2G':
+ generator = smplx_S2G(args, config)
+ elif model_name == 'StyleGestures':
+ generator = StyleGesture_Generator(
+ args,
+ config
+ )
+ elif model_name == 'Audio2Gestures':
+ config.Train.using_mspec_stat = False
+ generator = Audio2Gesture_Generator(
+ args,
+ config,
+ torch.zeros([1, 1, 108]),
+ torch.ones([1, 1, 108])
+ )
+ elif model_name == 'S2G':
+ generator = S2G_Generator(
+ args,
+ config,
+ )
+ elif model_name == 'Tmpt':
+ generator = S2G_Generator(
+ args,
+ config,
+ )
+ else:
+ raise NotImplementedError
+
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ if model_name == 'smplx_S2G':
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
+ elif 'generator' in list(model_ckpt.keys()):
+ generator.load_state_dict(model_ckpt['generator'])
+ else:
+ model_ckpt = {'generator': model_ckpt}
+ generator.load_state_dict(model_ckpt)
+
+ return generator
+
+
+
+def prevar_loader(data_root, speakers, args, config, model_path, device, generator):
+ path = model_path.split('ckpt')[0]
+ file = os.path.join(os.path.dirname(path), "pre_variable.npy")
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='pre',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ num_generate_length=config.Data.pose.generate_length,
+ num_frames=15,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=config.Data.pose.convert_to_6d,
+ expression=config.Data.pose.expression
+ )
+
+ data_base.get_dataset()
+ pre_set = data_base.all_dataset
+ pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True)
+
+ total_pose = []
+
+ with torch.no_grad():
+ for bat in pre_loader:
+ pose = bat['poses'].to(device).to(torch.float32)
+ expression = bat['expression'].to(device).to(torch.float32)
+ pose = pose.permute(0, 2, 1)
+ pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0)
+ expression = expression.permute(0, 2, 1)
+ expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0)
+ pose = torch.cat([pose, expression], dim=-1)
+ pose = pose.reshape(pose.shape[0], -1, 1)
+ pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu()
+ total_pose.append(np.asarray(pose_code))
+ total_pose = np.concatenate(total_pose, axis=0)
+ mean = np.mean(total_pose, axis=0)
+ std = np.std(total_pose, axis=0)
+ prevar = (mean, std)
+ np.save(file, prevar, allow_pickle=True)
+
+ return mean, std
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ print('init model...')
+ generator = init_model(config.Model.model_name, args.model_path, args, config)
+ print('init pre-pose vectors...')
+ mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator)
+
+main()
\ No newline at end of file
diff --git a/scripts/.idea/inspectionProfiles/Project_Default.xml b/scripts/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..3dce9c67a3cba33789d113124d53150ccca2370b
--- /dev/null
+++ b/scripts/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/inspectionProfiles/profiles_settings.xml b/scripts/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/scripts/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/lower body b/scripts/.idea/lower body
new file mode 100644
index 0000000000000000000000000000000000000000..1efda13cfb1455b382ced16ed1ddb16d1716ae7f
--- /dev/null
+++ b/scripts/.idea/lower body
@@ -0,0 +1 @@
+0, 1, 3, 4, 6, 7, 9, 10,
\ No newline at end of file
diff --git a/scripts/.idea/modules.xml b/scripts/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..bb83e262159915cb1ea30b748c3123878bf4c341
--- /dev/null
+++ b/scripts/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/scripts.iml b/scripts/.idea/scripts.iml
new file mode 100644
index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6
--- /dev/null
+++ b/scripts/.idea/scripts.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/test.png b/scripts/.idea/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..b028ad2b8c4fb0b6a89571ae62d2654908c93d9c
Binary files /dev/null and b/scripts/.idea/test.png differ
diff --git a/scripts/.idea/testtext.py b/scripts/.idea/testtext.py
new file mode 100644
index 0000000000000000000000000000000000000000..af0185442ae8b1e84c9ea64c671dde6da394046c
--- /dev/null
+++ b/scripts/.idea/testtext.py
@@ -0,0 +1,24 @@
+import cv2
+
+# path being defined from where the system will read the image
+path = r'test.png'
+# command used for reading an image from the disk disk, cv2.imread function is used
+image1 = cv2.imread(path)
+# Window name being specified where the image will be displayed
+window_name1 = 'image'
+# font for the text being specified
+font1 = cv2.FONT_HERSHEY_SIMPLEX
+# org for the text being specified
+org1 = (50, 50)
+# font scale for the text being specified
+fontScale1 = 1
+# Blue color for the text being specified from BGR
+color1 = (255, 255, 255)
+# Line thickness for the text being specified at 2 px
+thickness1 = 2
+# Using the cv2.putText() method for inserting text in the image of the specified path
+image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA)
+# Displaying the output image
+cv2.imshow(window_name1, image_1)
+cv2.waitKey(0)
+cv2.destroyAllWindows()
diff --git a/scripts/.idea/vcs.xml b/scripts/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6c0b8635858dc7ad44b93df54b762707ce49eefc
--- /dev/null
+++ b/scripts/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/.idea/workspace.xml b/scripts/.idea/workspace.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e45519a6841bc50a93d2d3bdb05aaa935ff861a0
--- /dev/null
+++ b/scripts/.idea/workspace.xml
@@ -0,0 +1,75 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1655101254730
+
+
+ 1655101254730
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/__pycache__/__init__.cpython-37.pyc b/scripts/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ac07c57d7d79de0f98f4a3f0f868c3c4fbcf474
Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-37.pyc differ
diff --git a/scripts/__pycache__/diversity.cpython-37.pyc b/scripts/__pycache__/diversity.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33875c7b891d8fe4b20c894d7fccd4fb3ddb645e
Binary files /dev/null and b/scripts/__pycache__/diversity.cpython-37.pyc differ
diff --git a/scripts/continuity.py b/scripts/continuity.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e5eaebe0a5f45c745443a6d4d959544c595051
--- /dev/null
+++ b/scripts/continuity.py
@@ -0,0 +1,200 @@
+import os
+import sys
+# os.environ["PYOPENGL_PLATFORM"] = "egl"
+from transformers import Wav2Vec2Processor
+from visualise.rendering import RenderTool
+
+sys.path.append(os.getcwd())
+from glob import glob
+
+import numpy as np
+import json
+import smplx as smpl
+
+from nets import *
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils import data
+from scripts.diversity import init_model, init_dataloader, get_vertices
+from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
+from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
+import time
+
+
+global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
+
+
+def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
+ smplx_model, rendertool, args=None, config=None, var=None):
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
+ am_sr = 16000
+ num_sample = 1
+ face = False
+ if face:
+ body_static = torch.zeros([1, 162], device='cuda')
+ body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
+ stand = False
+ j = 0
+ gt_0 = None
+
+ for bat in infer_loader:
+ poses_ = bat['poses'].to(torch.float32).to(device)
+ if poses_.shape[-1] == 300:
+ j = j + 1
+ if j > 1000:
+ continue
+ id = bat['speaker'].to('cuda') - 20
+ if config.Data.pose.expression:
+ expression = bat['expression'].to(device).to(torch.float32)
+ poses = torch.cat([poses_, expression], dim=1)
+ else:
+ poses = poses_
+ cur_wav_file = bat['aud_file'][0]
+ betas = bat['betas'][0].to(torch.float64).to('cuda')
+ # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
+ gt = poses.to('cuda').squeeze().transpose(1, 0)
+ if config.Data.pose.normalization:
+ gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
+ if config.Data.pose.convert_to_6d:
+ if config.Data.pose.expression:
+ gt_exp = gt[:, -100:]
+ gt = gt[:, :-100]
+
+ gt = gt.reshape(gt.shape[0], -1, 6)
+ gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
+ gt = torch.cat([gt, gt_exp], -1)
+ if face:
+ gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
+
+ result_list = [gt]
+
+ # cur_wav_file = '.\\training_data\\french-V4.wav'
+
+ # pred_face = g_face.infer_on_audio(cur_wav_file,
+ # initial_pose=poses_,
+ # norm_stats=None,
+ # w_pre=False,
+ # # id=id,
+ # frame=None,
+ # am=am,
+ # am_sr=am_sr
+ # )
+ #
+ # pred_face = torch.tensor(pred_face).squeeze().to('cuda')
+
+ pred_face = torch.zeros([gt.shape[0], 103], device='cuda')
+ pred_jaw = pred_face[:, :3]
+ pred_face = pred_face[:, 3:]
+
+ # id = torch.tensor([0], device='cuda')
+
+ for i in range(num_sample):
+ pred_res = g_body.infer_on_audio(cur_wav_file,
+ initial_pose=poses_,
+ norm_stats=norm_stats,
+ txgfile=None,
+ id=id,
+ var=var,
+ fps=30,
+ continuity=True,
+ smooth=False
+ )
+ pred = torch.tensor(pred_res).squeeze().to('cuda')
+
+ if pred.shape[0] < pred_face.shape[0]:
+ repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
+ pred = torch.cat([pred, repeat_frame], dim=0)
+ else:
+ pred = pred[:pred_face.shape[0], :]
+
+ if config.Data.pose.convert_to_6d:
+ pred = pred.reshape(pred.shape[0], -1, 6)
+ pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
+ pred = pred.reshape(pred.shape[0], -1)
+
+ pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
+ # pred[:, 9:12] = global_orient
+ pred = part2full(pred, stand)
+ if face:
+ pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
+ # result_list[0] = poses2pred(result_list[0], stand)
+ # if gt_0 is None:
+ # gt_0 = gt
+ # pred = pred2poses(pred, gt_0)
+ # result_list[0] = poses2poses(result_list[0], gt_0)
+
+ result_list.append(pred)
+
+ vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
+
+ result_list = [res.to('cpu') for res in result_list]
+ dict = np.concatenate(result_list[1:], axis=0)
+ file_name = 'visualise/video/' + config.Log.name + '/' + \
+ cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
+ np.save(file_name, dict)
+
+ rendertool._render_continuity(cur_wav_file, vertices_list[1], frame=60)
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ smplx = True
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init model...')
+ body_model_name = 's2g_body_pixel'
+ body_model_path = './experiments/2022-12-31-smplx_S2G-body-pixel-conti-wide/ckpt-99.pth' # './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
+ generator = init_model(body_model_name, body_model_path, args, config)
+
+ # face_model_name = 's2g_face'
+ # face_model_path = './experiments/2022-10-15-smplx_S2G-face-sgd-3p-wv2/ckpt-99.pth' # './experiments/2022-09-28-smplx_S2G-face-faceformer-3d/ckpt-99.pth'
+ # generator_face = init_model(face_model_name, face_model_path, args, config)
+ generator_face = None
+ print('init dataloader...')
+ infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
+
+ print('init smlpx model...')
+ dtype = torch.float64
+ model_params = dict(model_path='E:/PycharmProjects/Motion-Projects/models',
+ model_type='smplx',
+ create_global_orient=True,
+ create_body_pose=True,
+ create_betas=True,
+ num_betas=300,
+ create_left_hand_pose=True,
+ create_right_hand_pose=True,
+ use_pca=False,
+ flat_hand_mean=False,
+ create_expression=True,
+ num_expression_coeffs=100,
+ num_pca_comps=12,
+ create_jaw_pose=True,
+ create_leye_pose=True,
+ create_reye_pose=True,
+ create_transl=False,
+ # gender='ne',
+ dtype=dtype, )
+ smplx_model = smpl.create(**model_params).to('cuda')
+ print('init rendertool...')
+ rendertool = RenderTool('visualise/video/' + config.Log.name)
+
+ infer(config.Data.data_root, generator, generator_face, None, args.exp_name, infer_loader, infer_set, device,
+ norm_stats, smplx, smplx_model, rendertool, args, config, (None, None))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/demo.py b/scripts/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7e2f9e48a7f20faed2ffb6aab729215c878ddb1
--- /dev/null
+++ b/scripts/demo.py
@@ -0,0 +1,303 @@
+import os
+import sys
+# os.environ["PYOPENGL_PLATFORM"] = "egl"
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+sys.path.append(os.getcwd())
+
+from transformers import Wav2Vec2Processor
+from glob import glob
+
+import numpy as np
+import json
+import smplx as smpl
+
+from nets import *
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils import data
+from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
+from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
+from visualise.rendering import RenderTool
+
+global device
+device = 'cpu'
+
+def init_model(model_name, model_path, args, config):
+ if model_name == 's2g_face':
+ generator = s2g_face(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_vq':
+ generator = s2g_body_vq(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_pixel':
+ generator = s2g_body_pixel(
+ args,
+ config,
+ )
+ elif model_name == 's2g_LS3DCG':
+ generator = LS3DCG(
+ args,
+ config,
+ )
+ else:
+ raise NotImplementedError
+
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ if model_name == 'smplx_S2G':
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
+
+ elif 'generator' in list(model_ckpt.keys()):
+ generator.load_state_dict(model_ckpt['generator'])
+ else:
+ model_ckpt = {'generator': model_ckpt}
+ generator.load_state_dict(model_ckpt)
+
+ return generator
+
+
+def init_dataloader(data_root, speakers, args, config):
+ if data_root.endswith('.csv'):
+ raise NotImplementedError
+ else:
+ data_class = torch_data
+ if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='test',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ num_generate_length=config.Data.pose.generate_length,
+ num_frames=30,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=config.Data.pose.convert_to_6d,
+ expression=config.Data.pose.expression,
+ config=config
+ )
+ else:
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='val',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method
+ )
+ if config.Data.pose.normalization:
+ norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
+ norm_stats = np.load(norm_stats_fn, allow_pickle=True)
+ data_base.data_mean = norm_stats[0]
+ data_base.data_std = norm_stats[1]
+ else:
+ norm_stats = None
+
+ data_base.get_dataset()
+ infer_set = data_base.all_dataset
+ infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)
+
+ return infer_set, infer_loader, norm_stats
+
+
+def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
+ vertices_list = []
+ poses_list = []
+ expression = torch.zeros([1, 50])
+
+ for i in result_list:
+ vertices = []
+ poses = []
+ for j in range(i.shape[0]):
+ output = smplx_model(betas=betas,
+ expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
+ jaw_pose=i[j][0:3].unsqueeze_(dim=0),
+ leye_pose=i[j][3:6].unsqueeze_(dim=0),
+ reye_pose=i[j][6:9].unsqueeze_(dim=0),
+ global_orient=i[j][9:12].unsqueeze_(dim=0),
+ body_pose=i[j][12:75].unsqueeze_(dim=0),
+ left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
+ right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
+ return_verts=True)
+ vertices.append(output.vertices.detach().cpu().numpy().squeeze())
+ # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
+ pose = output.body_pose
+ poses.append(pose.detach().cpu())
+ vertices = np.asarray(vertices)
+ vertices_list.append(vertices)
+ poses = torch.cat(poses, dim=0)
+ poses_list.append(poses)
+ if require_pose:
+ return vertices_list, poses_list
+ else:
+ return vertices_list, None
+
+
+global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
+
+
+def infer(g_body, g_face, smplx_model, rendertool, config, args):
+ betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
+ am_sr = 16000
+ num_sample = args.num_sample
+ cur_wav_file = args.audio_file
+ id = args.id
+ face = args.only_face
+ stand = args.stand
+ if face:
+ body_static = torch.zeros([1, 162], device=device)
+ body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
+
+ result_list = []
+
+ pred_face = g_face.infer_on_audio(cur_wav_file,
+ initial_pose=None,
+ norm_stats=None,
+ w_pre=False,
+ # id=id,
+ frame=None,
+ am=am,
+ am_sr=am_sr
+ )
+ pred_face = torch.tensor(pred_face).squeeze().to(device)
+ # pred_face = torch.zeros([gt.shape[0], 105])
+
+ if config.Data.pose.convert_to_6d:
+ pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
+ pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
+ pred_face = pred_face[:, 6:]
+ else:
+ pred_jaw = pred_face[:, :3]
+ pred_face = pred_face[:, 3:]
+
+ id = torch.tensor([id], device=device)
+
+ for i in range(num_sample):
+ pred_res = g_body.infer_on_audio(cur_wav_file,
+ initial_pose=None,
+ norm_stats=None,
+ txgfile=None,
+ id=id,
+ var=None,
+ fps=30,
+ w_pre=False
+ )
+ pred = torch.tensor(pred_res).squeeze().to(device)
+
+ if pred.shape[0] < pred_face.shape[0]:
+ repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
+ pred = torch.cat([pred, repeat_frame], dim=0)
+ else:
+ pred = pred[:pred_face.shape[0], :]
+
+ body_or_face = False
+ if pred.shape[1] < 275:
+ body_or_face = True
+ if config.Data.pose.convert_to_6d:
+ pred = pred.reshape(pred.shape[0], -1, 6)
+ pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
+ pred = pred.reshape(pred.shape[0], -1)
+
+ if config.Model.model_name == 's2g_LS3DCG':
+ pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
+ else:
+ pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
+
+ # pred[:, 9:12] = global_orient
+ pred = part2full(pred, stand)
+ if face:
+ pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
+ # result_list[0] = poses2pred(result_list[0], stand)
+ # if gt_0 is None:
+ # gt_0 = gt
+ # pred = pred2poses(pred, gt_0)
+ # result_list[0] = poses2poses(result_list[0], gt_0)
+
+ result_list.append(pred)
+
+
+ vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
+
+ result_list = [res.to('cpu') for res in result_list]
+ dict = np.concatenate(result_list[:], axis=0)
+ file_name = 'visualise/video/' + config.Log.name + '/' + \
+ cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
+ np.save(file_name, dict)
+
+ rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ # device = torch.device(args.gpu)
+ # torch.cuda.set_device(device)
+
+
+ config = load_JsonConfig(args.config_file)
+
+ face_model_name = args.face_model_name
+ face_model_path = args.face_model_path
+ body_model_name = args.body_model_name
+ body_model_path = args.body_model_path
+ smplx_path = './visualise/'
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init model...')
+ generator = init_model(body_model_name, body_model_path, args, config)
+ generator2 = None
+ generator_face = init_model(face_model_name, face_model_path, args, config)
+
+ print('init smlpx model...')
+ dtype = torch.float64
+ model_params = dict(model_path=smplx_path,
+ model_type='smplx',
+ create_global_orient=True,
+ create_body_pose=True,
+ create_betas=True,
+ num_betas=300,
+ create_left_hand_pose=True,
+ create_right_hand_pose=True,
+ use_pca=False,
+ flat_hand_mean=False,
+ create_expression=True,
+ num_expression_coeffs=100,
+ num_pca_comps=12,
+ create_jaw_pose=True,
+ create_leye_pose=True,
+ create_reye_pose=True,
+ create_transl=False,
+ # gender='ne',
+ dtype=dtype, )
+ smplx_model = smpl.create(**model_params).to(device)
+ print('init rendertool...')
+ rendertool = RenderTool('visualise/video/' + config.Log.name)
+
+ infer(generator, generator_face, smplx_model, rendertool, config, args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/diversity.py b/scripts/diversity.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4fe023cd7e20e508a881528346aa1251fbd02b5
--- /dev/null
+++ b/scripts/diversity.py
@@ -0,0 +1,352 @@
+import os
+import sys
+# os.environ["PYOPENGL_PLATFORM"] = "egl"
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+sys.path.append(os.getcwd())
+
+from transformers import Wav2Vec2Processor
+from glob import glob
+
+import numpy as np
+import json
+import smplx as smpl
+
+from nets import *
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils import data
+from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
+from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
+from visualise.rendering import RenderTool
+
+import time
+
+
+def init_model(model_name, model_path, args, config):
+ if model_name == 's2g_face':
+ generator = s2g_face(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_vq':
+ generator = s2g_body_vq(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_pixel':
+ generator = s2g_body_pixel(
+ args,
+ config,
+ )
+ elif model_name == 's2g_LS3DCG':
+ generator = LS3DCG(
+ args,
+ config,
+ )
+ else:
+ raise NotImplementedError
+
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ if model_name == 'smplx_S2G':
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
+
+ elif 'generator' in list(model_ckpt.keys()):
+ generator.load_state_dict(model_ckpt['generator'])
+ else:
+ model_ckpt = {'generator': model_ckpt}
+ generator.load_state_dict(model_ckpt)
+
+ return generator
+
+
+def init_dataloader(data_root, speakers, args, config):
+ if data_root.endswith('.csv'):
+ raise NotImplementedError
+ else:
+ data_class = torch_data
+ if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='test',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ num_generate_length=config.Data.pose.generate_length,
+ num_frames=30,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=config.Data.pose.convert_to_6d,
+ expression=config.Data.pose.expression,
+ config=config
+ )
+ else:
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='val',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method
+ )
+ if config.Data.pose.normalization:
+ norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
+ norm_stats = np.load(norm_stats_fn, allow_pickle=True)
+ data_base.data_mean = norm_stats[0]
+ data_base.data_std = norm_stats[1]
+ else:
+ norm_stats = None
+
+ data_base.get_dataset()
+ infer_set = data_base.all_dataset
+ infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)
+
+ return infer_set, infer_loader, norm_stats
+
+
+def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
+ vertices_list = []
+ poses_list = []
+ expression = torch.zeros([1, 50])
+
+ for i in result_list:
+ vertices = []
+ poses = []
+ for j in range(i.shape[0]):
+ output = smplx_model(betas=betas,
+ expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
+ jaw_pose=i[j][0:3].unsqueeze_(dim=0),
+ leye_pose=i[j][3:6].unsqueeze_(dim=0),
+ reye_pose=i[j][6:9].unsqueeze_(dim=0),
+ global_orient=i[j][9:12].unsqueeze_(dim=0),
+ body_pose=i[j][12:75].unsqueeze_(dim=0),
+ left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
+ right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
+ return_verts=True)
+ vertices.append(output.vertices.detach().cpu().numpy().squeeze())
+ # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
+ pose = output.body_pose
+ poses.append(pose.detach().cpu())
+ vertices = np.asarray(vertices)
+ vertices_list.append(vertices)
+ poses = torch.cat(poses, dim=0)
+ poses_list.append(poses)
+ if require_pose:
+ return vertices_list, poses_list
+ else:
+ return vertices_list, None
+
+
+global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
+
+
+def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
+ smplx_model, rendertool, args=None, config=None):
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
+ am_sr = 16000
+ num_sample = 1
+ face = False
+ if face:
+ body_static = torch.zeros([1, 162], device='cuda')
+ body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
+ stand = False
+ j = 0
+ gt_0 = None
+
+ for bat in infer_loader:
+ poses_ = bat['poses'].to(torch.float32).to(device)
+ if poses_.shape[-1] == 300:
+ j = j + 1
+ if j > 1000:
+ continue
+ id = bat['speaker'].to('cuda') - 20
+ if config.Data.pose.expression:
+ expression = bat['expression'].to(device).to(torch.float32)
+ poses = torch.cat([poses_, expression], dim=1)
+ else:
+ poses = poses_
+ cur_wav_file = bat['aud_file'][0]
+ betas = bat['betas'][0].to(torch.float64).to('cuda')
+ # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
+ gt = poses.to('cuda').squeeze().transpose(1, 0)
+ if config.Data.pose.normalization:
+ gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
+ if config.Data.pose.convert_to_6d:
+ if config.Data.pose.expression:
+ gt_exp = gt[:, -100:]
+ gt = gt[:, :-100]
+
+ gt = gt.reshape(gt.shape[0], -1, 6)
+
+ gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
+ gt = torch.cat([gt, gt_exp], -1)
+ if face:
+ gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
+
+ result_list = [gt]
+
+ # cur_wav_file = '.\\training_data\\1_song_(Vocals).wav'
+
+ pred_face = g_face.infer_on_audio(cur_wav_file,
+ initial_pose=poses_,
+ norm_stats=None,
+ w_pre=False,
+ # id=id,
+ frame=None,
+ am=am,
+ am_sr=am_sr
+ )
+
+ pred_face = torch.tensor(pred_face).squeeze().to('cuda')
+ # pred_face = torch.zeros([gt.shape[0], 105])
+
+ if config.Data.pose.convert_to_6d:
+ pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
+ pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
+ pred_face = pred_face[:, 6:]
+ else:
+ pred_jaw = pred_face[:, :3]
+ pred_face = pred_face[:, 3:]
+
+ # id = torch.tensor([0], device='cuda')
+
+ for i in range(num_sample):
+ pred_res = g_body.infer_on_audio(cur_wav_file,
+ initial_pose=poses_,
+ norm_stats=norm_stats,
+ txgfile=None,
+ id=id,
+ # var=var,
+ fps=30,
+ w_pre=False
+ )
+ pred = torch.tensor(pred_res).squeeze().to('cuda')
+
+ if pred.shape[0] < pred_face.shape[0]:
+ repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
+ pred = torch.cat([pred, repeat_frame], dim=0)
+ else:
+ pred = pred[:pred_face.shape[0], :]
+
+ body_or_face = False
+ if pred.shape[1] < 275:
+ body_or_face = True
+ if config.Data.pose.convert_to_6d:
+ pred = pred.reshape(pred.shape[0], -1, 6)
+ pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
+ pred = pred.reshape(pred.shape[0], -1)
+
+ pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
+ # pred[:, 9:12] = global_orient
+ pred = part2full(pred, stand)
+ if face:
+ pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
+ result_list[0] = poses2pred(result_list[0], stand)
+ # if gt_0 is None:
+ # gt_0 = gt
+ # pred = pred2poses(pred, gt_0)
+ # result_list[0] = poses2poses(result_list[0], gt_0)
+
+ result_list.append(pred)
+
+ if g_body2 is not None:
+ pred_res2 = g_body2.infer_on_audio(cur_wav_file,
+ initial_pose=poses_,
+ norm_stats=norm_stats,
+ txgfile=None,
+ # var=var,
+ fps=30,
+ w_pre=False
+ )
+ pred2 = torch.tensor(pred_res2).squeeze().to('cuda')
+ pred2 = torch.cat([pred2[:, :3], pred2[:, 103:], pred2[:, 3:103]], dim=-1)
+ # pred2 = part2full(pred2, stand)
+ # result_list[0] = poses2pred(result_list[0], stand)
+ # if gt_0 is None:
+ # gt_0 = gt
+ # pred2 = pred2poses(pred2, gt_0)
+ # result_list[0] = poses2poses(result_list[0], gt_0)
+ result_list[1] = pred2
+
+ vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
+
+ result_list = [res.to('cpu') for res in result_list]
+ dict = np.concatenate(result_list[1:], axis=0)
+ file_name = 'visualise/video/' + config.Log.name + '/' + \
+ cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
+ np.save(file_name, dict)
+
+ rendertool._render_sequences(cur_wav_file, vertices_list[1:], stand=stand, face=face)
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ face_model_name = args.face_model_name
+ face_model_path = args.face_model_path
+ body_model_name = args.body_model_name
+ body_model_path = args.body_model_path
+ smplx_path = './visualise/'
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init model...')
+ generator = init_model(body_model_name, body_model_path, args, config)
+ generator2 = None
+ generator_face = init_model(face_model_name, face_model_path, args, config)
+ print('init dataloader...')
+ infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
+
+ print('init smlpx model...')
+ dtype = torch.float64
+ model_params = dict(model_path=smplx_path,
+ model_type='smplx',
+ create_global_orient=True,
+ create_body_pose=True,
+ create_betas=True,
+ num_betas=300,
+ create_left_hand_pose=True,
+ create_right_hand_pose=True,
+ use_pca=False,
+ flat_hand_mean=False,
+ create_expression=True,
+ num_expression_coeffs=100,
+ num_pca_comps=12,
+ create_jaw_pose=True,
+ create_leye_pose=True,
+ create_reye_pose=True,
+ create_transl=False,
+ # gender='ne',
+ dtype=dtype, )
+ smplx_model = smpl.create(**model_params).to('cuda')
+ print('init rendertool...')
+ rendertool = RenderTool('visualise/video/' + config.Log.name)
+
+ infer(config.Data.data_root, generator, generator_face, generator2, args.exp_name, infer_loader, infer_set, device,
+ norm_stats, True, smplx_model, rendertool, args, config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_body.py b/scripts/test_body.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c0d346ab2ea4963f766769e05187e661b5f091e
--- /dev/null
+++ b/scripts/test_body.py
@@ -0,0 +1,252 @@
+import os
+import sys
+
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '3'
+sys.path.append(os.getcwd())
+
+from tqdm import tqdm
+from transformers import Wav2Vec2Processor
+
+from evaluation.FGD import EmbeddingSpaceEvaluator
+
+from evaluation.metrics import LVD
+
+import numpy as np
+import smplx as smpl
+
+from data_utils.lower_body import part2full, poses2pred
+from data_utils.utils import get_mfcc_ta
+from nets import *
+from nets.utils import get_path, get_dpath
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+from torch.utils import data
+from data_utils.get_j import to3d, get_joints
+
+
+def init_model(model_name, model_path, args, config):
+ if model_name == 's2g_face':
+ generator = s2g_face(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_vq':
+ generator = s2g_body_vq(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_pixel':
+ generator = s2g_body_pixel(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_ae':
+ generator = s2g_body_ae(
+ args,
+ config,
+ )
+ else:
+ raise NotImplementedError
+
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ generator.load_state_dict(model_ckpt['generator'])
+
+ return generator
+
+
+def init_dataloader(data_root, speakers, args, config):
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='test',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ num_generate_length=config.Data.pose.generate_length,
+ num_frames=30,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=config.Data.pose.convert_to_6d,
+ expression=config.Data.pose.expression,
+ config=config
+ )
+
+ if config.Data.pose.normalization:
+ norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
+ norm_stats = np.load(norm_stats_fn, allow_pickle=True)
+ data_base.data_mean = norm_stats[0]
+ data_base.data_std = norm_stats[1]
+ else:
+ norm_stats = None
+
+ data_base.get_dataset()
+ test_set = data_base.all_dataset
+ test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
+
+ return test_set, test_loader, norm_stats
+
+
+def body_loss(gt, prs):
+ loss_dict = {}
+ # LVD
+ v_diff = LVD(gt[:, :22, :], prs[:, :, :22, :], symmetrical=False, weight=False)
+ loss_dict['LVD'] = v_diff
+ # Accuracy
+ error = (gt - prs).norm(p=2, dim=-1).sum(dim=-1).mean()
+ loss_dict['error'] = error
+ # Diversity
+ var = prs.var(dim=0).norm(p=2, dim=-1).sum(dim=-1).mean()
+ loss_dict['diverse'] = var
+
+ return loss_dict
+
+
+def test(test_loader, generator, FGD_handler, smplx_model, config):
+ print('start testing')
+
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
+ am_sr = 16000
+
+ loss_dict = {}
+ B = 2
+ with torch.no_grad():
+ count = 0
+ for bat in tqdm(test_loader, desc="Testing......"):
+ count = count + 1
+ # if count == 10:
+ # break
+ _, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
+ bat['expression'].to('cuda').to(torch.float32)
+ id = bat['speaker'].to('cuda') - 20
+ betas = bat['betas'][0].to('cuda').to(torch.float64)
+ poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2)
+
+ cur_wav_file = bat['aud_file'][0]
+
+ zero_face = torch.zeros([B, poses.shape[1], 103], device='cuda')
+
+ joints_list = []
+
+ pred = generator.infer_on_audio(cur_wav_file,
+ id=id,
+ fps=30,
+ B=B,
+ am=am,
+ am_sr=am_sr,
+ frame=poses.shape[0]
+ )
+ pred = torch.tensor(pred, device='cuda')
+
+ FGD_handler.push_samples(pred, poses)
+
+ poses = poses.squeeze()
+ poses = to3d(poses, config)
+
+ if pred.shape[2] > 129:
+ pred = pred[:, :, 103:]
+
+ pred = torch.cat([zero_face[:, :pred.shape[1], :3], pred, zero_face[:, :pred.shape[1], 3:]], dim=-1)
+ full_pred = []
+ for j in range(B):
+ f_pred = part2full(pred[j])
+ full_pred.append(f_pred)
+
+ for i in range(full_pred.__len__()):
+ full_pred[i] = full_pred[i].unsqueeze(dim=0)
+ full_pred = torch.cat(full_pred, dim=0)
+
+ pred_joints = get_joints(smplx_model, betas, full_pred)
+
+ poses = poses2pred(poses)
+ poses = torch.cat([zero_face[0, :, :3], poses[:, 3:165], zero_face[0, :, 3:]], dim=-1)
+ gt_joints = get_joints(smplx_model, betas, poses[:pred_joints.shape[1]])
+ FGD_handler.push_joints(pred_joints, gt_joints)
+ aud = get_mfcc_ta(cur_wav_file, fps=30, sr=16000, am='not None', encoder_choice='onset')
+ FGD_handler.push_aud(torch.from_numpy(aud))
+
+ bat_loss_dict = body_loss(gt_joints, pred_joints)
+
+ if loss_dict: # 非空
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] += bat_loss_dict[key]
+ else:
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] = bat_loss_dict[key]
+ for key in loss_dict.keys():
+ loss_dict[key] = loss_dict[key] / count
+ print(key + '=' + str(loss_dict[key].item()))
+
+ # MAAC = FGD_handler.get_MAAC()
+ # print(MAAC)
+ fgd_dist, feat_dist = FGD_handler.get_scores()
+ print('fgd_dist=', fgd_dist.item())
+ print('feat_dist=', feat_dist.item())
+ BCscore = FGD_handler.get_BCscore()
+ print('Beat consistency score=', BCscore)
+
+
+
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init dataloader...')
+ test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
+ print('init model...')
+ model_name = args.body_model_name
+ # model_path = get_path(model_name, model_type)
+ model_path = args.body_model_path
+ generator = init_model(model_name, model_path, args, config)
+
+ ae = init_model('s2g_body_ae', './experiments/feature_extractor.pth', args,
+ config)
+ FGD_handler = EmbeddingSpaceEvaluator(ae, None, 'cuda')
+
+ print('init smlpx model...')
+ dtype = torch.float64
+ smplx_path = './visualise/'
+ model_params = dict(model_path=smplx_path,
+ model_type='smplx',
+ create_global_orient=True,
+ create_body_pose=True,
+ create_betas=True,
+ num_betas=300,
+ create_left_hand_pose=True,
+ create_right_hand_pose=True,
+ use_pca=False,
+ flat_hand_mean=False,
+ create_expression=True,
+ num_expression_coeffs=100,
+ num_pca_comps=12,
+ create_jaw_pose=True,
+ create_leye_pose=True,
+ create_reye_pose=True,
+ create_transl=False,
+ dtype=dtype, )
+
+ smplx_model = smpl.create(**model_params).to('cuda')
+
+ test(test_loader, generator, FGD_handler, smplx_model, config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_face.py b/scripts/test_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..a889b6451fd32e91e7e3c3646faf17e54d147e78
--- /dev/null
+++ b/scripts/test_face.py
@@ -0,0 +1,209 @@
+import os
+import sys
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+sys.path.append(os.getcwd())
+
+from tqdm import tqdm
+from transformers import Wav2Vec2Processor
+
+from evaluation.metrics import LVD
+
+import numpy as np
+import smplx as smpl
+
+from nets import *
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+from data_utils.get_j import get_joints
+
+import torch
+from torch.utils import data
+
+
+def init_model(model_name, model_path, args, config):
+ if model_name == 's2g_face':
+ generator = s2g_face(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_vq':
+ generator = s2g_body_vq(
+ args,
+ config,
+ )
+ elif model_name == 's2g_body_pixel':
+ generator = s2g_body_pixel(
+ args,
+ config,
+ )
+ else:
+ raise NotImplementedError
+
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
+ if model_name == 'smplx_S2G':
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
+ elif 'generator' in list(model_ckpt.keys()):
+ generator.load_state_dict(model_ckpt['generator'])
+ else:
+ model_ckpt = {'generator': model_ckpt}
+ generator.load_state_dict(model_ckpt)
+
+ return generator
+
+
+def init_dataloader(data_root, speakers, args, config):
+ data_base = torch_data(
+ data_root=data_root,
+ speakers=speakers,
+ split='test',
+ limbscaling=False,
+ normalization=config.Data.pose.normalization,
+ norm_method=config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=config.Data.pose.pre_pose_length,
+ num_generate_length=config.Data.pose.generate_length,
+ num_frames=30,
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
+ feat_method=config.Data.aud.feat_method,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=config.Data.pose.convert_to_6d,
+ expression=config.Data.pose.expression,
+ config=config
+ )
+
+ if config.Data.pose.normalization:
+ norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
+ norm_stats = np.load(norm_stats_fn, allow_pickle=True)
+ data_base.data_mean = norm_stats[0]
+ data_base.data_std = norm_stats[1]
+ else:
+ norm_stats = None
+
+ data_base.get_dataset()
+ test_set = data_base.all_dataset
+ test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
+
+ return test_set, test_loader, norm_stats
+
+
+def face_loss(gt, gt_param, pr, pr_param):
+ loss_dict = {}
+
+ jaw_xyz = gt[:, 22:25, :] - pr[:, 22:25, :]
+ jaw_dist = jaw_xyz.norm(p=2, dim=-1)
+ jaw_dist = jaw_dist.sum(dim=-1).mean()
+ loss_dict['jaw_l1'] = jaw_dist
+
+ landmark_xyz = gt[:, 74:] - pr[:, 74:]
+ landmark_dist = landmark_xyz.norm(p=2, dim=-1)
+ landmark_dist = landmark_dist.sum(dim=-1).mean()
+ loss_dict['landmark_l1'] = landmark_dist
+
+ face_gt = torch.cat([gt[:, 22:25], gt[:, 74:]], dim=1)
+ face_pr = torch.cat([pr[:, 22:25], pr[:, 74:]], dim=1)
+
+ loss_dict['LVD'] = LVD(face_gt, face_pr, symmetrical=False, weight=False)
+
+ return loss_dict
+
+
+def test(test_loader, generator, smplx_model, args, config):
+ print('start testing')
+
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
+ am_sr = 16000
+
+ loss_dict = {}
+ with torch.no_grad():
+ i = 0
+ for bat in tqdm(test_loader, desc="Testing......"):
+ i = i + 1
+ aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
+ bat['expression'].to('cuda').to(torch.float32)
+ id = bat['speaker'].to('cuda') - 20
+ betas = bat['betas'][0].to('cuda').to(torch.float64)
+ poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
+ # poses = to3d(poses, config)
+
+ cur_wav_file = bat['aud_file'][0]
+ pred_face = generator.infer_on_audio(cur_wav_file,
+ id=id,
+ frame=poses.shape[0],
+ am=am,
+ am_sr=am_sr
+ )
+
+ pred_face = torch.tensor(pred_face).to('cuda').squeeze()
+ if pred_face.shape[1] > 103:
+ pred_face = pred_face[:, :103]
+ zero_poses = torch.zeros([pred_face.shape[0], 162], device='cuda')
+
+ full_param = torch.cat([pred_face[:, :3], zero_poses, pred_face[:, 3:]], dim=-1)
+
+ poses[:, 3:165] = full_param[:, 3:165]
+ gt_joints = get_joints(smplx_model, betas, poses)
+ pred_joints = get_joints(smplx_model, betas, full_param)
+ bat_loss_dict = face_loss(gt_joints, poses, pred_joints, full_param)
+
+ if loss_dict: # 非空
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] += bat_loss_dict[key]
+ else:
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] = bat_loss_dict[key]
+ for key in loss_dict.keys():
+ loss_dict[key] = loss_dict[key] / i
+ print(key + '=' + str(loss_dict[key].item()))
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init dataloader...')
+ test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
+ print('init model...')
+ face_model_name = args.face_model_name
+ face_model_path = args.face_model_path
+ generator_face = init_model(face_model_name, face_model_path, args, config)
+
+ print('init smlpx model...')
+ dtype = torch.float64
+ smplx_path = './visualise/'
+ model_params = dict(model_path=smplx_path,
+ model_type='smplx',
+ create_global_orient=True,
+ create_body_pose=True,
+ create_betas=True,
+ num_betas=300,
+ create_left_hand_pose=True,
+ create_right_hand_pose=True,
+ use_pca=False,
+ flat_hand_mean=False,
+ create_expression=True,
+ num_expression_coeffs=100,
+ num_pca_comps=12,
+ create_jaw_pose=True,
+ create_leye_pose=True,
+ create_reye_pose=True,
+ create_transl=False,
+ dtype=dtype, )
+ smplx_model = smpl.create(**model_params).to('cuda')
+
+ test(test_loader, generator_face, smplx_model, args, config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_vq.py b/scripts/test_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..9204d404030731118cca08c204314fbf84138c47
--- /dev/null
+++ b/scripts/test_vq.py
@@ -0,0 +1,91 @@
+import os
+import sys
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+sys.path.append(os.getcwd())
+
+from tqdm import tqdm
+from transformers import Wav2Vec2Processor
+
+from evaluation.metrics import LVD
+
+import numpy as np
+import smplx as smpl
+
+from data_utils.lower_body import part2full, poses2pred, c_index_3d
+from nets import *
+from nets.utils import get_path, get_dpath
+from trainer.options import parse_args
+from data_utils import torch_data
+from trainer.config import load_JsonConfig
+
+import torch
+from torch.utils import data
+from data_utils.get_j import to3d, get_joints
+from scripts.test_body import init_model, init_dataloader
+
+
+def test(test_loader, generator, config):
+ print('start testing')
+
+ loss_dict = {}
+ B = 1
+ with torch.no_grad():
+ count = 0
+ for bat in tqdm(test_loader, desc="Testing......"):
+ count = count + 1
+ aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
+ bat['expression'].to('cuda').to(torch.float32)
+ id = bat['speaker'].to('cuda') - 20
+ betas = bat['betas'][0].to('cuda').to(torch.float64)
+ poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
+ poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2)
+ # poses = poses[:, c_index_3d, :]
+
+ cur_wav_file = bat['aud_file'][0]
+
+ pred = generator.infer_on_audio(cur_wav_file,
+ initial_pose=poses,
+ id=id,
+ fps=30,
+ B=B
+ )
+ pred = torch.tensor(pred, device='cuda')
+ bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()}
+
+ if loss_dict: # 非空
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] += bat_loss_dict[key]
+ else:
+ for key in list(bat_loss_dict.keys()):
+ loss_dict[key] = bat_loss_dict[key]
+ for key in loss_dict.keys():
+ loss_dict[key] = loss_dict[key] / count
+ print(key + '=' + str(loss_dict[key].item()))
+
+
+def main():
+ parser = parse_args()
+ args = parser.parse_args()
+ device = torch.device(args.gpu)
+ torch.cuda.set_device(device)
+
+ config = load_JsonConfig(args.config_file)
+
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
+ os.environ['extra_joint_path'] = config.extra_joint_path
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
+
+ print('init dataloader...')
+ test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
+ print('init model...')
+ model_name = 's2g_body_vq'
+ model_type = 'n_com_8192'
+ model_path = get_path(model_name, model_type)
+ generator = init_model(model_name, model_path, args, config)
+
+ test(test_loader, generator, config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..808847d3da170fa251fb6a7aeb6a0b252860762b
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,11 @@
+import os
+import sys
+# os.chdir('/home/jovyan/Co-Speech-Motion-Generation/src')
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+sys.path.append(os.getcwd())
+
+from trainer import Trainer
+
+if __name__ == '__main__':
+ trainer = Trainer()
+ trainer.train()
\ No newline at end of file
diff --git a/trainer/Trainer.py b/trainer/Trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2d51961e5333ae0f71d20c610eff1ce88043654
--- /dev/null
+++ b/trainer/Trainer.py
@@ -0,0 +1,278 @@
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from data_utils import torch_data
+
+from trainer.options import parse_args
+from trainer.config import load_JsonConfig
+from nets.init_model import init_model
+
+import torch
+import torch.utils.data as data
+import torch.optim as optim
+import numpy as np
+import random
+import logging
+import time
+import shutil
+
+def prn_obj(obj):
+ print('\n'.join(['%s:%s' % item for item in obj.__dict__.items()]))
+
+
+
+
+
+class Trainer():
+ def __init__(self) -> None:
+ parser = parse_args()
+ self.args = parser.parse_args()
+ self.config = load_JsonConfig(self.args.config_file)
+
+ os.environ['smplx_npz_path']=self.config.smplx_npz_path
+ os.environ['extra_joint_path']=self.config.extra_joint_path
+ os.environ['j14_regressor_path']=self.config.j14_regressor_path
+
+ # torch.set_default_dtype(torch.float64)
+ # wandb_run = wandb.init(project=f's2g_sweep')
+
+ # if self.args.use_wandb:
+ # print('starting wandb sweep agent...')
+ # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d'
+ # os.environ['WANDB_API_KEY'] = wandb_key
+ #
+ # default_config=dict(w_b=1,w_h=10)
+ # wandb.init(config=default_config)
+ # self.config.param.w_b=wandb.config.w_b
+ # self.config.param.w_h=wandb.config.w_h
+ # self.config.Train.epochs=30
+
+ # if self.args.use_wandb:
+ # print('starting wandb sweep agent...')
+ # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d'
+ # os.environ['WANDB_API_KEY'] = wandb_key
+ #
+ # wandb.init(config=self.args, project="s2g_sweep")
+ # # wandb.config.update(self.args)
+ #
+ # self.config.param.w_b=self.args.w_b
+ # self.config.param.w_h=self.args.w_h
+ # self.config.Train.epochs=30
+
+ self.device = torch.device(self.args.gpu)
+ torch.cuda.set_device(self.device)
+ self.setup_seed(self.args.seed)
+ self.set_train_dir()
+
+ shutil.copy(self.args.config_file, self.train_dir)
+
+ self.generator = init_model(self.config.Model.model_name, self.args, self.config)
+ self.init_dataloader()
+ self.start_epoch = 0
+ self.global_steps = 0
+ if self.args.resume:
+ self.resume()
+ # self.init_optimizer()
+
+ def setup_seed(self, seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+ def set_train_dir(self):
+ time_stamp = time.strftime('%Y-%m-%d',time.localtime(time.time()))
+ train_dir = os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(
+ time_stamp + '-' + self.args.exp_name + '-' + self.config.Log.name))
+ # train_dir= os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(time_stamp+'-'+self.args.exp_name+'-'+time.strftime("%H:%M:%S")))
+ os.makedirs(train_dir, exist_ok=True)
+ log_file=os.path.join(train_dir, 'train.log')
+
+ fmt="%(asctime)s-%(lineno)d-%(message)s"
+ logging.basicConfig(
+ stream=sys.stdout, level=logging.INFO,format=fmt, datefmt='%m/%d %I:%M:%S %p'
+ )
+ fh=logging.FileHandler(log_file)
+ fh.setFormatter(logging.Formatter(fmt))
+ logging.getLogger().addHandler(fh)
+ self.train_dir = train_dir
+
+ def resume(self):
+ print('resume from a previous ckpt')
+ ckpt = torch.load(self.args.pretrained_pth)
+ self.generator.load_state_dict(ckpt['generator'])
+ self.start_epoch = ckpt['epoch']
+ self.global_steps = ckpt['global_steps']
+ self.generator.global_step = self.global_steps
+
+
+ def init_dataloader(self):
+ if 'freeMo' in self.config.Model.model_name:
+ if self.config.Data.data_root.endswith('.csv'):
+ raise NotImplementedError
+ else:
+ data_class = torch_data
+
+ self.train_set = data_class(
+ data_root=self.config.Data.data_root,
+ speakers=self.args.speakers,
+ split='train',
+ limbscaling=self.config.Data.pose.augmentation,
+ normalization=self.config.Data.pose.normalization,
+ norm_method=self.config.Data.pose.norm_method,
+ split_trans_zero=True,
+ num_pre_frames=self.config.Data.pose.pre_pose_length,
+ num_frames=self.config.Data.pose.generate_length,
+ aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=self.config.Data.aud.aud_feat_dim,
+ feat_method=self.config.Data.aud.feat_method,
+ context_info=self.config.Data.aud.context_info
+ )
+
+ if self.config.Data.pose.normalization:
+ self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
+ save_file = os.path.join(self.train_dir, 'norm_stats.npy')
+ np.save(save_file, self.norm_stats, allow_pickle=True)
+
+ self.train_set.get_dataset()
+ self.trans_set = self.train_set.trans_dataset
+ self.zero_set = self.train_set.zero_dataset
+
+ self.trans_loader = data.DataLoader(self.trans_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
+ self.zero_loader = data.DataLoader(self.zero_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
+ elif 'smplx' in self.config.Model.model_name or 's2g' in self.config.Model.model_name:
+ data_class = torch_data
+
+ self.train_set = data_class(
+ data_root=self.config.Data.data_root,
+ speakers=self.args.speakers,
+ split='train',
+ limbscaling=self.config.Data.pose.augmentation,
+ normalization=self.config.Data.pose.normalization,
+ norm_method=self.config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=self.config.Data.pose.pre_pose_length,
+ num_frames=self.config.Data.pose.generate_length,
+ num_generate_length=self.config.Data.pose.generate_length,
+ aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=self.config.Data.aud.aud_feat_dim,
+ feat_method=self.config.Data.aud.feat_method,
+ context_info=self.config.Data.aud.context_info,
+ smplx=True,
+ audio_sr=22000,
+ convert_to_6d=self.config.Data.pose.convert_to_6d,
+ expression=self.config.Data.pose.expression,
+ config=self.config
+ )
+ if self.config.Data.pose.normalization:
+ self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
+ save_file = os.path.join(self.train_dir, 'norm_stats.npy')
+ np.save(save_file, self.norm_stats, allow_pickle=True)
+ self.train_set.get_dataset()
+ self.train_loader = data.DataLoader(self.train_set.all_dataset,
+ batch_size=self.config.DataLoader.batch_size, shuffle=True,
+ num_workers=self.config.DataLoader.num_workers, drop_last=True)
+ else:
+ data_class = torch_data
+
+ self.train_set = data_class(
+ data_root=self.config.Data.data_root,
+ speakers=self.args.speakers,
+ split='train',
+ limbscaling=self.config.Data.pose.augmentation,
+ normalization=self.config.Data.pose.normalization,
+ norm_method=self.config.Data.pose.norm_method,
+ split_trans_zero=False,
+ num_pre_frames=self.config.Data.pose.pre_pose_length,
+ num_frames=self.config.Data.pose.generate_length,
+ aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
+ aud_feat_dim=self.config.Data.aud.aud_feat_dim,
+ feat_method=self.config.Data.aud.feat_method,
+ context_info=self.config.Data.aud.context_info
+ )
+
+ if self.config.Data.pose.normalization:
+ self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
+ save_file = os.path.join(self.train_dir, 'norm_stats.npy')
+ np.save(save_file, self.norm_stats, allow_pickle=True)
+
+ self.train_set.get_dataset()
+
+ self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
+
+
+ def init_optimizer(self):
+ pass
+
+ def print_func(self, loss_dict, steps):
+ info_str = ['global_steps:%d'%(self.global_steps)]
+ info_str += ['%s:%.4f'%(key, loss_dict[key]/steps) for key in list(loss_dict.keys())]
+ logging.info(','.join(info_str))
+
+ def save_model(self, epoch):
+ # if 'vq' in self.config.Model.model_name:
+ # state_dict = {
+ # 'g_body': self.g_body.state_dict(),
+ # 'g_hand': self.g_hand.state_dict(),
+ # 'epoch': epoch,
+ # 'global_steps': self.global_steps
+ # }
+ # else:
+ state_dict = {
+ 'generator': self.generator.state_dict(),
+ 'epoch': epoch,
+ 'global_steps': self.global_steps
+ }
+ save_name = os.path.join(self.train_dir, 'ckpt-%d.pth'%(epoch))
+ torch.save(state_dict, save_name)
+
+ def train_epoch(self, epoch):
+ epoch_loss_dict = {} #最好是追踪每个epoch的loss变换
+ epoch_steps = 0
+ if 'freeMo' in self.config.Model.model_name:
+ for bat in zip(self.trans_loader, self.zero_loader):
+ self.global_steps += 1
+ epoch_steps += 1
+ _, loss_dict = self.generator(bat)
+
+ if epoch_loss_dict:#非空
+ for key in list(loss_dict.keys()):
+ epoch_loss_dict[key] += loss_dict[key]
+ else:
+ for key in list(loss_dict.keys()):
+ epoch_loss_dict[key] = loss_dict[key]
+
+ if self.global_steps % self.config.Log.print_every == 0:
+ self.print_func(epoch_loss_dict, epoch_steps)
+ else:
+ # self.config.Model.model_name==smplx_S2G
+ for bat in self.train_loader:
+ # if epoch_steps == 1000:
+ # break
+ self.global_steps += 1
+ epoch_steps += 1
+ bat['epoch'] = epoch
+
+ _, loss_dict = self.generator(bat)
+ if epoch_loss_dict:#非空
+ for key in list(loss_dict.keys()):
+ epoch_loss_dict[key] += loss_dict[key]
+ else:
+ for key in list(loss_dict.keys()):
+ epoch_loss_dict[key] = loss_dict[key]
+ if self.global_steps % self.config.Log.print_every == 0:
+ self.print_func(epoch_loss_dict, epoch_steps)
+
+ def train(self):
+ logging.info('start_training')
+ self.total_loss_dict = {}
+ for epoch in range(self.start_epoch, self.config.Train.epochs):
+ logging.info('epoch:%d'%(epoch))
+ self.train_epoch(epoch)
+ # self.generator.scheduler.step()
+ # logging.info('learning rate:%d' % (self.generator.scheduler.get_lr()[0]))
+ if (epoch+1)%self.config.Log.save_every == 0 or (epoch+1) == 30:
+ self.save_model(epoch)
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc802e6491847c16f3a8265802d5e4979cf9ce6
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1 @@
+from .Trainer import Trainer
\ No newline at end of file
diff --git a/trainer/__pycache__/Trainer.cpython-37.pyc b/trainer/__pycache__/Trainer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a634dc552a2c11cc8c28f9592fcaa8100e0c00b1
Binary files /dev/null and b/trainer/__pycache__/Trainer.cpython-37.pyc differ
diff --git a/trainer/__pycache__/__init__.cpython-37.pyc b/trainer/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f75a56ff65a254623cd0b821a5f0facf4c38ca1a
Binary files /dev/null and b/trainer/__pycache__/__init__.cpython-37.pyc differ
diff --git a/trainer/__pycache__/config.cpython-37.pyc b/trainer/__pycache__/config.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db1272522833f6dc0882de8a6cc3661619066c89
Binary files /dev/null and b/trainer/__pycache__/config.cpython-37.pyc differ
diff --git a/trainer/__pycache__/options.cpython-37.pyc b/trainer/__pycache__/options.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10583bab84a08378a3f073009c4b4e44677bb366
Binary files /dev/null and b/trainer/__pycache__/options.cpython-37.pyc differ
diff --git a/trainer/config.py b/trainer/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..01fe699efefc9e6268608746f0d5b6f36e4aa977
--- /dev/null
+++ b/trainer/config.py
@@ -0,0 +1,27 @@
+'''
+load config from json file
+'''
+import json
+import os
+
+import configparser
+
+
+class Object():
+ def __init__(self, config:dict) -> None:
+ for key in list(config.keys()):
+ if isinstance(config[key], dict):
+ setattr(self, key, Object(config[key]))
+ else:
+ setattr(self, key, config[key])
+
+def load_JsonConfig(json_file):
+ with open(json_file, 'r') as f:
+ config = json.load(f)
+
+ return Object(config)
+
+
+if __name__ == '__main__':
+ config = load_JsonConfig('config/style_gestures.json')
+ print(dir(config))
\ No newline at end of file
diff --git a/trainer/options.py b/trainer/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..031b5147c3d0a9b12d63078243608b0c43286a3c
--- /dev/null
+++ b/trainer/options.py
@@ -0,0 +1,37 @@
+from argparse import ArgumentParser
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument('--gpu', default=0, type=int)
+ parser.add_argument('--save_dir', default='experiments', type=str)
+ parser.add_argument('--exp_name', default='smplx_S2G', type=str)
+ parser.add_argument('--speakers', nargs='+')
+ parser.add_argument('--seed', default=1, type=int)
+ parser.add_argument('--model_name', type=str)
+
+ #for Tmpt and S2G
+ parser.add_argument('--use_template', action='store_true')
+ parser.add_argument('--template_length', default=0, type=int)
+
+ #for training from a ckpt
+ parser.add_argument('--resume', action='store_true')
+ parser.add_argument('--pretrained_pth', default=None, type=str)
+ parser.add_argument('--style_layer_norm', action='store_true')
+
+ #required
+ parser.add_argument('--config_file', default='./config/style_gestures.json', type=str)
+
+ # for visualization and test
+ parser.add_argument('--audio_file', default=None, type=str)
+ parser.add_argument('--id', default=0, type=int, help='0=oliver, 1=chemistry, 2=seth, 3=conan')
+ parser.add_argument('--only_face', action='store_true')
+ parser.add_argument('--stand', action='store_true')
+ parser.add_argument('--whole_body', action='store_true')
+ parser.add_argument('--num_sample', default=1, type=int)
+ parser.add_argument('--face_model_name', default='s2g_face', type=str)
+ parser.add_argument('--face_model_path', default='./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth', type=str)
+ parser.add_argument('--body_model_name', default='s2g_body_pixel', type=str)
+ parser.add_argument('--body_model_path', default='./experiments/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth', type=str)
+ parser.add_argument('--infer', action='store_true')
+
+ return parser
\ No newline at end of file
diff --git a/trainer/training_config.cfg b/trainer/training_config.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..513c2ddea486a021bb28ff0a1e3cb51faff9b327
--- /dev/null
+++ b/trainer/training_config.cfg
@@ -0,0 +1,48 @@
+[Input Output]
+checkpoint_dir = ./training
+expression_basis_fname = ./training_data/init_expression_basis.npy
+template_fname = ./template/FLAME_sample.ply
+deepspeech_graph_fname = ./ds_graph/output_graph.pb
+face_or_body = body
+verts_mmaps_path = ./training_data/data_verts.npy
+raw_audio_path = ./training_data/raw_audio_fixed.pkl
+processed_audio_path = ./training_data/processed_audio_deepspeech.pkl
+templates_path = ./training_data/templates.pkl
+data2array_verts_path = ./training_data/subj_seq_to_idx.pkl
+
+[Audio Parameters]
+audio_feature_type = deepspeech
+num_audio_features = 29
+audio_window_size = 16
+audio_window_stride = 1
+condition_speech_features = True
+speech_encoder_size_factor = 1.0
+
+[Model Parameters]
+num_vertices = 10475
+expression_dim = 50
+init_expression = False
+num_consecutive_frames = 30
+absolute_reconstruction_loss = False
+velocity_weight = 10.0
+acceleration_weight = 0.0
+verts_regularizer_weight = 0.0
+
+[Data Setup]
+subject_for_training = speeker_oliver
+sequence_for_training = 0-00'00'05-00'00'10 1-00'00'32-00'00'37 2-00'01'05-00'01'10
+subject_for_validation = speeker_oliver
+sequence_for_validation = 2-00'01'05-00'01'10
+subject_for_testing = speeker_oliver
+sequence_for_testing = 2-00'01'05-00'01'10
+
+[Learning Parameters]
+batch_size = 64
+learning_rate = 1e-4
+decay_rate = 1.0
+epoch_num = 1000
+adam_beta1_value = 0.9
+
+[Visualization Parameters]
+num_render_sequences = 3
+
diff --git a/voca/__pycache__/rendering.cpython-37.pyc b/voca/__pycache__/rendering.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4856c56a47519fe7bb6437ca8f0acf40eb29038b
Binary files /dev/null and b/voca/__pycache__/rendering.cpython-37.pyc differ
diff --git a/voca/rendering.py b/voca/rendering.py
new file mode 100644
index 0000000000000000000000000000000000000000..68e933f2f63736561c7b74d29c1335eecee9a22a
--- /dev/null
+++ b/voca/rendering.py
@@ -0,0 +1,177 @@
+'''
+Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
+computer program.
+
+You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
+the computer program from someone who is authorized to grant you that right.
+
+Any use of the computer program without a valid license is prohibited and liable to prosecution.
+
+Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
+Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
+All rights reserved.
+
+More information about VOCA is available at http://voca.is.tue.mpg.de.
+For comments or questions, please email us at voca@tue.mpg.de
+'''
+
+from __future__ import division
+import os
+# os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # Uncommnet this line while running remotely
+import cv2
+import pyrender
+import trimesh
+import tempfile
+import numpy as np
+import matplotlib as mpl
+import matplotlib.cm as cm
+
+
+def get_unit_factor(unit):
+ if unit == 'mm':
+ return 1000.0
+ elif unit == 'cm':
+ return 100.0
+ elif unit == 'm':
+ return 1.0
+ else:
+ raise ValueError('Unit not supported')
+
+
+def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, v_colors=None,
+ errors=None, error_unit='m', min_dist_in_mm=0.0, max_dist_in_mm=3.0, z_offset=1.0, xmag=0.5,
+ y=0.7, z=1, camera='o', r=None):
+ camera_params = {'c': np.array([0, 0]),
+ 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
+ 'f': np.array([5000, 5000])}
+
+ frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}
+
+ v, f = mesh
+ v = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
+
+ texture_rendering = tex_img is not None and hasattr(mesh, 'vt') and hasattr(mesh, 'ft')
+ if texture_rendering:
+ intensity = 0.5
+ tex = pyrender.Texture(source=tex_img, source_channels='RGB')
+ material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
+
+ # Workaround as pyrender requires number of vertices and uv coordinates to be the same
+ temp_filename = '%s.obj' % next(tempfile._get_candidate_names())
+ mesh.write_obj(temp_filename)
+ tri_mesh = trimesh.load(temp_filename, process=False)
+ try:
+ os.remove(temp_filename)
+ except:
+ print('Failed deleting temporary file - %s' % temp_filename)
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=material)
+ elif errors is not None:
+ intensity = 0.5
+ unit_factor = get_unit_factor('mm') / get_unit_factor(error_unit)
+ errors = unit_factor * errors
+
+ norm = mpl.colors.Normalize(vmin=min_dist_in_mm, vmax=max_dist_in_mm)
+ cmap = cm.get_cmap(name='jet')
+ colormapper = cm.ScalarMappable(norm=norm, cmap=cmap)
+ rgba_per_v = colormapper.to_rgba(errors)
+ rgb_per_v = rgba_per_v[:, 0:3]
+ elif v_colors is not None:
+ intensity = 0.5
+ rgb_per_v = v_colors
+ else:
+ intensity = 6.
+ rgb_per_v = None
+
+ color = np.array([0.3, 0.5, 0.55])
+
+ if not texture_rendering:
+ tri_mesh = trimesh.Trimesh(vertices=v, faces=f, vertex_colors=rgb_per_v)
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh,
+ smooth=True,
+ material=pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.05,
+ roughnessFactor=0.7,
+ alphaMode='OPAQUE',
+ baseColorFactor=(color[0], color[1], color[2], 1.0)
+ ))
+
+ scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
+
+ if camera == 'o':
+ ymag = xmag * z_offset
+ camera = pyrender.OrthographicCamera(xmag=xmag, ymag=ymag)
+ elif camera == 'i':
+ camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
+ fy=camera_params['f'][1],
+ cx=camera_params['c'][0],
+ cy=camera_params['c'][1],
+ znear=frustum['near'],
+ zfar=frustum['far'])
+ elif camera == 'y':
+ camera = pyrender.PerspectiveCamera(yfov=(np.pi / 2.0))
+
+ scene.add(render_mesh, pose=np.eye(4))
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = np.array([0, 0.7, 1.0 - z_offset])
+ scene.add(camera, pose=[[1, 0, 0, 0],
+ [0, 1, 0, y], # 0.25
+ [0, 0, 1, z], # 0.2
+ [0, 0, 0, 1]])
+
+
+ angle = np.pi / 6.0
+ # pos = camera_pose[:3,3]
+ pos = np.array([0, 0.7, 2.0])
+ if False:
+ light_color = np.array([1., 1., 1.])
+ light = pyrender.DirectionalLight(color=light_color, intensity=intensity)
+
+ light_pose = np.eye(4)
+ light_pose[:3, 3] = np.array([0, 0.7, 2.0])
+ scene.add(light, pose=light_pose.copy())
+ else:
+ light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=2)
+ light_pose = np.eye(4)
+ light_pose[:3, 3] = [0, -1, 1]
+ scene.add(light, pose=light_pose)
+
+ light_pose[:3, 3] = [0, 1, 1]
+ scene.add(light, pose=light_pose)
+
+ light_pose[:3, 3] = [-1, 1, 2]
+ scene.add(light, pose=light_pose)
+
+ spot_l = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
+ innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2)
+
+ light_pose[:3, 3] = [-1, 2, 2]
+ scene.add(spot_l, pose=light_pose)
+
+ light_pose[:3, 3] = [1, 2, 2]
+ scene.add(spot_l, pose=light_pose)
+
+ # light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
+ # scene.add(light, pose=light_pose.copy())
+ #
+ # light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
+ # scene.add(light, pose=light_pose.copy())
+ #
+ # light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
+ # scene.add(light, pose=light_pose.copy())
+ #
+ # light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
+ # scene.add(light, pose=light_pose.copy())
+
+ # pyrender.Viewer(scene)
+
+ flags = pyrender.RenderFlags.SKIP_CULL_FACES
+ # try:
+ # r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
+ color, _ = r.render(scene, flags=flags)
+ # r.delete()
+ # except:
+ # print('pyrender: Failed rendering frame')
+ # color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')
+
+ return color[..., ::-1]