diff --git a/evaluation/FGD.py b/evaluation/FGD.py new file mode 100644 index 0000000000000000000000000000000000000000..d5521ee30b8751ef3bdd980ee91231ab78fea6e5 --- /dev/null +++ b/evaluation/FGD.py @@ -0,0 +1,199 @@ +import time + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import linalg +import math +from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d + +import warnings +warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings + + +change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04]) +class EmbeddingSpaceEvaluator: + def __init__(self, ae, vae, device): + + # init embed net + self.ae = ae + # self.vae = vae + + # storage + self.real_feat_list = [] + self.generated_feat_list = [] + self.real_joints_list = [] + self.generated_joints_list = [] + self.real_6d_list = [] + self.generated_6d_list = [] + self.audio_beat_list = [] + + def reset(self): + self.real_feat_list = [] + self.generated_feat_list = [] + + def get_no_of_samples(self): + return len(self.real_feat_list) + + def push_samples(self, generated_poses, real_poses): + # self.net.eval() + # convert poses to latent features + real_feat, real_poses = self.ae.extract(real_poses) + generated_feat, generated_poses = self.ae.extract(generated_poses) + + num_joints = real_poses.shape[2] // 3 + + real_feat = real_feat.squeeze() + generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1) + + self.real_feat_list.append(real_feat.data.cpu().numpy()) + self.generated_feat_list.append(generated_feat.data.cpu().numpy()) + + # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6) + # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6) + # + # self.real_feat_list.append(real_poses.data.cpu().numpy()) + # self.generated_feat_list.append(generated_poses.data.cpu().numpy()) + + def push_joints(self, generated_poses, real_poses): + self.real_joints_list.append(real_poses.data.cpu()) + self.generated_joints_list.append(generated_poses.squeeze().data.cpu()) + + def push_aud(self, aud): + self.audio_beat_list.append(aud.squeeze().data.cpu()) + + def get_MAAC(self): + ang_vel_list = [] + for real_joints in self.real_joints_list: + real_joints[:, 15:21] = real_joints[:, 16:22] + vec = real_joints[:, 15:21] - real_joints[:, 13:19] + inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]]) + inner_product = torch.clamp(inner_product, -1, 1, out=None) + angle = torch.acos(inner_product) / math.pi + ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0) + ang_vel_list.append(ang_vel.unsqueeze(dim=0)) + all_vel = torch.cat(ang_vel_list, dim=0) + MAAC = all_vel.mean(dim=0) + return MAAC + + def get_BCscore(self): + thres = 0.01 + sigma = 0.1 + sum_1 = 0 + total_beat = 0 + for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list): + motion_beat_time = [] + if joints.dim() == 4: + joints = joints[0] + joints[:, 15:21] = joints[:, 16:22] + vec = joints[:, 15:21] - joints[:, 13:19] + inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]]) + inner_product = torch.clamp(inner_product, -1, 1, out=None) + angle = torch.acos(inner_product) / math.pi + ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle) + + angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0) + + sum_2 = 0 + for i in range(angle_diff.shape[1]): + motion_beat_time = [] + for t in range(1, joints.shape[0]-1): + if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]): + if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[ + t][i] >= thres): + motion_beat_time.append(float(t) / 30.0) + if (len(motion_beat_time) == 0): + continue + motion_beat_time = torch.tensor(motion_beat_time) + sum = 0 + for audio in audio_beat_time: + sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma)) + sum_2 = sum_2 + sum + total_beat = total_beat + len(audio_beat_time) + sum_1 = sum_1 + sum_2 + return sum_1/total_beat + + + def get_scores(self): + generated_feats = np.vstack(self.generated_feat_list) + real_feats = np.vstack(self.real_feat_list) + + def frechet_distance(samples_A, samples_B): + A_mu = np.mean(samples_A, axis=0) + A_sigma = np.cov(samples_A, rowvar=False) + B_mu = np.mean(samples_B, axis=0) + B_sigma = np.cov(samples_B, rowvar=False) + try: + frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) + except ValueError: + frechet_dist = 1e+10 + return frechet_dist + + #################################################################### + # frechet distance + frechet_dist = frechet_distance(generated_feats, real_feats) + + #################################################################### + # distance between real and generated samples on the latent feature space + dists = [] + for i in range(real_feats.shape[0]): + d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE + dists.append(d) + feat_dist = np.mean(dists) + + return frechet_dist, feat_dist + + @staticmethod + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) \ No newline at end of file diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/evaluation/__pycache__/__init__.cpython-37.pyc b/evaluation/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1631bfe6e7bc4b06edead82accfc11190e66830 Binary files /dev/null and b/evaluation/__pycache__/__init__.cpython-37.pyc differ diff --git a/evaluation/__pycache__/metrics.cpython-37.pyc b/evaluation/__pycache__/metrics.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc7c877f5fd731130c82d1b67a180ac9b69cc251 Binary files /dev/null and b/evaluation/__pycache__/metrics.cpython-37.pyc differ diff --git a/evaluation/diversity_LVD.py b/evaluation/diversity_LVD.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd7dd81118692ea0c361c6016afa29d98665315 --- /dev/null +++ b/evaluation/diversity_LVD.py @@ -0,0 +1,64 @@ +''' +LVD: different initial pose +diversity: same initial pose +''' +import os +import sys +sys.path.append(os.getcwd()) + +from glob import glob + +from argparse import ArgumentParser +import json + +from evaluation.util import * +from evaluation.metrics import * +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument('--speaker', required=True, type=str) +parser.add_argument('--post_fix', nargs='+', default=['base'], type=str) +args = parser.parse_args() + +speaker = args.speaker +test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) + +LVD_list = [] +diversity_list = [] + +for aud in tqdm(test_audios): + base_name = os.path.splitext(aud)[0] + gt_path = get_full_path(aud, speaker, 'val') + _, gt_poses, _ = get_gts(gt_path) + gt_poses = gt_poses[np.newaxis,...] + # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face + for post_fix in args.post_fix: + pred_path = base_name + '_'+post_fix+'.json' + pred_poses = np.array(json.load(open(pred_path))) + # print(pred_poses.shape)#(B, seq_len, 108) + pred_poses = cvt25(pred_poses, gt_poses) + # print(pred_poses.shape)#(B, seq, pose_dim) + + gt_valid_points = hand_points(gt_poses) + pred_valid_points = hand_points(pred_poses) + + lvd = LVD(gt_valid_points, pred_valid_points) + # div = diversity(pred_valid_points) + + LVD_list.append(lvd) + # diversity_list.append(div) + + # gt_velocity = peak_velocity(gt_valid_points, order=2) + # pred_velocity = peak_velocity(pred_valid_points, order=2) + + # gt_consistency = velocity_consistency(gt_velocity, pred_velocity) + # pred_consistency = velocity_consistency(pred_velocity, gt_velocity) + + # gt_consistency_list.append(gt_consistency) + # pred_consistency_list.append(pred_consistency) + +lvd = np.mean(LVD_list) +# diversity_list = np.mean(diversity_list) + +print('LVD:', lvd) +# print("diversity:", diversity_list) \ No newline at end of file diff --git a/evaluation/get_quality_samples.py b/evaluation/get_quality_samples.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ef393cd310aa2e75f871122a62f45dd525e47c --- /dev/null +++ b/evaluation/get_quality_samples.py @@ -0,0 +1,62 @@ +''' +''' +import os +import sys +sys.path.append(os.getcwd()) + +from glob import glob + +from argparse import ArgumentParser +import json + +from evaluation.util import * +from evaluation.metrics import * +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument('--speaker', required=True, type=str) +parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) +args = parser.parse_args() + +speaker = args.speaker +test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) + +quality_samples={'gt':[]} +for post_fix in args.post_fix: + quality_samples[post_fix] = [] + +for aud in tqdm(test_audios): + base_name = os.path.splitext(aud)[0] + gt_path = get_full_path(aud, speaker, 'val') + _, gt_poses, _ = get_gts(gt_path) + gt_poses = gt_poses[np.newaxis,...] + gt_valid_points = valid_points(gt_poses) + # print(gt_valid_points.shape) + quality_samples['gt'].append(gt_valid_points) + + for post_fix in args.post_fix: + pred_path = base_name + '_'+post_fix+'.json' + pred_poses = np.array(json.load(open(pred_path))) + # print(pred_poses.shape)#(B, seq_len, 108) + pred_poses = cvt25(pred_poses, gt_poses) + # print(pred_poses.shape)#(B, seq, pose_dim) + + pred_valid_points = valid_points(pred_poses)[0:1] + quality_samples[post_fix].append(pred_valid_points) + +quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1) +for post_fix in args.post_fix: + quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1) + +print('gt:', quality_samples['gt'].shape) +quality_samples['gt'] = quality_samples['gt'].tolist() +for post_fix in args.post_fix: + print(post_fix, ':', quality_samples[post_fix].shape) + quality_samples[post_fix] = quality_samples[post_fix].tolist() + +save_dir = '../../experiments/' +os.makedirs(save_dir, exist_ok=True) +save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker)) +with open(save_name, 'w') as f: + json.dump(quality_samples, f) + diff --git a/evaluation/metrics.py b/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..93dc8fde8b8de57cc2f0f7387afd5de7cb753835 --- /dev/null +++ b/evaluation/metrics.py @@ -0,0 +1,109 @@ +''' +Warning: metrics are for reference only, may have limited significance +''' +import os +import sys +sys.path.append(os.getcwd()) +import numpy as np +import torch + +from data_utils.lower_body import rearrange, symmetry +import torch.nn.functional as F + +def data_driven_baselines(gt_kps): + ''' + gt_kps: T, D + ''' + gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1]) + + mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D) + mean = np.mean(np.abs(gt_velocity-mean)) + last_step = gt_kps[1] - gt_kps[0] + last_step = last_step[np.newaxis] #(1, D) + last_step = np.mean(np.abs(gt_velocity-last_step)) + return last_step, mean + +def Batch_LVD(gt_kps, pr_kps, symmetrical, weight): + if gt_kps.shape[0] > pr_kps.shape[1]: + length = pr_kps.shape[1] + else: + length = gt_kps.shape[0] + gt_kps = gt_kps[:length] + pr_kps = pr_kps[:, :length] + global symmetry + symmetry = torch.tensor(symmetry).bool() + + if symmetrical: + # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints. + gt_kps = gt_kps[:, rearrange] + ns_gt_kps = gt_kps[:, ~symmetry] + ys_gt_kps = gt_kps[:, symmetry] + ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3) + ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1) + ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1) + left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1) + right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1) + move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda()) + ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool()) + ys_gt_velocity = ys_gt_velocity.transpose(0,1) + gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1) + + pr_kps = pr_kps[:, :, rearrange] + ns_pr_kps = pr_kps[:, :, ~symmetry] + ys_pr_kps = pr_kps[:, :, symmetry] + ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3) + ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1) + ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1) + left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1) + right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1) + move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(), + torch.zeros(left_pr_vel.shape).cuda()) + ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul( + ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long()) + ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0) + pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2) + else: + gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1) + pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1) + + if weight: + w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0) + else: + w = 1 / gt_velocity.shape[0] + + v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean() + + return v_diff + + +def LVD(gt_kps, pr_kps, symmetrical=False, weight=False): + gt_kps = gt_kps.squeeze() + pr_kps = pr_kps.squeeze() + if len(pr_kps.shape) == 4: + return Batch_LVD(gt_kps, pr_kps, symmetrical, weight) + # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0]) + length = gt_kps.shape[0]-10 + # gt_kps = gt_kps[25:length] + # pr_kps = pr_kps[25:length] #(T, D) + # if pr_kps.shape[0] < gt_kps.shape[0]: + # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant') + + gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1) + pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1) + + return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean() + +def diversity(kps): + ''' + kps: bs, seq, dim + ''' + dis_list = [] + #the distance between each pair + for i in range(kps.shape[0]): + for j in range(i+1, kps.shape[0]): + seq_i = kps[i] + seq_j = kps[j] + + dis = np.mean(np.abs(seq_i - seq_j)) + dis_list.append(dis) + return np.mean(dis_list) diff --git a/evaluation/mode_transition.py b/evaluation/mode_transition.py new file mode 100644 index 0000000000000000000000000000000000000000..92cd0e5ecfe688b7a8add932af0303bd8d5ed947 --- /dev/null +++ b/evaluation/mode_transition.py @@ -0,0 +1,60 @@ +import os +import sys +sys.path.append(os.getcwd()) + +from glob import glob + +from argparse import ArgumentParser +import json + +from evaluation.util import * +from evaluation.metrics import * +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument('--speaker', required=True, type=str) +parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) +args = parser.parse_args() + +speaker = args.speaker +test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) + +precision_list=[] +recall_list=[] +accuracy_list=[] + +for aud in tqdm(test_audios): + base_name = os.path.splitext(aud)[0] + gt_path = get_full_path(aud, speaker, 'val') + _, gt_poses, _ = get_gts(gt_path) + if gt_poses.shape[0] < 50: + continue + gt_poses = gt_poses[np.newaxis,...] + # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face + for post_fix in args.post_fix: + pred_path = base_name + '_'+post_fix+'.json' + pred_poses = np.array(json.load(open(pred_path))) + # print(pred_poses.shape)#(B, seq_len, 108) + pred_poses = cvt25(pred_poses, gt_poses) + # print(pred_poses.shape)#(B, seq, pose_dim) + + gt_valid_points = valid_points(gt_poses) + pred_valid_points = valid_points(pred_poses) + + # print(gt_valid_points.shape, pred_valid_points.shape) + + gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N) + pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N) + + # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape) + # pred_mode_transition_seq = baseline + precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq) + precision_list.append(precision) + recall_list.append(recall) + accuracy_list.append(accuracy) +print(len(precision_list), len(recall_list), len(accuracy_list)) +precision_list = np.mean(precision_list) +recall_list = np.mean(recall_list) +accuracy_list = np.mean(accuracy_list) + +print('precision, recall, accu:', precision_list, recall_list, accuracy_list) diff --git a/evaluation/peak_velocity.py b/evaluation/peak_velocity.py new file mode 100644 index 0000000000000000000000000000000000000000..3842b918375176099cd60a8f9ede50d8920b3e4a --- /dev/null +++ b/evaluation/peak_velocity.py @@ -0,0 +1,65 @@ +import os +import sys +sys.path.append(os.getcwd()) + +from glob import glob + +from argparse import ArgumentParser +import json + +from evaluation.util import * +from evaluation.metrics import * +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument('--speaker', required=True, type=str) +parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) +args = parser.parse_args() + +speaker = args.speaker +test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) + +gt_consistency_list=[] +pred_consistency_list=[] + +for aud in tqdm(test_audios): + base_name = os.path.splitext(aud)[0] + gt_path = get_full_path(aud, speaker, 'val') + _, gt_poses, _ = get_gts(gt_path) + gt_poses = gt_poses[np.newaxis,...] + # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face + for post_fix in args.post_fix: + pred_path = base_name + '_'+post_fix+'.json' + pred_poses = np.array(json.load(open(pred_path))) + # print(pred_poses.shape)#(B, seq_len, 108) + pred_poses = cvt25(pred_poses, gt_poses) + # print(pred_poses.shape)#(B, seq, pose_dim) + + gt_valid_points = hand_points(gt_poses) + pred_valid_points = hand_points(pred_poses) + + gt_velocity = peak_velocity(gt_valid_points, order=2) + pred_velocity = peak_velocity(pred_valid_points, order=2) + + gt_consistency = velocity_consistency(gt_velocity, pred_velocity) + pred_consistency = velocity_consistency(pred_velocity, gt_velocity) + + gt_consistency_list.append(gt_consistency) + pred_consistency_list.append(pred_consistency) + +gt_consistency_list = np.concatenate(gt_consistency_list) +pred_consistency_list = np.concatenate(pred_consistency_list) + +print(gt_consistency_list.max(), gt_consistency_list.min()) +print(pred_consistency_list.max(), pred_consistency_list.min()) +print(np.mean(gt_consistency_list), np.mean(pred_consistency_list)) +print(np.std(gt_consistency_list), np.std(pred_consistency_list)) + +draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue') +draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue') + +to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker)) +to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker)) + +np.save('%s_gt.npy'%(speaker), gt_consistency_list) +np.save('%s_pred.npy'%(speaker), pred_consistency_list) \ No newline at end of file diff --git a/evaluation/util.py b/evaluation/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c18a0ab89ca812e4a2126bd18ee8e588b7ec43 --- /dev/null +++ b/evaluation/util.py @@ -0,0 +1,148 @@ +import os +from glob import glob +import numpy as np +import json +from matplotlib import pyplot as plt +import pandas as pd +def get_gts(clip): + ''' + clip: abs path to the clip dir + ''' + keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json')) + + upper_body_points = list(np.arange(0, 25)) + poses = [] + confs = [] + neck_to_nose_len = [] + mean_position = [] + for kp_file in keypoints_files: + kp_load = json.load(open(kp_file, 'r'))['people'][0] + posepts = kp_load['pose_keypoints_2d'] + lhandpts = kp_load['hand_left_keypoints_2d'] + rhandpts = kp_load['hand_right_keypoints_2d'] + facepts = kp_load['face_keypoints_2d'] + + neck = np.array(posepts).reshape(-1,3)[1] + nose = np.array(posepts).reshape(-1,3)[0] + x_offset = abs(neck[0]-nose[0]) + y_offset = abs(neck[1]-nose[1]) + neck_to_nose_len.append(y_offset) + mean_position.append([neck[0],neck[1]]) + + keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2] + + upper_body = keypoints[upper_body_points, :] + hand_points = keypoints[25:, :] + keypoints = np.vstack([upper_body, hand_points]) + + poses.append(keypoints) + + if len(neck_to_nose_len) > 0: + scale_factor = np.mean(neck_to_nose_len) + else: + raise ValueError(clip) + mean_position = np.mean(np.array(mean_position), axis=0) + + unlocalized_poses = np.array(poses).copy() + localized_poses = [] + for i in range(len(poses)): + keypoints = poses[i] + neck = keypoints[1].copy() + + keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor + keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor + localized_poses.append(keypoints.reshape(-1)) + + localized_poses=np.array(localized_poses) + return unlocalized_poses, localized_poses, (scale_factor, mean_position) + +def get_full_path(wav_name, speaker, split): + ''' + get clip path from aud file + ''' + wav_name = os.path.basename(wav_name) + wav_name = os.path.splitext(wav_name)[0] + clip_name, vid_name = wav_name[:10], wav_name[11:] + + full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name) + + assert os.path.isdir(full_path), full_path + + return full_path + +def smooth(res): + ''' + res: (B, seq_len, pose_dim) + ''' + window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]] + w_size=7 + for i in range(10, res.shape[1]-3): + window.append(res[:, i+3, :]) + if len(window) > w_size: + window = window[1:] + + if (i%25) in [22, 23, 24, 0, 1, 2, 3]: + res[:, i, :] = np.mean(window, axis=1) + + return res + +def cvt25(pred_poses, gt_poses=None): + ''' + gt_poses: (1, seq_len, 270), 135 *2 + pred_poses: (B, seq_len, 108), 54 * 2 + ''' + if gt_poses is None: + gt_poses = np.zeros_like(pred_poses) + else: + gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0) + + length = min(pred_poses.shape[1], gt_poses.shape[1]) + pred_poses = pred_poses[:, :length, :] + gt_poses = gt_poses[:, :length, :] + gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2) + pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2) + + gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :] + gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :] + + return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1) + +def hand_points(seq): + ''' + seq: (B, seq_len, 135*2) + hands only + ''' + hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21)) + seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2) + return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1) + +def valid_points(seq): + ''' + hands with some head points + ''' + valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21)) + seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2) + + seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1) + assert seq.shape[-1] == 108, seq.shape + return seq + +def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'): + plt.figure() + plt.hist(seq, bins=100, range=(0, 100), color=color) + plt.savefig(save_name) + +def to_excel(seq, save_name='res.xlsx'): + ''' + seq: (T) + ''' + df = pd.DataFrame(seq) + writer = pd.ExcelWriter(save_name) + df.to_excel(writer, 'sheet1') + writer.save() + writer.close() + + +if __name__ == '__main__': + random_data = np.random.randint(0, 10, 100) + draw_cdf(random_data) \ No newline at end of file diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37fdea0eb6c3190e7001567cfe17dc296bf811e8 --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1 @@ +from .losses import * \ No newline at end of file diff --git a/losses/__pycache__/__init__.cpython-37.pyc b/losses/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a099ba9ea1527c7efdbd683bf96e5cd43dd8d932 Binary files /dev/null and b/losses/__pycache__/__init__.cpython-37.pyc differ diff --git a/losses/__pycache__/losses.cpython-37.pyc b/losses/__pycache__/losses.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2ede0fb38bc991969ee097ef54cc16cb499989 Binary files /dev/null and b/losses/__pycache__/losses.cpython-37.pyc differ diff --git a/losses/losses.py b/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e433ca256b3ed54b77fbc6ca8751aa32959153 --- /dev/null +++ b/losses/losses.py @@ -0,0 +1,91 @@ +import os +import sys + +sys.path.append(os.getcwd()) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class KeypointLoss(nn.Module): + def __init__(self): + super(KeypointLoss, self).__init__() + + def forward(self, pred_seq, gt_seq, gt_conf=None): + #pred_seq: (B, C, T) + if gt_conf is not None: + gt_conf = gt_conf >= 0.01 + return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean') + else: + return F.mse_loss(pred_seq, gt_seq) + + +class KLLoss(nn.Module): + def __init__(self, kl_tolerance): + super(KLLoss, self).__init__() + self.kl_tolerance = kl_tolerance + + def forward(self, mu, var, mul=1): + kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64 + kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1) + # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1) + if self.kl_tolerance is not None: + # above_line = kld_loss[kld_loss > self.kl_tolerance] + # if len(above_line) > 0: + # kld_loss = torch.mean(kld_loss) + # else: + # kld_loss = 0 + kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda')) + # else: + kld_loss = torch.mean(kld_loss) + return kld_loss + + +class L2KLLoss(nn.Module): + def __init__(self, kl_tolerance): + super(L2KLLoss, self).__init__() + self.kl_tolerance = kl_tolerance + + def forward(self, x): + # TODO: check + kld_loss = torch.sum(x ** 2, dim=1) + if self.kl_tolerance is not None: + above_line = kld_loss[kld_loss > self.kl_tolerance] + if len(above_line) > 0: + kld_loss = torch.mean(kld_loss) + else: + kld_loss = 0 + else: + kld_loss = torch.mean(kld_loss) + return kld_loss + +class L2RegLoss(nn.Module): + def __init__(self): + super(L2RegLoss, self).__init__() + + def forward(self, x): + #TODO: check + return torch.sum(x**2) + + +class L2Loss(nn.Module): + def __init__(self): + super(L2Loss, self).__init__() + + def forward(self, x): + # TODO: check + return torch.sum(x ** 2) + + +class AudioLoss(nn.Module): + def __init__(self): + super(AudioLoss, self).__init__() + + def forward(self, dynamics, gt_poses): + #pay attention, normalized + mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1) + gt = gt_poses - mean + return F.mse_loss(dynamics, gt) + +L1Loss = nn.L1Loss \ No newline at end of file diff --git a/nets/LS3DCG.py b/nets/LS3DCG.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b2728385c1c196501c73822644a461ead309ec --- /dev/null +++ b/nets/LS3DCG.py @@ -0,0 +1,414 @@ +''' +not exactly the same as the official repo but the results are good +''' +import sys +import os + +from data_utils.lower_body import c_index_3d, c_index_6d + +sys.path.append(os.getcwd()) + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import math + +from nets.base import TrainWrapperBaseClass +from nets.layers import SeqEncoder1D +from losses import KeypointLoss, L1Loss, KLLoss +from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta +from nets.utils import denormalize + +class Conv1d_tf(nn.Conv1d): + """ + Conv1d with the padding behavior from TF + modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py + """ + + def __init__(self, *args, **kwargs): + super(Conv1d_tf, self).__init__(*args, **kwargs) + self.padding = kwargs.get("padding", "same") + + def _compute_padding(self, input, dim): + input_size = input.size(dim + 2) + filter_size = self.weight.size(dim + 2) + effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 + out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] + total_padding = max( + 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size + ) + additional_padding = int(total_padding % 2 != 0) + + return additional_padding, total_padding + + def forward(self, input): + if self.padding == "VALID": + return F.conv1d( + input, + self.weight, + self.bias, + self.stride, + padding=0, + dilation=self.dilation, + groups=self.groups, + ) + rows_odd, padding_rows = self._compute_padding(input, dim=0) + if rows_odd: + input = F.pad(input, [0, rows_odd]) + + return F.conv1d( + input, + self.weight, + self.bias, + self.stride, + padding=(padding_rows // 2), + dilation=self.dilation, + groups=self.groups, + ) + + +def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'): + if k is None and s is None: + if not downsample: + k = 3 + s = 1 + else: + k = 4 + s = 2 + + if type == '1d': + conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + if norm == 'bn': + norm_block = nn.BatchNorm1d(out_channels) + elif norm == 'ln': + norm_block = nn.LayerNorm(out_channels) + elif type == '2d': + conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + norm_block = nn.BatchNorm2d(out_channels) + else: + assert False + + return nn.Sequential( + conv_block, + norm_block, + nn.LeakyReLU(0.2, True) + ) + +class Decoder(nn.Module): + def __init__(self, in_ch, out_ch): + super(Decoder, self).__init__() + self.up1 = nn.Sequential( + ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2), + ConvNormRelu(in_ch // 2, in_ch // 2), + nn.Upsample(scale_factor=2, mode='nearest') + ) + self.up2 = nn.Sequential( + ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4), + ConvNormRelu(in_ch // 4, in_ch // 4), + nn.Upsample(scale_factor=2, mode='nearest') + ) + self.up3 = nn.Sequential( + ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8), + ConvNormRelu(in_ch // 8, in_ch // 8), + nn.Conv1d(in_ch // 8, out_ch, 1, 1) + ) + + def forward(self, x, x1, x2, x3): + x = F.interpolate(x, x3.shape[2]) + x = torch.cat([x, x3], dim=1) + x = self.up1(x) + x = F.interpolate(x, x2.shape[2]) + x = torch.cat([x, x2], dim=1) + x = self.up2(x) + x = F.interpolate(x, x1.shape[2]) + x = torch.cat([x, x1], dim=1) + x = self.up3(x) + return x + + +class EncoderDecoder(nn.Module): + def __init__(self, n_frames, each_dim): + super().__init__() + self.n_frames = n_frames + + self.down1 = nn.Sequential( + ConvNormRelu(64, 64, '1d', False), + ConvNormRelu(64, 128, '1d', False), + ) + self.down2 = nn.Sequential( + ConvNormRelu(128, 128, '1d', False), + ConvNormRelu(128, 256, '1d', False), + ) + self.down3 = nn.Sequential( + ConvNormRelu(256, 256, '1d', False), + ConvNormRelu(256, 512, '1d', False), + ) + self.down4 = nn.Sequential( + ConvNormRelu(512, 512, '1d', False), + ConvNormRelu(512, 1024, '1d', False), + ) + + self.down = nn.MaxPool1d(kernel_size=2) + self.up = nn.Upsample(scale_factor=2, mode='nearest') + + self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3]) + self.body_decoder = Decoder(1024, each_dim[1]) + self.hand_decoder = Decoder(1024, each_dim[2]) + + def forward(self, spectrogram, time_steps=None): + if time_steps is None: + time_steps = self.n_frames + + x1 = self.down1(spectrogram) + x = self.down(x1) + x2 = self.down2(x) + x = self.down(x2) + x3 = self.down3(x) + x = self.down(x3) + x = self.down4(x) + x = self.up(x) + + face = self.face_decoder(x, x1, x2, x3) + body = self.body_decoder(x, x1, x2, x3) + hand = self.hand_decoder(x, x1, x2, x3) + + return face, body, hand + + +class Generator(nn.Module): + def __init__(self, + each_dim, + training=False, + device=None + ): + super().__init__() + + self.training = training + self.device = device + + self.encoderdecoder = EncoderDecoder(15, each_dim) + + def forward(self, in_spec, time_steps=None): + if time_steps is not None: + self.gen_length = time_steps + + face, body, hand = self.encoderdecoder(in_spec) + out = torch.cat([face, body, hand], dim=1) + out = out.transpose(1, 2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.net = nn.Sequential( + ConvNormRelu(input_dim, 128, '1d'), + ConvNormRelu(128, 256, '1d'), + nn.MaxPool1d(kernel_size=2), + ConvNormRelu(256, 256, '1d'), + ConvNormRelu(256, 512, '1d'), + nn.MaxPool1d(kernel_size=2), + ConvNormRelu(512, 512, '1d'), + ConvNormRelu(512, 1024, '1d'), + nn.MaxPool1d(kernel_size=2), + nn.Conv1d(1024, 1, 1, 1), + nn.Sigmoid() + ) + + def forward(self, x): + x = x.transpose(1, 2) + + out = self.net(x) + return out + + +class TrainWrapper(TrainWrapperBaseClass): + def __init__(self, args, config) -> None: + self.args = args + self.config = config + self.device = torch.device(self.args.gpu) + self.global_step = 0 + self.convert_to_6d = self.config.Data.pose.convert_to_6d + self.init_params() + + self.generator = Generator( + each_dim=self.each_dim, + training=not self.args.infer, + device=self.device, + ).to(self.device) + self.discriminator = Discriminator( + input_dim=self.each_dim[1] + self.each_dim[2] + 64 + ).to(self.device) + if self.convert_to_6d: + self.c_index = c_index_6d + else: + self.c_index = c_index_3d + self.MSELoss = KeypointLoss().to(self.device) + self.L1Loss = L1Loss().to(self.device) + super().__init__(args, config) + + def init_params(self): + scale = 1 + + global_orient = round(0 * scale) + leye_pose = reye_pose = round(0 * scale) + jaw_pose = round(3 * scale) + body_pose = round((63 - 24) * scale) + left_hand_pose = right_hand_pose = round(45 * scale) + + expression = 100 + + b_j = 0 + jaw_dim = jaw_pose + b_e = b_j + jaw_dim + eye_dim = leye_pose + reye_pose + b_b = b_e + eye_dim + body_dim = global_orient + body_pose + b_h = b_b + body_dim + hand_dim = left_hand_pose + right_hand_pose + b_f = b_h + hand_dim + face_dim = expression + + self.dim_list = [b_j, b_e, b_b, b_h, b_f] + self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + self.pose = int(self.full_dim / round(3 * scale)) + self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] + + def __call__(self, bat): + assert (not self.args.infer), "infer mode" + self.global_step += 1 + + loss_dict = {} + + aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) + expression = bat['expression'].to(self.device).to(torch.float32) + jaw = poses[:, :3, :] + poses = poses[:, self.c_index, :] + + pred = self.generator(in_spec=aud) + + D_loss, D_loss_dict = self.get_loss( + pred_poses=pred.detach(), + gt_poses=poses, + aud=aud, + mode='training_D', + ) + + self.discriminator_optimizer.zero_grad() + D_loss.backward() + self.discriminator_optimizer.step() + + G_loss, G_loss_dict = self.get_loss( + pred_poses=pred, + gt_poses=poses, + aud=aud, + expression=expression, + jaw=jaw, + mode='training_G', + ) + self.generator_optimizer.zero_grad() + G_loss.backward() + self.generator_optimizer.step() + + total_loss = None + loss_dict = {} + for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()): + loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0) + + return total_loss, loss_dict + + def get_loss(self, + pred_poses, + gt_poses, + aud=None, + jaw=None, + expression=None, + mode='training_G', + ): + loss_dict = {} + aud = aud.transpose(1, 2) + gt_poses = gt_poses.transpose(1, 2) + gt_aud = torch.cat([gt_poses, aud], dim=2) + pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2) + + if mode == 'training_D': + dis_real = self.discriminator(gt_aud) + dis_fake = self.discriminator(pred_aud) + dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss( + torch.zeros_like(dis_fake).to(self.device), dis_fake) + loss_dict['dis'] = dis_error + + return dis_error, loss_dict + elif mode == 'training_G': + jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2)) + face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2)) + body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39]) + hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:]) + l1_loss = jaw_loss + face_loss + body_loss + hand_loss + + dis_output = self.discriminator(pred_aud) + gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output) + gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error + + loss_dict['gen'] = gen_error + loss_dict['jaw_loss'] = jaw_loss + loss_dict['face_loss'] = face_loss + loss_dict['body_loss'] = body_loss + loss_dict['hand_loss'] = hand_loss + return gen_loss, loss_dict + else: + raise ValueError(mode) + + def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs): + output = [] + assert self.args.infer, "train mode" + self.generator.eval() + + if self.config.Data.pose.normalization: + assert norm_stats is not None + data_mean = norm_stats[0] + data_std = norm_stats[1] + + pre_length = self.config.Data.pose.pre_pose_length + generate_length = self.config.Data.pose.generate_length + # assert pre_length == initial_pose.shape[-1] + # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32) + # B = pre_poses.shape[0] + + aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0) + num_poses_to_generate = aud_feat.shape[-1] + aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) + aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) + + with torch.no_grad(): + pred_poses = self.generator(aud_feat) + pred_poses = pred_poses.cpu().numpy() + output = pred_poses.squeeze() + + return output + + def generate(self, aud, id): + self.generator.eval() + pred_poses = self.generator(aud) + return pred_poses + + +if __name__ == '__main__': + from trainer.options import parse_args + + parser = parse_args() + args = parser.parse_args( + ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64', + '--infer']) + + generator = TrainWrapper(args) + + aud_fn = '../sample_audio/jon.wav' + initial_pose = torch.randn(64, 108, 4) + norm_stats = (np.random.randn(108), np.random.randn(108)) + output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats) + + print(output.shape) diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0669d82d7506b314fa3fccd7dd412445fa0b37e1 --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1,8 @@ +from .smplx_face import TrainWrapper as s2g_face +from .smplx_body_vq import TrainWrapper as s2g_body_vq +from .smplx_body_pixel import TrainWrapper as s2g_body_pixel +from .body_ae import TrainWrapper as s2g_body_ae +from .LS3DCG import TrainWrapper as LS3DCG +from .base import TrainWrapperBaseClass + +from .utils import normalize, denormalize \ No newline at end of file diff --git a/nets/__pycache__/LS3DCG.cpython-37.pyc b/nets/__pycache__/LS3DCG.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b805df22b6a5908c15e9a8d45380c287c864df79 Binary files /dev/null and b/nets/__pycache__/LS3DCG.cpython-37.pyc differ diff --git a/nets/__pycache__/__init__.cpython-37.pyc b/nets/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd77b05f4e2503c9766d2581325594482c8ab30c Binary files /dev/null and b/nets/__pycache__/__init__.cpython-37.pyc differ diff --git a/nets/__pycache__/base.cpython-37.pyc b/nets/__pycache__/base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94093ffb32eb5d7a08ee7ffc694f1dd667a3134 Binary files /dev/null and b/nets/__pycache__/base.cpython-37.pyc differ diff --git a/nets/__pycache__/body_ae.cpython-37.pyc b/nets/__pycache__/body_ae.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a42d9b5d9ae7795677e92931304111ded3f4c68d Binary files /dev/null and b/nets/__pycache__/body_ae.cpython-37.pyc differ diff --git a/nets/__pycache__/init_model.cpython-37.pyc b/nets/__pycache__/init_model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..089d2e7ec5d5bbd73d1fff7cef7c792c378de03a Binary files /dev/null and b/nets/__pycache__/init_model.cpython-37.pyc differ diff --git a/nets/__pycache__/layers.cpython-37.pyc b/nets/__pycache__/layers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0df91833195fee22a25b90b578cac16c53bf11d Binary files /dev/null and b/nets/__pycache__/layers.cpython-37.pyc differ diff --git a/nets/__pycache__/smplx_body_pixel.cpython-37.pyc b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9be6845fada6e59cc304238f4c9e847ea0255a8 Binary files /dev/null and b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc differ diff --git a/nets/__pycache__/smplx_body_vq.cpython-37.pyc b/nets/__pycache__/smplx_body_vq.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480f25b32e812601156f4ab65d5d8e1ad6915704 Binary files /dev/null and b/nets/__pycache__/smplx_body_vq.cpython-37.pyc differ diff --git a/nets/__pycache__/smplx_face.cpython-37.pyc b/nets/__pycache__/smplx_face.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12138283f09af4e1e631e9b8704ab8a8b273f3de Binary files /dev/null and b/nets/__pycache__/smplx_face.cpython-37.pyc differ diff --git a/nets/__pycache__/utils.cpython-37.pyc b/nets/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8092271dc39faad72b0c77283492a2f75a0067c8 Binary files /dev/null and b/nets/__pycache__/utils.cpython-37.pyc differ diff --git a/nets/base.py b/nets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08c07caa27ba642dd5a48cebfcf51e4e79edd574 --- /dev/null +++ b/nets/base.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +class TrainWrapperBaseClass(): + def __init__(self, args, config) -> None: + self.init_optimizer() + + def init_optimizer(self) -> None: + print('using Adam') + self.generator_optimizer = optim.Adam( + self.generator.parameters(), + lr = self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + if self.discriminator is not None: + self.discriminator_optimizer = optim.Adam( + self.discriminator.parameters(), + lr = self.config.Train.learning_rate.discriminator_learning_rate, + betas=[0.9, 0.999] + ) + + def __call__(self, bat): + raise NotImplementedError + + def get_loss(self, **kwargs): + raise NotImplementedError + + def state_dict(self): + model_state = { + 'generator': self.generator.state_dict(), + 'generator_optim': self.generator_optimizer.state_dict(), + 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, + 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None + } + return model_state + + def parameters(self): + return self.generator.parameters() + + def load_state_dict(self, state_dict): + if 'generator' in state_dict: + self.generator.load_state_dict(state_dict['generator']) + else: + self.generator.load_state_dict(state_dict) + + if 'generator_optim' in state_dict and self.generator_optimizer is not None: + self.generator_optimizer.load_state_dict(state_dict['generator_optim']) + + if self.discriminator is not None: + self.discriminator.load_state_dict(state_dict['discriminator']) + + if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None: + self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim']) + + def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs): + raise NotImplementedError + + def init_params(self): + if self.config.Data.pose.convert_to_6d: + scale = 2 + else: + scale = 1 + + global_orient = round(0 * scale) + leye_pose = reye_pose = round(0 * scale) + jaw_pose = round(0 * scale) + body_pose = round((63 - 24) * scale) + left_hand_pose = right_hand_pose = round(45 * scale) + if self.expression: + expression = 100 + else: + expression = 0 + + b_j = 0 + jaw_dim = jaw_pose + b_e = b_j + jaw_dim + eye_dim = leye_pose + reye_pose + b_b = b_e + eye_dim + body_dim = global_orient + body_pose + b_h = b_b + body_dim + hand_dim = left_hand_pose + right_hand_pose + b_f = b_h + hand_dim + face_dim = expression + + self.dim_list = [b_j, b_e, b_b, b_h, b_f] + self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + self.pose = int(self.full_dim / round(3 * scale)) + self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] \ No newline at end of file diff --git a/nets/body_ae.py b/nets/body_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9f8bc0ee92f8410da71d711bb19dbcc254d1af --- /dev/null +++ b/nets/body_ae.py @@ -0,0 +1,152 @@ +import os +import sys + +sys.path.append(os.getcwd()) + +from nets.base import TrainWrapperBaseClass +from nets.spg.s2glayers import Discriminator as D_S2G +from nets.spg.vqvae_1d import AE as s2g_body +import torch +import torch.optim as optim +import torch.nn.functional as F + +from data_utils.lower_body import c_index, c_index_3d, c_index_6d + + +def separate_aa(aa): + aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5) + axis = F.normalize(aa[:, :, :, :3], dim=-1) + angle = F.normalize(aa[:, :, :, 3:5], dim=-1) + return axis, angle + + +class TrainWrapper(TrainWrapperBaseClass): + ''' + a wrapper receving a batch from data_utils and calculate loss + ''' + + def __init__(self, args, config): + self.args = args + self.config = config + self.device = torch.device(self.args.gpu) + self.global_step = 0 + + self.gan = False + self.convert_to_6d = self.config.Data.pose.convert_to_6d + self.preleng = self.config.Data.pose.pre_pose_length + self.expression = self.config.Data.pose.expression + self.epoch = 0 + self.init_params() + self.num_classes = 4 + self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0, + num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device) + if self.gan: + self.discriminator = D_S2G( + pose_dim=110 + 64, pose=self.pose + ).to(self.device) + else: + self.discriminator = None + + if self.convert_to_6d: + self.c_index = c_index_6d + else: + self.c_index = c_index_3d + + super().__init__(args, config) + + def init_optimizer(self): + + self.g_optimizer = optim.Adam( + self.g.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + + def state_dict(self): + model_state = { + 'g': self.g.state_dict(), + 'g_optim': self.g_optimizer.state_dict(), + 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, + 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None + } + return model_state + + + def __call__(self, bat): + # assert (not self.args.infer), "infer mode" + self.global_step += 1 + + total_loss = None + loss_dict = {} + + aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) + + # id = bat['speaker'].to(self.device) - 20 + # id = F.one_hot(id, self.num_classes) + + poses = poses[:, self.c_index, :] + gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1) + + loss = 0 + loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss) + + return total_loss, loss_dict + + def vq_train(self, gt, name, model, dict, total_loss, pre=None): + x_recon = model(gt_poses=gt, pre_state=pre) + loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre) + # total_loss = total_loss + loss + + if name == 'g': + optimizer_name = 'g_optimizer' + + optimizer = getattr(self, optimizer_name) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for key in list(loss_dict.keys()): + dict[name + key] = loss_dict.get(key, 0).item() + return dict, total_loss + + def get_loss(self, + pred_poses, + gt_poses, + pre=None + ): + loss_dict = {} + + + rec_loss = torch.mean(torch.abs(pred_poses - gt_poses)) + v_pr = pred_poses[:, 1:] - pred_poses[:, :-1] + v_gt = gt_poses[:, 1:] - gt_poses[:, :-1] + velocity_loss = torch.mean(torch.abs(v_pr - v_gt)) + + if pre is None: + f0_vel = 0 + else: + v0_pr = pred_poses[:, 0] - pre[:, -1] + v0_gt = gt_poses[:, 0] - pre[:, -1] + f0_vel = torch.mean(torch.abs(v0_pr - v0_gt)) + + gen_loss = rec_loss + velocity_loss + f0_vel + + loss_dict['rec_loss'] = rec_loss + loss_dict['velocity_loss'] = velocity_loss + # loss_dict['e_q_loss'] = e_q_loss + if pre is not None: + loss_dict['f0_vel'] = f0_vel + + return gen_loss, loss_dict + + def load_state_dict(self, state_dict): + self.g.load_state_dict(state_dict['g']) + + def extract(self, x): + self.g.eval() + if x.shape[2] > self.full_dim: + if x.shape[2] == 239: + x = x[:, :, 102:] + x = x[:, :, self.c_index] + feat = self.g.encode(x) + return feat.transpose(1, 2), x diff --git a/nets/init_model.py b/nets/init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a849d9e49afc54048979ebc12c743c4a8b8620a7 --- /dev/null +++ b/nets/init_model.py @@ -0,0 +1,35 @@ +from nets import * + + +def init_model(model_name, args, config): + + if model_name == 's2g_face': + generator = s2g_face( + args, + config, + ) + elif model_name == 's2g_body_vq': + generator = s2g_body_vq( + args, + config, + ) + elif model_name == 's2g_body_pixel': + generator = s2g_body_pixel( + args, + config, + ) + elif model_name == 's2g_body_ae': + generator = s2g_body_ae( + args, + config, + ) + elif model_name == 's2g_LS3DCG': + generator = LS3DCG( + args, + config, + ) + else: + raise ValueError + return generator + + diff --git a/nets/layers.py b/nets/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..79251b42b6e0fe839ec04dc38472ef36165208ac --- /dev/null +++ b/nets/layers.py @@ -0,0 +1,1052 @@ +import os +import sys + +sys.path.append(os.getcwd()) + +import torch +import torch.nn as nn +import numpy as np + + +# TODO: be aware of the actual netork structures + +def get_log(x): + log = 0 + while x > 1: + if x % 2 == 0: + x = x // 2 + log += 1 + else: + raise ValueError('x is not a power of 2') + + return log + + +class ConvNormRelu(nn.Module): + ''' + (B,C_in,H,W) -> (B, C_out, H, W) + there exist some kernel size that makes the result is not H/s + #TODO: there might some problems with residual + ''' + + def __init__(self, + in_channels, + out_channels, + type='1d', + leaky=False, + downsample=False, + kernel_size=None, + stride=None, + padding=None, + p=0, + groups=1, + residual=False, + norm='bn'): + ''' + conv-bn-relu + ''' + super(ConvNormRelu, self).__init__() + self.residual = residual + self.norm_type = norm + # kernel_size = k + # stride = s + + if kernel_size is None and stride is None: + if not downsample: + kernel_size = 3 + stride = 1 + else: + kernel_size = 4 + stride = 2 + + if padding is None: + if isinstance(kernel_size, int) and isinstance(stride, tuple): + padding = tuple(int((kernel_size - st) / 2) for st in stride) + elif isinstance(kernel_size, tuple) and isinstance(stride, int): + padding = tuple(int((ks - stride) / 2) for ks in kernel_size) + elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): + padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) + else: + padding = int((kernel_size - stride) / 2) + + if self.residual: + if downsample: + if type == '1d': + self.residual_layer = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + elif type == '2d': + self.residual_layer = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + else: + if in_channels == out_channels: + self.residual_layer = nn.Identity() + else: + if type == '1d': + self.residual_layer = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + elif type == '2d': + self.residual_layer = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + + in_channels = in_channels * groups + out_channels = out_channels * groups + if type == '1d': + self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=p) + elif type == '2d': + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(p=p) + if norm == 'gn': + self.norm = nn.GroupNorm(2, out_channels) + elif norm == 'ln': + self.norm = nn.LayerNorm(out_channels) + if leaky: + self.relu = nn.LeakyReLU(negative_slope=0.2) + else: + self.relu = nn.ReLU() + + def forward(self, x, **kwargs): + if self.norm_type == 'ln': + out = self.dropout(self.conv(x)) + out = self.norm(out.transpose(1,2)).transpose(1,2) + else: + out = self.norm(self.dropout(self.conv(x))) + if self.residual: + residual = self.residual_layer(x) + out += residual + return self.relu(out) + + +class UNet1D(nn.Module): + def __init__(self, + input_channels, + output_channels, + max_depth=5, + kernel_size=None, + stride=None, + p=0, + groups=1): + super(UNet1D, self).__init__() + self.pre_downsampling_conv = nn.ModuleList([]) + self.conv1 = nn.ModuleList([]) + self.conv2 = nn.ModuleList([]) + self.upconv = nn.Upsample(scale_factor=2, mode='nearest') + self.max_depth = max_depth + self.groups = groups + + self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels, + type='1d', leaky=True, downsample=False, + kernel_size=kernel_size, stride=stride, p=p, groups=groups)) + self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels, + type='1d', leaky=True, downsample=False, + kernel_size=kernel_size, stride=stride, p=p, groups=groups)) + + for i in range(self.max_depth): + self.conv1.append(ConvNormRelu(output_channels, output_channels, + type='1d', leaky=True, downsample=True, + kernel_size=kernel_size, stride=stride, p=p, groups=groups)) + + for i in range(self.max_depth): + self.conv2.append(ConvNormRelu(output_channels, output_channels, + type='1d', leaky=True, downsample=False, + kernel_size=kernel_size, stride=stride, p=p, groups=groups)) + + def forward(self, x): + + input_size = x.shape[-1] + + assert get_log( + input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth' + + x = nn.Sequential(*self.pre_downsampling_conv)(x) + + residuals = [] + residuals.append(x) + for i, conv1 in enumerate(self.conv1): + x = conv1(x) + if i < self.max_depth - 1: + residuals.append(x) + + for i, conv2 in enumerate(self.conv2): + x = self.upconv(x) + residuals[self.max_depth - i - 1] + x = conv2(x) + + return x + + +class UNet2D(nn.Module): + def __init__(self): + super(UNet2D, self).__init__() + raise NotImplementedError('2D Unet is wierd') + + +class AudioPoseEncoder1D(nn.Module): + ''' + (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T) + ''' + + def __init__(self, + C_in, + C_out, + kernel_size=None, + stride=None, + min_layer_nums=None + ): + super(AudioPoseEncoder1D, self).__init__() + self.C_in = C_in + self.C_out = C_out + + conv_layers = nn.ModuleList([]) + cur_C = C_in + num_layers = 0 + while cur_C < self.C_out: + conv_layers.append(ConvNormRelu( + in_channels=cur_C, + out_channels=cur_C * 2, + kernel_size=kernel_size, + stride=stride + )) + cur_C *= 2 + num_layers += 1 + + if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums): + while (cur_C != C_out) or num_layers < min_layer_nums: + conv_layers.append(ConvNormRelu( + in_channels=cur_C, + out_channels=C_out, + kernel_size=kernel_size, + stride=stride + )) + num_layers += 1 + cur_C = C_out + + self.conv_layers = nn.Sequential(*conv_layers) + + def forward(self, x): + ''' + x: (B, C, T) + ''' + x = self.conv_layers(x) + return x + + +class AudioPoseEncoder2D(nn.Module): + ''' + (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T) + ''' + + def __init__(self): + raise NotImplementedError + + +class AudioPoseEncoderRNN(nn.Module): + ''' + (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T) + ''' + + def __init__(self, + C_in, + hidden_size, + num_layers, + rnn_cell='gru', + bidirectional=False + ): + super(AudioPoseEncoderRNN, self).__init__() + if rnn_cell == 'gru': + self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=bidirectional) + elif rnn_cell == 'lstm': + self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=bidirectional) + else: + raise ValueError('invalid rnn cell:%s' % (rnn_cell)) + + def forward(self, x, state=None): + + x = x.permute(0, 2, 1) + x, state = self.cell(x, state) + x = x.permute(0, 2, 1) + + return x + + +class AudioPoseEncoderGraph(nn.Module): + ''' + (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V) + ''' + + def __init__(self, + layers_config, # 理应是(C_in, C_out, kernel_size)的list + A, # adjacent matrix (num_parts, V, V) + residual, + local_bn=False, + share_weights=False + ) -> None: + super().__init__() + self.A = A + self.num_joints = A.shape[1] + self.num_parts = A.shape[0] + self.C_in = layers_config[0][0] + self.C_out = layers_config[-1][1] + + self.conv_layers = nn.ModuleList([ + GraphConvNormRelu( + C_in=c_in, + C_out=c_out, + A=self.A, + residual=residual, + local_bn=local_bn, + kernel_size=k, + share_weights=share_weights + ) for (c_in, c_out, k) in layers_config + ]) + + self.conv_layers = nn.Sequential(*self.conv_layers) + + def forward(self, x): + ''' + x: (B, C, T), C should be num_joints*D + output: (B, D, T, V) + ''' + B, C, T = x.shape + x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面 + x = x.permute(0, 2, 3, 1) # (B, D, T, V) + assert x.shape[1] == self.C_in + + x_conved = self.conv_layers(x) + + # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T) + + return x_conved + + +class SeqEncoder2D(nn.Module): + ''' + seq_encoder, encoding a seq to a vector + (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out) + ''' + + def __init__(self, + C_in, # should be 2 + T_in, + C_out, + num_joints, + min_layer_num=None, + residual=False + ): + super(SeqEncoder2D, self).__init__() + self.C_in = C_in + self.C_out = C_out + self.T_in = T_in + self.num_joints = num_joints + + conv_layers = nn.ModuleList([]) + conv_layers.append(ConvNormRelu( + in_channels=C_in, + out_channels=32, + type='2d', + residual=residual + )) + + cur_C = 32 + cur_H = T_in + cur_W = num_joints + num_layers = 1 + while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1): + ks = [3, 3] + st = [1, 1] + + if cur_H > 1: + if cur_H > 4: + ks[0] = 4 + st[0] = 2 + else: + ks[0] = cur_H + st[0] = cur_H + if cur_W > 1: + if cur_W > 4: + ks[1] = 4 + st[1] = 2 + else: + ks[1] = cur_W + st[1] = cur_W + + conv_layers.append(ConvNormRelu( + in_channels=cur_C, + out_channels=min(C_out, cur_C * 2), + type='2d', + kernel_size=tuple(ks), + stride=tuple(st), + residual=residual + )) + cur_C = min(cur_C * 2, C_out) + if cur_H > 1: + if cur_H > 4: + cur_H //= 2 + else: + cur_H = 1 + if cur_W > 1: + if cur_W > 4: + cur_W //= 2 + else: + cur_W = 1 + num_layers += 1 + + if min_layer_num is not None and (num_layers < min_layer_num): + while num_layers < min_layer_num: + conv_layers.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='2d', + kernel_size=1, + stride=1, + residual=residual + )) + num_layers += 1 + + self.conv_layers = nn.Sequential(*conv_layers) + self.num_layers = num_layers + + def forward(self, x): + B, C, T = x.shape + x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front + x = x.permute(0, 2, 3, 1) # (B, D, T, V) + assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints + + x = self.conv_layers(x) + return x.squeeze() + + +class SeqEncoder1D(nn.Module): + ''' + (B, C, T)->(B, D) + ''' + + def __init__(self, + C_in, + C_out, + T_in, + min_layer_nums=None + ): + super(SeqEncoder1D, self).__init__() + conv_layers = nn.ModuleList([]) + cur_C = C_in + cur_T = T_in + self.num_layers = 0 + while (cur_C < C_out) or (cur_T > 1): + ks = 3 + st = 1 + if cur_T > 1: + if cur_T > 4: + ks = 4 + st = 2 + else: + ks = cur_T + st = cur_T + + conv_layers.append(ConvNormRelu( + in_channels=cur_C, + out_channels=min(C_out, cur_C * 2), + type='1d', + kernel_size=ks, + stride=st + )) + cur_C = min(cur_C * 2, C_out) + if cur_T > 1: + if cur_T > 4: + cur_T = cur_T // 2 + else: + cur_T = 1 + self.num_layers += 1 + + if min_layer_nums is not None and (self.num_layers < min_layer_nums): + while self.num_layers < min_layer_nums: + conv_layers.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d', + kernel_size=1, + stride=1 + )) + self.num_layers += 1 + self.conv_layers = nn.Sequential(*conv_layers) + + def forward(self, x): + x = self.conv_layers(x) + return x.squeeze() + + +class SeqEncoderRNN(nn.Module): + ''' + (B, C, T) -> (B, T, C) -> (B, D) + LSTM/GRU-FC + ''' + + def __init__(self, + hidden_size, + in_size, + num_rnn_layers, + rnn_cell='gru', + bidirectional=False + ): + super(SeqEncoderRNN, self).__init__() + self.hidden_size = hidden_size + self.in_size = in_size + self.num_rnn_layers = num_rnn_layers + self.bidirectional = bidirectional + + if rnn_cell == 'gru': + self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, + batch_first=True, bidirectional=bidirectional) + elif rnn_cell == 'lstm': + self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, + batch_first=True, bidirectional=bidirectional) + + def forward(self, x, state=None): + + x = x.permute(0, 2, 1) + B, T, C = x.shape + x, _ = self.cell(x, state) + if self.bidirectional: + out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1) + else: + out = x[:, -1, :] + assert out.shape[0] == B + return out + + +class SeqEncoderGraph(nn.Module): + ''' + ''' + + def __init__(self, + embedding_size, + layer_configs, + residual, + local_bn, + A, + T, + share_weights=False + ) -> None: + super().__init__() + + self.C_in = layer_configs[0][0] + self.C_out = embedding_size + + self.num_joints = A.shape[1] + + self.graph_encoder = AudioPoseEncoderGraph( + layers_config=layer_configs, + A=A, + residual=residual, + local_bn=local_bn, + share_weights=share_weights + ) + + cur_C = layer_configs[-1][1] + self.spatial_pool = ConvNormRelu( + in_channels=cur_C, + out_channels=cur_C, + type='2d', + kernel_size=(1, self.num_joints), + stride=(1, 1), + padding=(0, 0) + ) + + temporal_pool = nn.ModuleList([]) + cur_H = T + num_layers = 0 + self.temporal_conv_info = [] + while cur_C < self.C_out or cur_H > 1: + self.temporal_conv_info.append(cur_C) + ks = [3, 1] + st = [1, 1] + + if cur_H > 1: + if cur_H > 4: + ks[0] = 4 + st[0] = 2 + else: + ks[0] = cur_H + st[0] = cur_H + + temporal_pool.append(ConvNormRelu( + in_channels=cur_C, + out_channels=min(self.C_out, cur_C * 2), + type='2d', + kernel_size=tuple(ks), + stride=tuple(st) + )) + cur_C = min(cur_C * 2, self.C_out) + + if cur_H > 1: + if cur_H > 4: + cur_H //= 2 + else: + cur_H = 1 + + num_layers += 1 + + self.temporal_pool = nn.Sequential(*temporal_pool) + print("graph seq encoder info: temporal pool:", self.temporal_conv_info) + self.num_layers = num_layers + # need fc? + + def forward(self, x): + ''' + x: (B, C, T) + ''' + B, C, T = x.shape + x = self.graph_encoder(x) + x = self.spatial_pool(x) + x = self.temporal_pool(x) + x = x.view(B, self.C_out) + + return x + + +class SeqDecoder2D(nn.Module): + ''' + (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T) + ''' + + def __init__(self): + super(SeqDecoder2D, self).__init__() + raise NotImplementedError + + +class SeqDecoder1D(nn.Module): + ''' + (B, D)->(B, D, 1)->...->(B, C_out, T) + ''' + + def __init__(self, + D_in, + C_out, + T_out, + min_layer_num=None + ): + super(SeqDecoder1D, self).__init__() + self.T_out = T_out + self.min_layer_num = min_layer_num + + cur_t = 1 + + self.pre_conv = ConvNormRelu( + in_channels=D_in, + out_channels=C_out, + type='1d' + ) + self.num_layers = 1 + self.upconv = nn.Upsample(scale_factor=2, mode='nearest') + self.conv_layers = nn.ModuleList([]) + cur_t *= 2 + while cur_t <= T_out: + self.conv_layers.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d' + )) + cur_t *= 2 + self.num_layers += 1 + + post_conv = nn.ModuleList([ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d' + )]) + self.num_layers += 1 + if min_layer_num is not None and self.num_layers < min_layer_num: + while self.num_layers < min_layer_num: + post_conv.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d' + )) + self.num_layers += 1 + self.post_conv = nn.Sequential(*post_conv) + + def forward(self, x): + + x = x.unsqueeze(-1) + x = self.pre_conv(x) + for conv in self.conv_layers: + x = self.upconv(x) + x = conv(x) + + x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest') + x = self.post_conv(x) + return x + + +class SeqDecoderRNN(nn.Module): + ''' + (B, D)->(B, C_out, T) + ''' + + def __init__(self, + hidden_size, + C_out, + T_out, + num_layers, + rnn_cell='gru' + ): + super(SeqDecoderRNN, self).__init__() + self.num_steps = T_out + if rnn_cell == 'gru': + self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + elif rnn_cell == 'lstm': + self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + else: + raise ValueError('invalid rnn cell:%s' % (rnn_cell)) + + self.fc = nn.Linear(hidden_size, C_out) + + def forward(self, hidden, frame_0): + frame_0 = frame_0.permute(0, 2, 1) + dec_input = frame_0 + outputs = [] + for i in range(self.num_steps): + frame_out, hidden = self.cell(dec_input, hidden) + frame_out = self.fc(frame_out) + dec_input = frame_out + outputs.append(frame_out) + output = torch.cat(outputs, dim=1) + return output.permute(0, 2, 1) + + +class SeqTranslator2D(nn.Module): + ''' + (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out) + ''' + + def __init__(self, + C_in=64, + C_out=108, + T_in=75, + T_out=25, + residual=True + ): + super(SeqTranslator2D, self).__init__() + print("Warning: hard coded") + self.C_in = C_in + self.C_out = C_out + self.T_in = T_in + self.T_out = T_out + self.residual = residual + + self.conv_layers = nn.Sequential( + ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1), + ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), + ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), + + ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)), + ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), + ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), + + ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)), + ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)), + ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), + + ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), + ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1), + ) + + def forward(self, x): + assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in + x = x.view(x.shape[0], 1, x.shape[1], x.shape[2]) + x = self.conv_layers(x) + x = x.squeeze(2) + return x + + +class SeqTranslator1D(nn.Module): + ''' + (B, C, T)->(B, C_out, T) + ''' + + def __init__(self, + C_in, + C_out, + kernel_size=None, + stride=None, + min_layers_num=None, + residual=True, + norm='bn' + ): + super(SeqTranslator1D, self).__init__() + + conv_layers = nn.ModuleList([]) + conv_layers.append(ConvNormRelu( + in_channels=C_in, + out_channels=C_out, + type='1d', + kernel_size=kernel_size, + stride=stride, + residual=residual, + norm=norm + )) + self.num_layers = 1 + if min_layers_num is not None and self.num_layers < min_layers_num: + while self.num_layers < min_layers_num: + conv_layers.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d', + kernel_size=kernel_size, + stride=stride, + residual=residual, + norm=norm + )) + self.num_layers += 1 + self.conv_layers = nn.Sequential(*conv_layers) + + def forward(self, x): + return self.conv_layers(x) + + +class SeqTranslatorRNN(nn.Module): + ''' + (B, C, T)->(B, C_out, T) + LSTM-FC + ''' + + def __init__(self, + C_in, + C_out, + hidden_size, + num_layers, + rnn_cell='gru' + ): + super(SeqTranslatorRNN, self).__init__() + + if rnn_cell == 'gru': + self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + elif rnn_cell == 'lstm': + self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, + bidirectional=False) + else: + raise ValueError('invalid rnn cell:%s' % (rnn_cell)) + + self.fc = nn.Linear(hidden_size, C_out) + + def forward(self, x, frame_0): + + num_steps = x.shape[-1] + x = x.permute(0, 2, 1) + frame_0 = frame_0.permute(0, 2, 1) + _, hidden = self.enc_cell(x, None) + + outputs = [] + for i in range(num_steps): + inputs = frame_0 + output_frame, hidden = self.dec_cell(inputs, hidden) + output_frame = self.fc(output_frame) + frame_0 = output_frame + outputs.append(output_frame) + outputs = torch.cat(outputs, dim=1) + return outputs.permute(0, 2, 1) + + +class ResBlock(nn.Module): + def __init__(self, + input_dim, + fc_dim, + afn, + nfn + ): + ''' + afn: activation fn + nfn: normalization fn + ''' + super(ResBlock, self).__init__() + + self.input_dim = input_dim + self.fc_dim = fc_dim + self.afn = afn + self.nfn = nfn + + if self.afn != 'relu': + raise ValueError('Wrong') + + if self.nfn == 'layer_norm': + raise ValueError('wrong') + + self.layers = nn.Sequential( + nn.Linear(self.input_dim, self.fc_dim // 2), + nn.ReLU(), + nn.Linear(self.fc_dim // 2, self.fc_dim // 2), + nn.ReLU(), + nn.Linear(self.fc_dim // 2, self.fc_dim), + nn.ReLU() + ) + + self.shortcut_layer = nn.Sequential( + nn.Linear(self.input_dim, self.fc_dim), + nn.ReLU(), + ) + + def forward(self, inputs): + return self.layers(inputs) + self.shortcut_layer(inputs) + + +class AudioEncoder(nn.Module): + def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False): + super(AudioEncoder, self).__init__() + self.in_channels = channels[0] + self.augmentation = augmentation + + model = [] + acti = nn.LeakyReLU(0.2) + + nr_layer = len(channels) - 1 + + for i in range(nr_layer): + if conv_pool is None: + model.append(nn.ReflectionPad1d(padding)) + model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) + model.append(acti) + else: + model.append(nn.ReflectionPad1d(padding)) + model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) + model.append(acti) + model.append(conv_pool(kernel_size=2, stride=2)) + + if self.augmentation: + model.append( + nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride) + ) + model.append(acti) + + self.model = nn.Sequential(*model) + + def forward(self, x): + + x = x[:, :self.in_channels, :] + x = self.model(x) + return x + + +class AudioDecoder(nn.Module): + def __init__(self, channels, kernel_size=7, ups=25): + super(AudioDecoder, self).__init__() + + model = [] + pad = (kernel_size - 1) // 2 + acti = nn.LeakyReLU(0.2) + + for i in range(len(channels) - 2): + model.append(nn.Upsample(scale_factor=2, mode='nearest')) + model.append(nn.ReflectionPad1d(pad)) + model.append(nn.Conv1d(channels[i], channels[i + 1], + kernel_size=kernel_size, stride=1)) + if i == 0 or i == 1: + model.append(nn.Dropout(p=0.2)) + if not i == len(channels) - 2: + model.append(acti) + + model.append(nn.Upsample(size=ups, mode='nearest')) + model.append(nn.ReflectionPad1d(pad)) + model.append(nn.Conv1d(channels[-2], channels[-1], + kernel_size=kernel_size, stride=1)) + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class Audio2Pose(nn.Module): + def __init__(self, pose_dim, embed_size, augmentation, ups=25): + super(Audio2Pose, self).__init__() + self.pose_dim = pose_dim + self.embed_size = embed_size + self.augmentation = augmentation + + self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1, + conv_pool=nn.AvgPool1d, augmentation=self.augmentation) + if self.augmentation: + self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim]) + else: + self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups) + + if self.augmentation: + self.pose_enc = nn.Sequential( + nn.Linear(self.embed_size // 2, 256), + nn.LayerNorm(256) + ) + + def forward(self, audio_feat, dec_input=None): + + B = audio_feat.shape[0] + + aud_embed = self.aud_enc.forward(audio_feat) + + if self.augmentation: + dec_input = dec_input.squeeze(0) + dec_embed = self.pose_enc(dec_input) + dec_embed = dec_embed.unsqueeze(2) + dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1]) + aud_embed = torch.cat([aud_embed, dec_embed], dim=1) + + out = self.aud_dec.forward(aud_embed) + return out + + +if __name__ == '__main__': + import numpy as np + import os + import sys + + test_model = SeqEncoder2D( + C_in=2, + T_in=25, + C_out=512, + num_joints=54, + ) + print(test_model.num_layers) + + input = torch.randn((64, 108, 25)) + output = test_model(input) + print(output.shape) \ No newline at end of file diff --git a/nets/smplx_body_pixel.py b/nets/smplx_body_pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..02bb6c672ecf18371f1ff0f16732c8c16db9f2a8 --- /dev/null +++ b/nets/smplx_body_pixel.py @@ -0,0 +1,326 @@ +import os +import sys + +import torch +from torch.optim.lr_scheduler import StepLR + +sys.path.append(os.getcwd()) + +from nets.layers import * +from nets.base import TrainWrapperBaseClass +from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn +from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder +from nets.spg.vqvae_1d import AudioEncoder +from nets.utils import parse_audio, denormalize +from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta +import numpy as np +import torch.optim as optim +import torch.nn.functional as F +from sklearn.preprocessing import normalize + +from data_utils.lower_body import c_index, c_index_3d, c_index_6d +from data_utils.utils import smooth_geom, get_mfcc_sepa + + +class TrainWrapper(TrainWrapperBaseClass): + ''' + a wrapper receving a batch from data_utils and calculate loss + ''' + + def __init__(self, args, config): + self.args = args + self.config = config + self.device = torch.device(self.args.gpu) + self.global_step = 0 + + self.convert_to_6d = self.config.Data.pose.convert_to_6d + self.expression = self.config.Data.pose.expression + self.epoch = 0 + self.init_params() + self.num_classes = 4 + self.audio = True + self.composition = self.config.Model.composition + self.bh_model = self.config.Model.bh_model + + if self.audio: + self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device) + else: + self.audioencoder = None + if self.convert_to_6d: + dim, layer = 512, 10 + else: + dim, layer = 256, 15 + self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device) + self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, + num_residual_layers=2, num_residual_hiddens=512).to(self.device) + self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, + num_residual_layers=2, num_residual_hiddens=512).to(self.device) + + model_path = self.config.Model.vq_path + model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) + self.g_body.load_state_dict(model_ckpt['generator']['g_body']) + self.g_hand.load_state_dict(model_ckpt['generator']['g_hand']) + + if torch.cuda.device_count() > 1: + self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1]) + self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1]) + self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1]) + if self.audioencoder is not None: + self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1]) + + self.discriminator = None + if self.convert_to_6d: + self.c_index = c_index_6d + else: + self.c_index = c_index_3d + + super().__init__(args, config) + + def init_optimizer(self): + + print('using Adam') + self.generator_optimizer = optim.Adam( + self.generator.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + if self.audioencoder is not None: + opt = self.config.Model.AudioOpt + if opt == 'Adam': + self.audioencoder_optimizer = optim.Adam( + self.audioencoder.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + else: + print('using SGD') + self.audioencoder_optimizer = optim.SGD( + filter(lambda p: p.requires_grad,self.audioencoder.parameters()), + lr=self.config.Train.learning_rate.generator_learning_rate*10, + momentum=0.9, + nesterov=False, + ) + + def state_dict(self): + model_state = { + 'generator': self.generator.state_dict(), + 'generator_optim': self.generator_optimizer.state_dict(), + 'audioencoder': self.audioencoder.state_dict() if self.audio else None, + 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None, + 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, + 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None + } + return model_state + + def load_state_dict(self, state_dict): + + from collections import OrderedDict + new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.` + for k, v in state_dict.items(): + sub_dict = OrderedDict() + if v is not None: + for k1, v1 in v.items(): + name = k1.replace('module.', '') + sub_dict[name] = v1 + new_state_dict[k] = sub_dict + state_dict = new_state_dict + if 'generator' in state_dict: + self.generator.load_state_dict(state_dict['generator']) + else: + self.generator.load_state_dict(state_dict) + + if 'generator_optim' in state_dict and self.generator_optimizer is not None: + self.generator_optimizer.load_state_dict(state_dict['generator_optim']) + + if self.discriminator is not None: + self.discriminator.load_state_dict(state_dict['discriminator']) + + if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None: + self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim']) + + if 'audioencoder' in state_dict and self.audioencoder is not None: + self.audioencoder.load_state_dict(state_dict['audioencoder']) + + def init_params(self): + if self.config.Data.pose.convert_to_6d: + scale = 2 + else: + scale = 1 + + global_orient = round(0 * scale) + leye_pose = reye_pose = round(0 * scale) + jaw_pose = round(0 * scale) + body_pose = round((63 - 24) * scale) + left_hand_pose = right_hand_pose = round(45 * scale) + if self.expression: + expression = 100 + else: + expression = 0 + + b_j = 0 + jaw_dim = jaw_pose + b_e = b_j + jaw_dim + eye_dim = leye_pose + reye_pose + b_b = b_e + eye_dim + body_dim = global_orient + body_pose + b_h = b_b + body_dim + hand_dim = left_hand_pose + right_hand_pose + b_f = b_h + hand_dim + face_dim = expression + + self.dim_list = [b_j, b_e, b_b, b_h, b_f] + self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + self.pose = int(self.full_dim / round(3 * scale)) + self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] + + def __call__(self, bat): + # assert (not self.args.infer), "infer mode" + self.global_step += 1 + + total_loss = None + loss_dict = {} + + aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) + + id = bat['speaker'].to(self.device) - 20 + # id = F.one_hot(id, self.num_classes) + + poses = poses[:, self.c_index, :] + + aud = aud.permute(0, 2, 1) + gt_poses = poses.permute(0, 2, 1) + + with torch.no_grad(): + self.g_body.eval() + self.g_hand.eval() + if torch.cuda.device_count() > 1: + _, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id) + _, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id) + else: + _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id) + _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id) + latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1) + latents = latents.detach() + + if self.audio: + audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2) + logits = self.generator(latents[:, :], id, audio) + else: + logits = self.generator(latents, id) + logits = logits.permute(0, 2, 3, 1).contiguous() + + self.generator_optimizer.zero_grad() + if self.audio: + self.audioencoder_optimizer.zero_grad() + + loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1)) + loss.backward() + + grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm) + + if torch.isnan(grad).sum() > 0: + print('fuck') + + loss_dict['grad'] = grad.item() + loss_dict['ce_loss'] = loss.item() + self.generator_optimizer.step() + if self.audio: + self.audioencoder_optimizer.step() + + return total_loss, loss_dict + + def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None, + continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs): + ''' + initial_pose: (B, C, T), normalized + (aud_fn, txgfile) -> generated motion (B, T, C) + ''' + output = [] + + assert self.args.infer, "train mode" + self.generator.eval() + self.g_body.eval() + self.g_hand.eval() + + if continuity: + aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps) + else: + aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am) + aud_feat = aud_feat.transpose(1, 0) + aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) + aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) + + if id is None: + id = torch.tensor([0]).to(self.device) + else: + id = id.repeat(B) + + with torch.no_grad(): + aud_feat = aud_feat.permute(0, 2, 1) + if continuity: + self.audioencoder.eval() + pre_pose = {} + pre_pose['b'] = pre_pose['h'] = None + pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose) + pre_pose['b'] = body_0[:, :, -4:].transpose(1,2) + pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2) + _, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose) + body = torch.cat([body_0, body_1], dim=2) + hand = torch.cat([hand_0, hand_1], dim=2) + + else: + if self.audio: + self.audioencoder.eval() + audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2) + latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio) + else: + latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B) + + body_latents = latents[..., 0] + hand_latents = latents[..., 1] + + body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents) + hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents) + + pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy() + + output = pred_poses + + return output + + def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None): + audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2) + latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio, + pre_latents=pre_latents, pre_audio=pre_audio) + + body_latents = latents[..., 0] + hand_latents = latents[..., 1] + + body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], + latents=body_latents, pre_state=pre_pose['b']) + hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], + latents=hand_latents, pre_state=pre_pose['h']) + + return latents, audio, body, hand + + def generate(self, aud, id, frame_num=0): + + self.generator.eval() + self.g_body.eval() + self.g_hand.eval() + aud_feat = aud.permute(0, 2, 1) + if self.audio: + self.audioencoder.eval() + audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2) + latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio) + else: + latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0]) + + body_latents = latents[..., 0] + hand_latents = latents[..., 1] + + body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents) + hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents) + + pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2) + return pred_poses diff --git a/nets/smplx_body_vq.py b/nets/smplx_body_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..95770ac251b873de26b4530f0a37fe43ed5e14f5 --- /dev/null +++ b/nets/smplx_body_vq.py @@ -0,0 +1,302 @@ +import os +import sys + +from torch.optim.lr_scheduler import StepLR + +sys.path.append(os.getcwd()) + +from nets.layers import * +from nets.base import TrainWrapperBaseClass +from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G +from nets.spg.vqvae_1d import VQVAE as s2g_body +from nets.utils import parse_audio, denormalize +from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta +import numpy as np +import torch.optim as optim +import torch.nn.functional as F +from sklearn.preprocessing import normalize + +from data_utils.lower_body import c_index, c_index_3d, c_index_6d + + +class TrainWrapper(TrainWrapperBaseClass): + ''' + a wrapper receving a batch from data_utils and calculate loss + ''' + + def __init__(self, args, config): + self.args = args + self.config = config + self.device = torch.device(self.args.gpu) + self.global_step = 0 + + self.convert_to_6d = self.config.Data.pose.convert_to_6d + self.expression = self.config.Data.pose.expression + self.epoch = 0 + self.init_params() + self.num_classes = 4 + self.composition = self.config.Model.composition + if self.composition: + self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, + num_residual_layers=2, num_residual_hiddens=512).to(self.device) + self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, + num_residual_layers=2, num_residual_hiddens=512).to(self.device) + else: + self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, + num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device) + + self.discriminator = None + + if self.convert_to_6d: + self.c_index = c_index_6d + else: + self.c_index = c_index_3d + + super().__init__(args, config) + + def init_optimizer(self): + print('using Adam') + if self.composition: + self.g_body_optimizer = optim.Adam( + self.g_body.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + self.g_hand_optimizer = optim.Adam( + self.g_hand.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + else: + self.g_optimizer = optim.Adam( + self.g.parameters(), + lr=self.config.Train.learning_rate.generator_learning_rate, + betas=[0.9, 0.999] + ) + + def state_dict(self): + if self.composition: + model_state = { + 'g_body': self.g_body.state_dict(), + 'g_body_optim': self.g_body_optimizer.state_dict(), + 'g_hand': self.g_hand.state_dict(), + 'g_hand_optim': self.g_hand_optimizer.state_dict(), + 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, + 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None + } + else: + model_state = { + 'g': self.g.state_dict(), + 'g_optim': self.g_optimizer.state_dict(), + 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, + 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None + } + return model_state + + def init_params(self): + if self.config.Data.pose.convert_to_6d: + scale = 2 + else: + scale = 1 + + global_orient = round(0 * scale) + leye_pose = reye_pose = round(0 * scale) + jaw_pose = round(0 * scale) + body_pose = round((63 - 24) * scale) + left_hand_pose = right_hand_pose = round(45 * scale) + if self.expression: + expression = 100 + else: + expression = 0 + + b_j = 0 + jaw_dim = jaw_pose + b_e = b_j + jaw_dim + eye_dim = leye_pose + reye_pose + b_b = b_e + eye_dim + body_dim = global_orient + body_pose + b_h = b_b + body_dim + hand_dim = left_hand_pose + right_hand_pose + b_f = b_h + hand_dim + face_dim = expression + + self.dim_list = [b_j, b_e, b_b, b_h, b_f] + self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + self.pose = int(self.full_dim / round(3 * scale)) + self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] + + def __call__(self, bat): + # assert (not self.args.infer), "infer mode" + self.global_step += 1 + + total_loss = None + loss_dict = {} + + aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) + + # id = bat['speaker'].to(self.device) - 20 + # id = F.one_hot(id, self.num_classes) + + poses = poses[:, self.c_index, :] + gt_poses = poses.permute(0, 2, 1) + b_poses = gt_poses[..., :self.each_dim[1]] + h_poses = gt_poses[..., self.each_dim[1]:] + + if self.composition: + loss = 0 + loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss) + loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss) + else: + loss = 0 + loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss) + + return total_loss, loss_dict + + def vq_train(self, gt, name, model, dict, total_loss, pre=None): + e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre) + loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre) + # total_loss = total_loss + loss + + if name == 'b': + optimizer_name = 'g_body_optimizer' + elif name == 'h': + optimizer_name = 'g_hand_optimizer' + elif name == 'g': + optimizer_name = 'g_optimizer' + else: + raise ValueError("model's name must be b or h") + optimizer = getattr(self, optimizer_name) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for key in list(loss_dict.keys()): + dict[name + key] = loss_dict.get(key, 0).item() + return dict, total_loss + + def get_loss(self, + pred_poses, + gt_poses, + e_q_loss, + pre=None + ): + loss_dict = {} + + + rec_loss = torch.mean(torch.abs(pred_poses - gt_poses)) + v_pr = pred_poses[:, 1:] - pred_poses[:, :-1] + v_gt = gt_poses[:, 1:] - gt_poses[:, :-1] + velocity_loss = torch.mean(torch.abs(v_pr - v_gt)) + + if pre is None: + f0_vel = 0 + else: + v0_pr = pred_poses[:, 0] - pre[:, -1] + v0_gt = gt_poses[:, 0] - pre[:, -1] + f0_vel = torch.mean(torch.abs(v0_pr - v0_gt)) + + gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel + + loss_dict['rec_loss'] = rec_loss + loss_dict['velocity_loss'] = velocity_loss + # loss_dict['e_q_loss'] = e_q_loss + if pre is not None: + loss_dict['f0_vel'] = f0_vel + + return gen_loss, loss_dict + + def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False, + id=None, fps=15, sr=22000, smooth=False, **kwargs): + ''' + initial_pose: (B, C, T), normalized + (aud_fn, txgfile) -> generated motion (B, T, C) + ''' + output = [] + + assert self.args.infer, "train mode" + if self.composition: + self.g_body.eval() + self.g_hand.eval() + else: + self.g.eval() + + if self.config.Data.pose.normalization: + assert norm_stats is not None + data_mean = norm_stats[0] + data_std = norm_stats[1] + + # assert initial_pose.shape[-1] == pre_length + if initial_pose is not None: + gt = initial_pose[:, :, :].to(self.device).to(torch.float32) + pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32) + poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32) + B = pre_poses.shape[0] + else: + gt = None + pre_poses = None + B = 1 + + if type(aud_fn) == torch.Tensor: + aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device) + num_poses_to_generate = aud_feat.shape[-1] + else: + aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0) + aud_feat = aud_feat[:, :] + num_poses_to_generate = aud_feat.shape[-1] + aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) + aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) + + # pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32) + if id is None: + id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device) + + with torch.no_grad(): + aud_feat = aud_feat.permute(0, 2, 1) + gt_poses = gt[:, self.c_index].permute(0, 2, 1) + if self.composition: + if continuity: + pred_poses_body = [] + pred_poses_hand = [] + pre_b = None + pre_h = None + for i in range(5): + _, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b) + pre_b = pred_body[..., -1:].transpose(1,2) + pred_poses_body.append(pred_body) + _, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h) + pre_h = pred_hand[..., -1:].transpose(1,2) + pred_poses_hand.append(pred_hand) + + pred_poses_body = torch.cat(pred_poses_body, dim=2) + pred_poses_hand = torch.cat(pred_poses_hand, dim=2) + else: + _, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id) + _, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id) + pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1) + else: + _, pred_poses = self.g(gt_poses=gt_poses, id=id) + pred_poses = pred_poses.transpose(1, 2).cpu().numpy() + output = pred_poses + + if self.config.Data.pose.normalization: + output = denormalize(output, data_mean, data_std) + + if smooth: + lamda = 0.8 + smooth_f = 10 + frame = 149 + for i in range(smooth_f): + f = frame + i + l = lamda * (i + 1) / smooth_f + output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f] + + output = np.concatenate(output, axis=1) + + return output + + def load_state_dict(self, state_dict): + if self.composition: + self.g_body.load_state_dict(state_dict['g_body']) + self.g_hand.load_state_dict(state_dict['g_hand']) + else: + self.g.load_state_dict(state_dict['g']) diff --git a/nets/smplx_face.py b/nets/smplx_face.py new file mode 100644 index 0000000000000000000000000000000000000000..e591b9dab674770b60655f607892b068f412d75a --- /dev/null +++ b/nets/smplx_face.py @@ -0,0 +1,238 @@ +import os +import sys + +sys.path.append(os.getcwd()) + +from nets.layers import * +from nets.base import TrainWrapperBaseClass +# from nets.spg.faceformer import Faceformer +from nets.spg.s2g_face import Generator as s2g_face +from losses import KeypointLoss +from nets.utils import denormalize +from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta +import numpy as np +import torch.optim as optim +import torch.nn.functional as F +from sklearn.preprocessing import normalize +import smplx + + +class TrainWrapper(TrainWrapperBaseClass): + ''' + a wrapper receving a batch from data_utils and calculate loss + ''' + + def __init__(self, args, config): + self.args = args + self.config = config + self.device = torch.device(self.args.gpu) + self.global_step = 0 + + self.convert_to_6d = self.config.Data.pose.convert_to_6d + self.expression = self.config.Data.pose.expression + self.epoch = 0 + self.init_params() + self.num_classes = 4 + + self.generator = s2g_face( + n_poses=self.config.Data.pose.generate_length, + each_dim=self.each_dim, + dim_list=self.dim_list, + training=not self.args.infer, + device=self.device, + identity=False if self.convert_to_6d else True, + num_classes=self.num_classes, + ).to(self.device) + + # self.generator = Faceformer().to(self.device) + + self.discriminator = None + self.am = None + + self.MSELoss = KeypointLoss().to(self.device) + super().__init__(args, config) + + def init_optimizer(self): + self.generator_optimizer = optim.SGD( + filter(lambda p: p.requires_grad,self.generator.parameters()), + lr=0.001, + momentum=0.9, + nesterov=False, + ) + + def init_params(self): + if self.convert_to_6d: + scale = 2 + else: + scale = 1 + + global_orient = round(3 * scale) + leye_pose = reye_pose = round(3 * scale) + jaw_pose = round(3 * scale) + body_pose = round(63 * scale) + left_hand_pose = right_hand_pose = round(45 * scale) + if self.expression: + expression = 100 + else: + expression = 0 + + b_j = 0 + jaw_dim = jaw_pose + b_e = b_j + jaw_dim + eye_dim = leye_pose + reye_pose + b_b = b_e + eye_dim + body_dim = global_orient + body_pose + b_h = b_b + body_dim + hand_dim = left_hand_pose + right_hand_pose + b_f = b_h + hand_dim + face_dim = expression + + self.dim_list = [b_j, b_e, b_b, b_h, b_f] + self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim + self.pose = int(self.full_dim / round(3 * scale)) + self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] + + def __call__(self, bat): + # assert (not self.args.infer), "infer mode" + self.global_step += 1 + + total_loss = None + loss_dict = {} + + aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) + id = bat['speaker'].to(self.device) - 20 + id = F.one_hot(id, self.num_classes) + + aud = aud.permute(0, 2, 1) + gt_poses = poses.permute(0, 2, 1) + + if self.expression: + expression = bat['expression'].to(self.device).to(torch.float32) + gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2) + + pred_poses, _ = self.generator( + aud, + gt_poses, + id, + ) + + G_loss, G_loss_dict = self.get_loss( + pred_poses=pred_poses, + gt_poses=gt_poses, + pre_poses=None, + mode='training_G', + gt_conf=None, + aud=aud, + ) + + self.generator_optimizer.zero_grad() + G_loss.backward() + grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm) + loss_dict['grad'] = grad.item() + self.generator_optimizer.step() + + for key in list(G_loss_dict.keys()): + loss_dict[key] = G_loss_dict.get(key, 0).item() + + return total_loss, loss_dict + + def get_loss(self, + pred_poses, + gt_poses, + pre_poses, + aud, + mode='training_G', + gt_conf=None, + exp=1, + gt_nzero=None, + pre_nzero=None, + ): + loss_dict = {} + + + [b_j, b_e, b_b, b_h, b_f] = self.dim_list + + MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6])) + if self.expression: + expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2) + else: + expl = 0 + + gen_loss = expl + MSELoss + + loss_dict['MSELoss'] = MSELoss + if self.expression: + loss_dict['exp_loss'] = expl + + return gen_loss, loss_dict + + def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs): + ''' + initial_pose: (B, C, T), normalized + (aud_fn, txgfile) -> generated motion (B, T, C) + ''' + output = [] + + # assert self.args.infer, "train mode" + self.generator.eval() + + if self.config.Data.pose.normalization: + assert norm_stats is not None + data_mean = norm_stats[0] + data_std = norm_stats[1] + + # assert initial_pose.shape[-1] == pre_length + if initial_pose is not None: + gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32) + pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32) + poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32) + B = pre_poses.shape[0] + else: + gt = None + pre_poses=None + B = 1 + + if type(aud_fn) == torch.Tensor: + aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device) + num_poses_to_generate = aud_feat.shape[-1] + else: + aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer') + aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) + aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2) + if frame is None: + frame = aud_feat.shape[2]*30//16000 + # + if id is None: + id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device) + else: + id = F.one_hot(id, self.num_classes).to(self.generator.device) + + with torch.no_grad(): + pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0] + pred_poses = pred_poses.cpu().numpy() + output = pred_poses + + if self.config.Data.pose.normalization: + output = denormalize(output, data_mean, data_std) + + return output + + + def generate(self, wv2_feat, frame): + ''' + initial_pose: (B, C, T), normalized + (aud_fn, txgfile) -> generated motion (B, T, C) + ''' + output = [] + + # assert self.args.infer, "train mode" + self.generator.eval() + + B = 1 + + id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device) + id = id.repeat(wv2_feat.shape[0], 1) + + with torch.no_grad(): + pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0] + return pred_poses diff --git a/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae3f74945d1e857902656a7f3938c5929845ab9 Binary files /dev/null and b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc differ diff --git a/nets/spg/__pycache__/s2g_face.cpython-37.pyc b/nets/spg/__pycache__/s2g_face.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dc7edd5b4082fcfcf0ef4d9f868b0b37ab1bbd4 Binary files /dev/null and b/nets/spg/__pycache__/s2g_face.cpython-37.pyc differ diff --git a/nets/spg/__pycache__/s2glayers.cpython-37.pyc b/nets/spg/__pycache__/s2glayers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cab33d662e58a9f0c9e8126e0d49b40d4760f74d Binary files /dev/null and b/nets/spg/__pycache__/s2glayers.cpython-37.pyc differ diff --git a/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70727ba6ce575c5829db4e47a04bf8813428d96a Binary files /dev/null and b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc differ diff --git a/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c8d4dd80c3387e10535349774171623f34c8a38 Binary files /dev/null and b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc differ diff --git a/nets/spg/__pycache__/wav2vec.cpython-37.pyc b/nets/spg/__pycache__/wav2vec.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e935f81110e0674b8f634fe7bfb2495acae33154 Binary files /dev/null and b/nets/spg/__pycache__/wav2vec.cpython-37.pyc differ diff --git a/nets/spg/gated_pixelcnn_v2.py b/nets/spg/gated_pixelcnn_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb0ec8c0c1dd4bc1e2b6e9f55fd1104b547b05f --- /dev/null +++ b/nets/spg/gated_pixelcnn_v2.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + try: + nn.init.xavier_uniform_(m.weight.data) + m.bias.data.fill_(0) + except AttributeError: + print("Skipping initialization of ", classname) + + +class GatedActivation(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x, y = x.chunk(2, dim=1) + return F.tanh(x) * F.sigmoid(y) + + +class GatedMaskedConv2d(nn.Module): + def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False): + super().__init__() + assert kernel % 2 == 1, print("Kernel size must be odd") + self.mask_type = mask_type + self.residual = residual + self.bh_model = bh_model + + self.class_cond_embedding = nn.Embedding(n_classes, 2 * dim) + self.class_cond_embedding = self.class_cond_embedding.to("cpu") + + kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n) + padding_shp = (kernel // 2, 1 if self.bh_model else 0) + self.vert_stack = nn.Conv2d( + dim, dim * 2, + kernel_shp, 1, padding_shp + ) + + self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1) + + kernel_shp = (1, 2) + padding_shp = (0, 1) + self.horiz_stack = nn.Conv2d( + dim, dim * 2, + kernel_shp, 1, padding_shp + ) + + self.horiz_resid = nn.Conv2d(dim, dim, 1) + + self.gate = GatedActivation() + + def make_causal(self): + self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row + self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column + + def forward(self, x_v, x_h, h): + if self.mask_type == 'A': + self.make_causal() + + h = h.to(self.class_cond_embedding.weight.device) + h = self.class_cond_embedding(h) + + h_vert = self.vert_stack(x_v) + h_vert = h_vert[:, :, :x_v.size(-2), :] + out_v = self.gate(h_vert + h[:, :, None, None]) + + if self.bh_model: + h_horiz = self.horiz_stack(x_h) + h_horiz = h_horiz[:, :, :, :x_h.size(-1)] + v2h = self.vert_to_horiz(h_vert) + + out = self.gate(v2h + h_horiz + h[:, :, None, None]) + if self.residual: + out_h = self.horiz_resid(out) + x_h + else: + out_h = self.horiz_resid(out) + else: + if self.residual: + out_v = self.horiz_resid(out_v) + x_v + else: + out_v = self.horiz_resid(out_v) + out_h = out_v + + return out_v, out_h + + +class GatedPixelCNN(nn.Module): + def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False): + super().__init__() + self.dim = dim + self.audio = audio + self.bh_model = bh_model + + if self.audio: + self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0) + self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0) + self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0) + + # Create embedding layer to embed input + self.embedding = nn.Embedding(input_dim, dim) + + # Building the PixelCNN layer by layer + self.layers = nn.ModuleList() + + # Initial block with Mask-A convolution + # Rest with Mask-B convolutions + for i in range(n_layers): + mask_type = 'A' if i == 0 else 'B' + kernel = 7 if i == 0 else 3 + residual = False if i == 0 else True + + self.layers.append( + GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model) + ) + + # Add the output layer + self.output_conv = nn.Sequential( + nn.Conv2d(dim, 512, 1), + nn.ReLU(True), + nn.Conv2d(512, input_dim, 1) + ) + + self.apply(weights_init) + + self.dp = nn.Dropout(0.1) + self.to("cpu") + + def forward(self, x, label, aud=None): + shp = x.size() + (-1,) + x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C) + x = x.permute(0, 3, 1, 2) # (B, C, W, W) + + x_v, x_h = (x, x) + for i, layer in enumerate(self.layers): + if i == 1 and self.audio is True: + aud = self.embedding_aud(aud) + a = torch.ones(aud.shape[-2]).to(aud.device) + a = self.dp(a) + aud = (aud.transpose(-1, -2) * a).transpose(-1, -2) + x_v = self.fusion_v(torch.cat([x_v, aud], dim=1)) + if self.bh_model: + x_h = self.fusion_h(torch.cat([x_h, aud], dim=1)) + x_v, x_h = layer(x_v, x_h, label) + + if self.bh_model: + return self.output_conv(x_h) + else: + return self.output_conv(x_v) + + def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None): + param = next(self.parameters()) + x = torch.zeros( + (batch_size, *shape), + dtype=torch.int64, device=param.device + ) + if pre_latents is not None: + x = torch.cat([pre_latents, x], dim=1) + aud_feat = torch.cat([pre_audio, aud_feat], dim=2) + h0 = pre_latents.shape[1] + h = h0 + shape[0] + else: + h0 = 0 + h = shape[0] + + for i in range(h0, h): + for j in range(shape[1]): + if self.audio: + logits = self.forward(x, label, aud_feat) + else: + logits = self.forward(x, label) + probs = F.softmax(logits[:, :, i, j], -1) + x.data[:, i, j].copy_( + probs.multinomial(1).squeeze().data + ) + return x[:, h0:h] diff --git a/nets/spg/s2g_face.py b/nets/spg/s2g_face.py new file mode 100644 index 0000000000000000000000000000000000000000..b221df6c7bff0912640dfffe46267f8a131cc829 --- /dev/null +++ b/nets/spg/s2g_face.py @@ -0,0 +1,226 @@ +''' +not exactly the same as the official repo but the results are good +''' +import sys +import os + +from transformers import Wav2Vec2Processor + +from .wav2vec import Wav2Vec2Model +from torchaudio.sox_effects import apply_effects_tensor + +sys.path.append(os.getcwd()) + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio as ta +import math +from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu + + +""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ + + +def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): + """ + :param audio: 1 x T tensor containing a 16kHz audio signal + :param frame_rate: frame rate for video (we need one audio chunk per video frame) + :param chunk_size: number of audio samples per chunk + :return: num_chunks x chunk_size tensor containing sliced audio + """ + samples_per_frame = 16000 // frame_rate + padding = (chunk_size - samples_per_frame) // 2 + audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) + anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) + audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) + return audio + + +class MeshtalkEncoder(nn.Module): + def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): + """ + :param latent_dim: size of the latent audio embedding + :param model_name: name of the model, used to load and save the model + """ + super().__init__() + + self.melspec = ta.transforms.MelSpectrogram( + sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 + ) + + conv_len = 5 + self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) + self.weights_init(self.convert_dimensions) + self.receptive_field = conv_len + + convs = [] + for i in range(6): + dilation = 2 * (i % 3 + 1) + self.receptive_field += (conv_len - 1) * dilation + convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] + self.weights_init(convs[-1]) + self.convs = torch.nn.ModuleList(convs) + self.code = torch.nn.Linear(128, latent_dim) + + self.apply(lambda x: self.weights_init(x)) + + def weights_init(self, m): + if isinstance(m, torch.nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight) + try: + torch.nn.init.constant_(m.bias, .01) + except: + pass + + def forward(self, audio: torch.Tensor): + """ + :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame + :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding + """ + B, T = audio.shape[0], audio.shape[1] + x = self.melspec(audio).squeeze(1) + x = torch.log(x.clamp(min=1e-10, max=None)) + if T == 1: + x = x.unsqueeze(1) + + # Convert to the right dimensionality + x = x.view(-1, x.shape[2], x.shape[3]) + x = F.leaky_relu(self.convert_dimensions(x), .2) + + # Process stacks + for conv in self.convs: + x_ = F.leaky_relu(conv(x), .2) + if self.training: + x_ = F.dropout(x_, .2) + l = (x.shape[2] - x_.shape[2]) // 2 + x = (x[:, :, l:-l] + x_) / 2 + + x = torch.mean(x, dim=-1) + x = x.view(B, T, x.shape[-1]) + x = self.code(x) + + return {"code": x} + + +class AudioEncoder(nn.Module): + def __init__(self, in_dim, out_dim, identity=False, num_classes=0): + super().__init__() + self.identity = identity + if self.identity: + in_dim = in_dim + 64 + self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1) + self.first_net = SeqTranslator1D(in_dim, out_dim, + min_layers_num=3, + residual=True, + norm='ln' + ) + self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True) + self.dropout = nn.Dropout(0.1) + # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True) + + def forward(self, spectrogram, pre_state=None, id=None, time_steps=None): + + spectrogram = spectrogram + spectrogram = self.dropout(spectrogram) + if self.identity: + id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32) + id = self.id_mlp(id) + spectrogram = torch.cat([spectrogram, id], dim=1) + x1 = self.first_net(spectrogram)# .permute(0, 2, 1) + if time_steps is not None: + x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear') + # x1, _ = self.att(x1, x1, x1) + # x1, hidden_state = self.grus(x1) + # x1 = x1.permute(0, 2, 1) + hidden_state=None + + return x1, hidden_state + + +class Generator(nn.Module): + def __init__(self, + n_poses, + each_dim: list, + dim_list: list, + training=False, + device=None, + identity=True, + num_classes=0, + ): + super().__init__() + + self.training = training + self.device = device + self.gen_length = n_poses + self.identity = identity + + norm = 'ln' + in_dim = 256 + out_dim = 256 + + self.encoder_choice = 'faceformer' + + if self.encoder_choice == 'meshtalk': + self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim) + elif self.encoder_choice == 'faceformer': + # wav2vec 2.0 weights initialization + self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" + self.audio_encoder.feature_extractor._freeze_parameters() + self.audio_feature_map = nn.Linear(768, in_dim) + else: + self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim) + + self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes) + + self.dim_list = dim_list + + self.decoder = nn.ModuleList() + self.final_out = nn.ModuleList() + + self.decoder.append(nn.Sequential( + ConvNormRelu(out_dim, 64, norm=norm), + ConvNormRelu(64, 64, norm=norm), + ConvNormRelu(64, 64, norm=norm), + )) + self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) + + self.decoder.append(nn.Sequential( + ConvNormRelu(out_dim, out_dim, norm=norm), + ConvNormRelu(out_dim, out_dim, norm=norm), + ConvNormRelu(out_dim, out_dim, norm=norm), + )) + self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1)) + + def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): + if self.training: + time_steps = gt_poses.shape[1] + + # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) + if self.encoder_choice == 'meshtalk': + in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) + feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) + elif self.encoder_choice == 'faceformer': + hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state + feature = self.audio_feature_map(hidden_states).transpose(1, 2) + else: + feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) + + # hidden_states = in_spec + + feature, _ = self.audio_middle(feature, id=id) + + out = [] + + for i in range(self.decoder.__len__()): + mid = self.decoder[i](feature) + mid = self.final_out[i](mid) + out.append(mid) + + out = torch.cat(out, dim=1) + out = out.transpose(1, 2) + + return out, None + + diff --git a/nets/spg/s2glayers.py b/nets/spg/s2glayers.py new file mode 100644 index 0000000000000000000000000000000000000000..2a439e6bc0c4973586d39f3b113aa3752ff077fa --- /dev/null +++ b/nets/spg/s2glayers.py @@ -0,0 +1,522 @@ +''' +not exactly the same as the official repo but the results are good +''' +import sys +import os + +sys.path.append(os.getcwd()) + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from nets.layers import SeqEncoder1D, SeqTranslator1D + +""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ + + +class Conv2d_tf(nn.Conv2d): + """ + Conv2d with the padding behavior from TF + from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py + """ + + def __init__(self, *args, **kwargs): + super(Conv2d_tf, self).__init__(*args, **kwargs) + self.padding = kwargs.get("padding", "SAME") + + def _compute_padding(self, input, dim): + input_size = input.size(dim + 2) + filter_size = self.weight.size(dim + 2) + effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 + out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] + total_padding = max( + 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size + ) + additional_padding = int(total_padding % 2 != 0) + + return additional_padding, total_padding + + def forward(self, input): + if self.padding == "VALID": + return F.conv2d( + input, + self.weight, + self.bias, + self.stride, + padding=0, + dilation=self.dilation, + groups=self.groups, + ) + rows_odd, padding_rows = self._compute_padding(input, dim=0) + cols_odd, padding_cols = self._compute_padding(input, dim=1) + if rows_odd or cols_odd: + input = F.pad(input, [0, cols_odd, 0, rows_odd]) + + return F.conv2d( + input, + self.weight, + self.bias, + self.stride, + padding=(padding_rows // 2, padding_cols // 2), + dilation=self.dilation, + groups=self.groups, + ) + + +class Conv1d_tf(nn.Conv1d): + """ + Conv1d with the padding behavior from TF + modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py + """ + + def __init__(self, *args, **kwargs): + super(Conv1d_tf, self).__init__(*args, **kwargs) + self.padding = kwargs.get("padding") + + def _compute_padding(self, input, dim): + input_size = input.size(dim + 2) + filter_size = self.weight.size(dim + 2) + effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 + out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] + total_padding = max( + 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size + ) + additional_padding = int(total_padding % 2 != 0) + + return additional_padding, total_padding + + def forward(self, input): + # if self.padding == "valid": + # return F.conv1d( + # input, + # self.weight, + # self.bias, + # self.stride, + # padding=0, + # dilation=self.dilation, + # groups=self.groups, + # ) + rows_odd, padding_rows = self._compute_padding(input, dim=0) + if rows_odd: + input = F.pad(input, [0, rows_odd]) + + return F.conv1d( + input, + self.weight, + self.bias, + self.stride, + padding=(padding_rows // 2), + dilation=self.dilation, + groups=self.groups, + ) + + +def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1, + nonlinear='lrelu', bn='bn'): + if k is None and s is None: + if not downsample: + k = 3 + s = 1 + padding = 'same' + else: + k = 4 + s = 2 + padding = 'valid' + + if type == '1d': + conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) + norm_block = nn.BatchNorm1d(out_channels) + elif type == '2d': + conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) + norm_block = nn.BatchNorm2d(out_channels) + else: + assert False + if bn != 'bn': + if bn == 'gn': + norm_block = nn.GroupNorm(1, out_channels) + elif bn == 'ln': + norm_block = nn.LayerNorm(out_channels) + else: + norm_block = nn.Identity() + if nonlinear == 'lrelu': + nlinear = nn.LeakyReLU(0.2, True) + elif nonlinear == 'tanh': + nlinear = nn.Tanh() + elif nonlinear == 'none': + nlinear = nn.Identity() + + return nn.Sequential( + conv_block, + norm_block, + nlinear + ) + + +class UnetUp(nn.Module): + def __init__(self, in_ch, out_ch): + super(UnetUp, self).__init__() + self.conv = ConvNormRelu(in_ch, out_ch) + + def forward(self, x1, x2): + # x1 = torch.repeat_interleave(x1, 2, dim=2) + # x1 = x1[:, :, :x2.shape[2]] + x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear') + x = x1 + x2 + x = self.conv(x) + return x + + +class UNet(nn.Module): + def __init__(self, input_dim, dim): + super(UNet, self).__init__() + # dim = 512 + self.down1 = nn.Sequential( + ConvNormRelu(input_dim, input_dim, '1d', False), + ConvNormRelu(input_dim, dim, '1d', False), + ConvNormRelu(dim, dim, '1d', False) + ) + self.gru = nn.GRU(dim, dim, 1, batch_first=True) + self.down2 = ConvNormRelu(dim, dim, '1d', True) + self.down3 = ConvNormRelu(dim, dim, '1d', True) + self.down4 = ConvNormRelu(dim, dim, '1d', True) + self.down5 = ConvNormRelu(dim, dim, '1d', True) + self.down6 = ConvNormRelu(dim, dim, '1d', True) + self.up1 = UnetUp(dim, dim) + self.up2 = UnetUp(dim, dim) + self.up3 = UnetUp(dim, dim) + self.up4 = UnetUp(dim, dim) + self.up5 = UnetUp(dim, dim) + + def forward(self, x1, pre_pose=None, w_pre=False): + x2_0 = self.down1(x1) + if w_pre: + i = 1 + x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1) + x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1) + # x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15] + else: + # x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2) + x2 = x2_0 + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + x7 = self.down6(x6) + x = self.up1(x7, x6) + x = self.up2(x, x5) + x = self.up3(x, x4) + x = self.up4(x, x3) + x = self.up5(x, x2) # [B, 512, 15] + return x, x2_0 + + +class AudioEncoder(nn.Module): + def __init__(self, n_frames, template_length, pose=False, common_dim=512): + super().__init__() + self.n_frames = n_frames + self.pose = pose + self.step = 0 + self.weight = 0 + if self.pose: + # self.first_net = nn.Sequential( + # ConvNormRelu(1, 64, '2d', False), + # ConvNormRelu(64, 64, '2d', True), + # ConvNormRelu(64, 128, '2d', False), + # ConvNormRelu(128, 128, '2d', True), + # ConvNormRelu(128, 256, '2d', False), + # ConvNormRelu(256, 256, '2d', True), + # ConvNormRelu(256, 256, '2d', False), + # ConvNormRelu(256, 256, '2d', False, padding='VALID') + # ) + # decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4, + # dim_feedforward=2 * args.feature_dim, batch_first=True) + # a = nn.TransformerDecoder + self.first_net = SeqTranslator1D(256, 256, + min_layers_num=4, + residual=True + ) + self.dropout_0 = nn.Dropout(0.1) + self.mu_fc = nn.Conv1d(256, 128, 1, 1) + self.var_fc = nn.Conv1d(256, 128, 1, 1) + self.trans_motion = SeqTranslator1D(common_dim, common_dim, + kernel_size=1, + stride=1, + min_layers_num=3, + residual=True + ) + # self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1) + self.unet = UNet(128 + template_length, common_dim) + + else: + self.first_net = SeqTranslator1D(256, 256, + min_layers_num=4, + residual=True + ) + self.dropout_0 = nn.Dropout(0.1) + # self.att = nn.MultiheadAttention(256, 4, dropout=0.1) + self.unet = UNet(256, 256) + self.dropout_1 = nn.Dropout(0.0) + + def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False): + self.step = self.step + 1 + if self.pose: + spect = spectrogram.transpose(1, 2) + if w_pre: + spect = spect[:, :, :] + + out = self.first_net(spect) + out = self.dropout_0(out) + + mu = self.mu_fc(out) + var = self.var_fc(out) + audio = self.__reparam(mu, var) + # audio = out + + # template = self.trans_motion(template) + x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1) + # x1 = out + #x1, _ = self.att(x1, x1, x1) + #x1 = x1.permute(1,2,0) + x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre) + else: + spectrogram = spectrogram.transpose(1, 2) + x1 = self.first_net(spectrogram)#.permute(2,0,1) + #out, _ = self.att(out, out, out) + #out = out.permute(1, 2, 0) + x1 = self.dropout_0(x1) + x1, x2_0 = self.unet(x1) + x1 = self.dropout_1(x1) + mu = None + var = None + + return x1, (mu, var), x2_0 + + def __reparam(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std, device='cuda') + z = eps * std + mu + return z + + +class Generator(nn.Module): + def __init__(self, + n_poses, + pose_dim, + pose, + n_pre_poses, + each_dim: list, + dim_list: list, + use_template=False, + template_length=0, + training=False, + device=None, + separate=False, + expression=False + ): + super().__init__() + + self.use_template = use_template + self.template_length = template_length + self.training = training + self.device = device + self.separate = separate + self.pose = pose + self.decoderf = True + self.expression = expression + + common_dim = 256 + + if self.use_template: + assert template_length > 0 + # self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device) + # self.pose_encoder = SeqEncoder1D( + # C_in=pose_dim, + # C_out=512, + # T_in=n_poses, + # min_layer_nums=6 + # + # ) + self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim, + # kernel_size=1, + # stride=1, + min_layers_num=3, + residual=True + ) + self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) + self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) + + else: + self.template_length = 0 + + self.gen_length = n_poses + + self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim) + self.speech_encoder = AudioEncoder(n_poses, template_length, False) + + # self.pre_pose_encoder = SeqEncoder1D( + # C_in=pose_dim, + # C_out=128, + # T_in=15, + # min_layer_nums=3 + # + # ) + # self.pmu_fc = nn.Linear(128, 64) + # self.pvar_fc = nn.Linear(128, 64) + + self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim, + min_layers_num=5, + residual=True + ) + self.decoder_in = 256 + 64 + self.dim_list = dim_list + + if self.separate: + self.decoder = nn.ModuleList() + self.final_out = nn.ModuleList() + + self.decoder.append(nn.Sequential( + ConvNormRelu(256, 64), + ConvNormRelu(64, 64), + ConvNormRelu(64, 64), + )) + self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) + + self.decoder.append(nn.Sequential( + ConvNormRelu(common_dim, common_dim), + ConvNormRelu(common_dim, common_dim), + ConvNormRelu(common_dim, common_dim), + )) + self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1)) + + self.decoder.append(nn.Sequential( + ConvNormRelu(common_dim, common_dim), + ConvNormRelu(common_dim, common_dim), + ConvNormRelu(common_dim, common_dim), + )) + self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1)) + + if self.expression: + self.decoder.append(nn.Sequential( + ConvNormRelu(256, 256), + ConvNormRelu(256, 256), + ConvNormRelu(256, 256), + )) + self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1)) + else: + self.decoder = nn.Sequential( + ConvNormRelu(self.decoder_in, 512), + ConvNormRelu(512, 512), + ConvNormRelu(512, 512), + ConvNormRelu(512, 512), + ConvNormRelu(512, 512), + ConvNormRelu(512, 512), + ) + self.final_out = nn.Conv1d(512, pose_dim, 1, 1) + + def __reparam(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std, device=self.device) + z = eps * std + mu + return z + + def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True): + if time_steps is not None: + self.gen_length = time_steps + + if self.use_template: + if self.training: + if w_pre: + in_spec = in_spec[:, 15:, :] + pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1)) + pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1)) + mu = self.mu_fc(pose_enc) + var = self.var_fc(pose_enc) + template = self.__reparam(mu, var) + else: + pre_pose = None + pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) + mu = self.mu_fc(pose_enc) + var = self.var_fc(pose_enc) + template = self.__reparam(mu, var) + elif pre_poses is not None: + if w_pre: + pre_pose = pre_poses[:, -1:, :-50] + if norm: + pre_pose = pre_pose.reshape(1, -1, 55, 5) + pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1), + F.normalize(pre_pose[..., 3:5], dim=-1)], + dim=-1).reshape(1, -1, 275) + pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1)) + template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to( + in_spec.device) + else: + pre_pose = None + template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) + elif gt_poses is not None: + template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) + elif template is None: + pre_pose = None + template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) + else: + template = None + mu = None + var = None + + a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre) + s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps) + + out = [] + + if self.separate: + for i in range(self.decoder.__len__()): + if i == 0 or i == 3: + mid = self.decoder[i](s_f) + else: + mid = self.decoder[i](a_t_f) + mid = self.final_out[i](mid) + out.append(mid) + out = torch.cat(out, dim=1) + + else: + out = self.decoder(a_t_f) + out = self.final_out(out) + + out = out.transpose(1, 2) + + if self.training: + if w_pre: + return out, template, mu, var, (mu2, var2, x2_0, pre_pose) + else: + return out, template, mu, var, (mu2, var2, None, None) + else: + return out + + +class Discriminator(nn.Module): + def __init__(self, pose_dim, pose): + super().__init__() + self.net = nn.Sequential( + Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'), + nn.LeakyReLU(0.2, True), + ConvNormRelu(64, 128, '1d', True), + ConvNormRelu(128, 256, '1d', k=4, s=1), + Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'), + ) + + def forward(self, x): + x = x.transpose(1, 2) + + out = self.net(x) + return out + + +def main(): + d = Discriminator(275, 55) + x = torch.randn([8, 60, 275]) + result = d(x) + + +if __name__ == "__main__": + main() diff --git a/nets/spg/vqvae_1d.py b/nets/spg/vqvae_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd15bd6439b949bf89098af274b3e7ccac9b5f5 --- /dev/null +++ b/nets/spg/vqvae_1d.py @@ -0,0 +1,235 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from .wav2vec import Wav2Vec2Model +from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack + + + +class AudioEncoder(nn.Module): + def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens): + super(AudioEncoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + + self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True) + + self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) + self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True, + sample='down') + self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) + self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down') + self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + + def forward(self, x, frame_num=0): + h = self.project(x) + h = self._enc_1(h) + h = self._down_1(h) + h = self._enc_2(h) + h = self._down_2(h) + h = self._enc_3(h) + return h + + +class Wav2VecEncoder(nn.Module): + def __init__(self, num_hiddens, num_residual_layers): + super(Wav2VecEncoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + + self.audio_encoder = Wav2Vec2Model.from_pretrained( + "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" + self.audio_encoder.feature_extractor._freeze_parameters() + + self.project = ConvNormRelu(768, self._num_hiddens, leaky=True) + + self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down') + self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down') + self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + + def forward(self, x, frame_num): + h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2) + h = self.project(h) + h = self._enc_1(h) + h = self._down_1(h) + h = self._enc_2(h) + h = self._down_2(h) + h = self._enc_3(h) + return h + + +class Encoder(nn.Module): + def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Encoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + + self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True) + + self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) + self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True, + sample='down') + self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) + self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down') + self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + + self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1) + + def forward(self, x): + h = self.project(x) + h = self._enc_1(h) + h = self._down_1(h) + h = self._enc_2(h) + h = self._down_2(h) + h = self._enc_3(h) + h = self.pre_vq_conv(h) + return h + + +class Frame_Enc(nn.Module): + def __init__(self, in_dim, num_hiddens): + super(Frame_Enc, self).__init__() + self.in_dim = in_dim + self.num_hiddens = num_hiddens + + # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4) + self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1) + self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True) + self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1) + self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1) + + def forward(self, x): + # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1) + x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1) + second_last = self.proj_2(x) + last = self.proj_1(x) + return second_last, last + + + +class Decoder(nn.Module): + def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False): + super(Decoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + + self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1) + + self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) + self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up') + self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) + self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True, + sample='up') + self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) + + if ae: + self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4) + self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True) + self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True) + + self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1) + + def forward(self, h, last_frame=None): + + h = self.aft_vq_conv(h) + h = self._dec_1(h) + h = self._up_2(h) + h = self._dec_2(h) + h = self._up_3(h) + h = self._dec_3(h) + + recon = self.project(h) + return recon, None + + +class Pre_VQ(nn.Module): + def __init__(self, num_hiddens, embedding_dim, num_chunks): + super(Pre_VQ, self).__init__() + self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks) + self.bn = nn.GroupNorm(num_chunks, num_hiddens) + self.relu = nn.ReLU() + self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.proj(x) + return x + + +class VQVAE(nn.Module): + """VQ-VAE""" + + def __init__(self, in_dim, embedding_dim, num_embeddings, + num_hiddens, num_residual_layers, num_residual_hiddens, + commitment_cost=0.25, decay=0.99, share=False): + super().__init__() + self.in_dim = in_dim + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.share_code_vq = share + + self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) + self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay) + self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) + + def forward(self, gt_poses, id=None, pre_state=None): + z = self.encoder(gt_poses.transpose(1, 2)) + if not self.training: + e, _ = self.vq_layer(z) + x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) + return e, x_recon + + e, e_q_loss = self.vq_layer(z) + gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) + + return e_q_loss, gt_recon.transpose(1, 2) + + def encode(self, gt_poses, id=None): + z = self.encoder(gt_poses.transpose(1, 2)) + e, latents = self.vq_layer(z) + return e, latents + + def decode(self, b, w, e=None, latents=None, pre_state=None): + if e is not None: + x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) + else: + e = self.vq_layer.quantize(latents) + e = e.view(b, w, -1).permute(0, 2, 1).contiguous() + x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) + return x + + +class AE(nn.Module): + """VQ-VAE""" + + def __init__(self, in_dim, embedding_dim, num_embeddings, + num_hiddens, num_residual_layers, num_residual_hiddens): + super().__init__() + self.in_dim = in_dim + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + + self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) + self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True) + + def forward(self, gt_poses, id=None, pre_state=None): + z = self.encoder(gt_poses.transpose(1, 2)) + if not self.training: + x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None) + return z, x_recon + gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None) + + return gt_recon.transpose(1, 2) + + def encode(self, gt_poses, id=None): + z = self.encoder(gt_poses.transpose(1, 2)) + return z diff --git a/nets/spg/vqvae_modules.py b/nets/spg/vqvae_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5c83bc0399bc3bc034881407ed49d223d8c86ba9 --- /dev/null +++ b/nets/spg/vqvae_modules.py @@ -0,0 +1,380 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +import matplotlib.pyplot as plt + + + + +class CasualCT(nn.Module): + def __init__(self, + in_channels, + out_channels, + leaky=False, + p=0, + groups=1, ): + ''' + conv-bn-relu + ''' + super(CasualCT, self).__init__() + padding = 0 + kernel_size = 2 + stride = 2 + in_channels = in_channels * groups + out_channels = out_channels * groups + + self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=p) + if leaky: + self.relu = nn.LeakyReLU(negative_slope=0.2) + else: + self.relu = nn.ReLU() + + def forward(self, x, **kwargs): + out = self.norm(self.dropout(self.conv(x))) + return self.relu(out) + + +class CasualConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + leaky=False, + p=0, + groups=1, + downsample=False): + ''' + conv-bn-relu + ''' + super(CasualConv, self).__init__() + padding = 0 + kernel_size = 2 + stride = 1 + self.downsample = downsample + if self.downsample: + kernel_size = 2 + stride = 2 + + in_channels = in_channels * groups + out_channels = out_channels * groups + self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=p) + if leaky: + self.relu = nn.LeakyReLU(negative_slope=0.2) + else: + self.relu = nn.ReLU() + + def forward(self, x, pre_state=None): + if not self.downsample: + if pre_state is not None: + x = torch.cat([pre_state, x], dim=-1) + else: + zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device) + x = torch.cat([zeros, x], dim=-1) + out = self.norm(self.dropout(self.conv(x))) + return self.relu(out) + + +class ConvNormRelu(nn.Module): + ''' + (B,C_in,H,W) -> (B, C_out, H, W) + there exist some kernel size that makes the result is not H/s + #TODO: there might some problems with residual + ''' + + def __init__(self, + in_channels, + out_channels, + leaky=False, + sample='none', + p=0, + groups=1, + residual=False, + norm='bn'): + ''' + conv-bn-relu + ''' + super(ConvNormRelu, self).__init__() + self.residual = residual + self.norm_type = norm + padding = 1 + + if sample == 'none': + kernel_size = 3 + stride = 1 + elif sample == 'one': + padding = 0 + kernel_size = stride = 1 + else: + kernel_size = 4 + stride = 2 + + if self.residual: + if sample == 'down': + self.residual_layer = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + elif sample == 'up': + self.residual_layer = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + else: + if in_channels == out_channels: + self.residual_layer = nn.Identity() + else: + self.residual_layer = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + + in_channels = in_channels * groups + out_channels = out_channels * groups + if sample == 'up': + self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + else: + self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=p) + if leaky: + self.relu = nn.LeakyReLU(negative_slope=0.2) + else: + self.relu = nn.ReLU() + + def forward(self, x, **kwargs): + out = self.norm(self.dropout(self.conv(x))) + if self.residual: + residual = self.residual_layer(x) + out += residual + return self.relu(out) + + +class Res_CNR_Stack(nn.Module): + def __init__(self, + channels, + layers, + sample='none', + leaky=False, + casual=False, + ): + super(Res_CNR_Stack, self).__init__() + + if casual: + kernal_size = 1 + padding = 0 + conv = CasualConv + else: + kernal_size = 3 + padding = 1 + conv = ConvNormRelu + + if sample == 'one': + kernal_size = 1 + padding = 0 + + self._layers = nn.ModuleList() + for i in range(layers): + self._layers.append(conv(channels, channels, leaky=leaky, sample=sample)) + self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding) + self.norm = nn.BatchNorm1d(channels) + self.relu = nn.ReLU() + + def forward(self, x, pre_state=None): + # cur_state = [] + h = x + for i in range(self._layers.__len__()): + # cur_state.append(h[..., -1:]) + h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None) + h = self.norm(self.conv(h)) + return self.relu(h + x) + + +class ExponentialMovingAverage(nn.Module): + """Maintains an exponential moving average for a value. + + This module keeps track of a hidden exponential moving average that is + initialized as a vector of zeros which is then normalized to give the average. + This gives us a moving average which isn't biased towards either zero or the + initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf) + + Initially: + hidden_0 = 0 + Then iteratively: + hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay) + average_i = hidden_i / (1 - decay^i) + """ + + def __init__(self, init_value, decay): + super().__init__() + + self.decay = decay + self.counter = 0 + self.register_buffer("hidden", torch.zeros_like(init_value)) + + def forward(self, value): + self.counter += 1 + self.hidden.sub_((self.hidden - value) * (1 - self.decay)) + average = self.hidden / (1 - self.decay ** self.counter) + return average + + +class VectorQuantizerEMA(nn.Module): + """ + VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings. + Args: + embedding_dim (int): the dimensionality of the tensors in the + quantized space. Inputs to the modules must be in this format as well. + num_embeddings (int): the number of vectors in the quantized space. + commitment_cost (float): scalar which controls the weighting of the loss terms (see + equation 4 in the paper - this variable is Beta). + decay (float): decay for the moving averages. + epsilon (float): small float constant to avoid numerical instability. + """ + + def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay, + epsilon=1e-5): + super().__init__() + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.commitment_cost = commitment_cost + self.epsilon = epsilon + + # initialize embeddings as buffers + embeddings = torch.empty(self.num_embeddings, self.embedding_dim) + nn.init.xavier_uniform_(embeddings) + self.register_buffer("embeddings", embeddings) + self.ema_dw = ExponentialMovingAverage(self.embeddings, decay) + + # also maintain ema_cluster_size, which record the size of each embedding + self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay) + + def forward(self, x): + # [B, C, H, W] -> [B, H, W, C] + x = x.permute(0, 2, 1).contiguous() + # [B, H, W, C] -> [BHW, C] + flat_x = x.reshape(-1, self.embedding_dim) + + encoding_indices = self.get_code_indices(flat_x) + quantized = self.quantize(encoding_indices) + quantized = quantized.view_as(x) # [B, W, C] + + if not self.training: + quantized = quantized.permute(0, 2, 1).contiguous() + return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2]) + + # update embeddings with EMA + with torch.no_grad(): + encodings = F.one_hot(encoding_indices, self.num_embeddings).float() + updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0)) + n = torch.sum(updated_ema_cluster_size) + updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) / + (n + self.num_embeddings * self.epsilon) * n) + dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster + updated_ema_dw = self.ema_dw(dw) + normalised_updated_ema_w = ( + updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1)) + self.embeddings.data = normalised_updated_ema_w + + # commitment loss + e_latent_loss = F.mse_loss(x, quantized.detach()) + loss = self.commitment_cost * e_latent_loss + + # Straight Through Estimator + quantized = x + (quantized - x).detach() + + quantized = quantized.permute(0, 2, 1).contiguous() + return quantized, loss + + def get_code_indices(self, flat_x): + # compute L2 distance + distances = ( + torch.sum(flat_x ** 2, dim=1, keepdim=True) + + torch.sum(self.embeddings ** 2, dim=1) - + 2. * torch.matmul(flat_x, self.embeddings.t()) + ) # [N, M] + encoding_indices = torch.argmin(distances, dim=1) # [N,] + return encoding_indices + + def quantize(self, encoding_indices): + """Returns embedding tensor for a batch of indices.""" + return F.embedding(encoding_indices, self.embeddings) + + + +class Casual_Encoder(nn.Module): + def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Casual_Encoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + + self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1) + self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True) + self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True) + self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True) + self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True) + self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True) + # self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1) + + def forward(self, x): + h = self.project(x) + h, _ = self._enc_1(h) + h = self._down_1(h) + h, _ = self._enc_2(h) + h = self._down_2(h) + h, _ = self._enc_3(h) + # h = self.pre_vq_conv(h) + return h + + +class Casual_Decoder(nn.Module): + def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Casual_Decoder, self).__init__() + self._num_hiddens = num_hiddens + self._num_residual_layers = num_residual_layers + self._num_residual_hiddens = num_residual_hiddens + + # self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1) + self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True) + self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True) + self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True) + self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True) + self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True) + self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1) + + def forward(self, h, pre_state=None): + cur_state = [] + # h = self.aft_vq_conv(x) + h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None) + cur_state.append(s) + h = self._up_2(h) + h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None) + cur_state.append(s) + h = self._up_3(h) + h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None) + cur_state.append(s) + recon = self.project(h) + return recon, cur_state \ No newline at end of file diff --git a/nets/spg/wav2vec.py b/nets/spg/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d0eff66e67de14ceba283fa6ce43f156c7ddc2 --- /dev/null +++ b/nets/spg/wav2vec.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import copy +import math +from transformers import Wav2Vec2Model,Wav2Vec2Config +from transformers.modeling_outputs import BaseModelOutput +from typing import Optional, Tuple +_CONFIG_FOR_DOC = "Wav2Vec2Config" + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + return mask + +# linear interpolation layer +def linear_interpolation(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear') + return output_features.transpose(1, 2) + + +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config): + super().__init__(config) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + frame_num=None + ): + self.config.output_attentions = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num) + + if attention_mask is not None: + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(hidden_states) + + if self.config.apply_spec_augment and self.training: + batch_size, sequence_length, hidden_size = hidden_states.size() + if self.config.mask_time_prob > 0: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + self.config.mask_time_prob, + self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=2, + ) + hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + self.config.mask_feature_prob, + self.config.mask_feature_length, + ) + mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + encoder_outputs = self.encoder( + hidden_states[0], + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = encoder_outputs[0] + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/nets/utils.py b/nets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41b03374e8023016b3bec4f66ab16cb421222ef1 --- /dev/null +++ b/nets/utils.py @@ -0,0 +1,122 @@ +import json +import textgrid as tg +import numpy as np + +def get_parameter_size(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return total_num, trainable_num + +def denormalize(kps, data_mean, data_std): + ''' + kps: (B, T, C) + ''' + data_std = data_std.reshape(1, 1, -1) + data_mean = data_mean.reshape(1, 1, -1) + return (kps * data_std) + data_mean + +def normalize(kps, data_mean, data_std): + ''' + kps: (B, T, C) + ''' + data_std = data_std.squeeze().reshape(1, 1, -1) + data_mean = data_mean.squeeze().reshape(1, 1, -1) + + return (kps-data_mean) / data_std + +def parse_audio(textgrid_file): + '''a demo implementation''' + words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when'] + txt=tg.TextGrid.fromFile(textgrid_file) + + total_time=int(np.ceil(txt.maxTime)) + code_seq=np.zeros(total_time) + + word_level=txt[0] + + for i in range(len(word_level)): + start_time=word_level[i].minTime + end_time=word_level[i].maxTime + mark=word_level[i].mark + + if mark in words: + start=int(np.round(start_time)) + end=int(np.round(end_time)) + + if start >= len(code_seq) or end >= len(code_seq): + code_seq[-1] = 1 + else: + code_seq[start]=1 + + return code_seq + + +def get_path(model_name, model_type): + if model_name == 's2g_body_pixel': + if model_type == 'mfcc': + return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth' + elif model_type == 'wv2': + return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth' + elif model_type == 'random': + return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth' + elif model_type == 'wbhmodel': + return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth' + elif model_type == 'wobhmodel': + return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth' + elif model_name == 's2g_body': + if model_type == 'a+m-vae': + return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth' + elif model_type == 'a-vae': + return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth' + elif model_type == 'a-ed': + return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth' + elif model_name == 's2g_LS3DCG': + return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth' + elif model_name == 's2g_body_vq': + if model_type == 'n_com_1024': + return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth' + elif model_type == 'n_com_2048': + return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth' + elif model_type == 'n_com_4096': + return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth' + elif model_type == 'n_com_8192': + return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth' + elif model_type == 'n_com_16384': + return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth' + elif model_type == 'n_com_170000': + return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth' + elif model_type == 'com_1024': + return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth' + elif model_type == 'com_2048': + return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth' + elif model_type == 'com_4096': + return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth' + elif model_type == 'com_8192': + return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth' + elif model_type == 'com_16384': + return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth' + + +def get_dpath(model_name, model_type): + if model_name == 's2g_body_pixel': + if model_type == 'audio': + return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth' + elif model_type == 'wv2': + return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth' + elif model_type == 'random': + return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth' + elif model_type == 'wbhmodel': + return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth' + # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth' + elif model_type == 'wobhmodel': + return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth' + # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth' + elif model_name == 's2g_body': + if model_type == 'a+m-vae': + return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth' + elif model_type == 'a-vae': + return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth' + elif model_type == 'a-ed': + return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth' + elif model_name == 's2g_LS3DCG': + return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth' \ No newline at end of file diff --git a/scripts/.idea/__init__.py b/scripts/.idea/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/.idea/aws.xml b/scripts/.idea/aws.xml new file mode 100644 index 0000000000000000000000000000000000000000..b63b642cfb4254fc0f7058903abc5b481895c4ef --- /dev/null +++ b/scripts/.idea/aws.xml @@ -0,0 +1,11 @@ + + + + + + \ No newline at end of file diff --git a/scripts/.idea/deployment.xml b/scripts/.idea/deployment.xml new file mode 100644 index 0000000000000000000000000000000000000000..14f2c41a46d0210c6395ed0e7bfd3b630211f699 --- /dev/null +++ b/scripts/.idea/deployment.xml @@ -0,0 +1,70 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/scripts/.idea/get_prevar.py b/scripts/.idea/get_prevar.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b2dfb1892e9ff79c8074f35e84d897c64ff673 --- /dev/null +++ b/scripts/.idea/get_prevar.py @@ -0,0 +1,132 @@ +import os +import sys +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +sys.path.append(os.getcwd()) +from glob import glob + +import numpy as np +import json +import smplx as smpl + +from nets import * +from repro_nets import * +from trainer.options import parse_args +from data_utils import torch_data +from trainer.config import load_JsonConfig + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import data + +def init_model(model_name, model_path, args, config): + if model_name == 'freeMo': + # generator = freeMo_Generator(args) + # generator = freeMo_Generator(args) + generator = freeMo_dev(args, config) + # generator.load_state_dict(torch.load(model_path)['generator']) + elif model_name == 'smplx_S2G': + generator = smplx_S2G(args, config) + elif model_name == 'StyleGestures': + generator = StyleGesture_Generator( + args, + config + ) + elif model_name == 'Audio2Gestures': + config.Train.using_mspec_stat = False + generator = Audio2Gesture_Generator( + args, + config, + torch.zeros([1, 1, 108]), + torch.ones([1, 1, 108]) + ) + elif model_name == 'S2G': + generator = S2G_Generator( + args, + config, + ) + elif model_name == 'Tmpt': + generator = S2G_Generator( + args, + config, + ) + else: + raise NotImplementedError + + model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) + if model_name == 'smplx_S2G': + generator.generator.load_state_dict(model_ckpt['generator']['generator']) + elif 'generator' in list(model_ckpt.keys()): + generator.load_state_dict(model_ckpt['generator']) + else: + model_ckpt = {'generator': model_ckpt} + generator.load_state_dict(model_ckpt) + + return generator + + + +def prevar_loader(data_root, speakers, args, config, model_path, device, generator): + path = model_path.split('ckpt')[0] + file = os.path.join(os.path.dirname(path), "pre_variable.npy") + data_base = torch_data( + data_root=data_root, + speakers=speakers, + split='pre', + limbscaling=False, + normalization=config.Data.pose.normalization, + norm_method=config.Data.pose.norm_method, + split_trans_zero=False, + num_pre_frames=config.Data.pose.pre_pose_length, + num_generate_length=config.Data.pose.generate_length, + num_frames=15, + aud_feat_win_size=config.Data.aud.aud_feat_win_size, + aud_feat_dim=config.Data.aud.aud_feat_dim, + feat_method=config.Data.aud.feat_method, + smplx=True, + audio_sr=22000, + convert_to_6d=config.Data.pose.convert_to_6d, + expression=config.Data.pose.expression + ) + + data_base.get_dataset() + pre_set = data_base.all_dataset + pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True) + + total_pose = [] + + with torch.no_grad(): + for bat in pre_loader: + pose = bat['poses'].to(device).to(torch.float32) + expression = bat['expression'].to(device).to(torch.float32) + pose = pose.permute(0, 2, 1) + pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0) + expression = expression.permute(0, 2, 1) + expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0) + pose = torch.cat([pose, expression], dim=-1) + pose = pose.reshape(pose.shape[0], -1, 1) + pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu() + total_pose.append(np.asarray(pose_code)) + total_pose = np.concatenate(total_pose, axis=0) + mean = np.mean(total_pose, axis=0) + std = np.std(total_pose, axis=0) + prevar = (mean, std) + np.save(file, prevar, allow_pickle=True) + + return mean, std + +def main(): + parser = parse_args() + args = parser.parse_args() + device = torch.device(args.gpu) + torch.cuda.set_device(device) + + config = load_JsonConfig(args.config_file) + + print('init model...') + generator = init_model(config.Model.model_name, args.model_path, args, config) + print('init pre-pose vectors...') + mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator) + +main() \ No newline at end of file diff --git a/scripts/.idea/inspectionProfiles/Project_Default.xml b/scripts/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..3dce9c67a3cba33789d113124d53150ccca2370b --- /dev/null +++ b/scripts/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/scripts/.idea/inspectionProfiles/profiles_settings.xml b/scripts/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/scripts/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/scripts/.idea/lower body b/scripts/.idea/lower body new file mode 100644 index 0000000000000000000000000000000000000000..1efda13cfb1455b382ced16ed1ddb16d1716ae7f --- /dev/null +++ b/scripts/.idea/lower body @@ -0,0 +1 @@ +0, 1, 3, 4, 6, 7, 9, 10, \ No newline at end of file diff --git a/scripts/.idea/modules.xml b/scripts/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..bb83e262159915cb1ea30b748c3123878bf4c341 --- /dev/null +++ b/scripts/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/scripts/.idea/scripts.iml b/scripts/.idea/scripts.iml new file mode 100644 index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6 --- /dev/null +++ b/scripts/.idea/scripts.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/scripts/.idea/test.png b/scripts/.idea/test.png new file mode 100644 index 0000000000000000000000000000000000000000..b028ad2b8c4fb0b6a89571ae62d2654908c93d9c Binary files /dev/null and b/scripts/.idea/test.png differ diff --git a/scripts/.idea/testtext.py b/scripts/.idea/testtext.py new file mode 100644 index 0000000000000000000000000000000000000000..af0185442ae8b1e84c9ea64c671dde6da394046c --- /dev/null +++ b/scripts/.idea/testtext.py @@ -0,0 +1,24 @@ +import cv2 + +# path being defined from where the system will read the image +path = r'test.png' +# command used for reading an image from the disk disk, cv2.imread function is used +image1 = cv2.imread(path) +# Window name being specified where the image will be displayed +window_name1 = 'image' +# font for the text being specified +font1 = cv2.FONT_HERSHEY_SIMPLEX +# org for the text being specified +org1 = (50, 50) +# font scale for the text being specified +fontScale1 = 1 +# Blue color for the text being specified from BGR +color1 = (255, 255, 255) +# Line thickness for the text being specified at 2 px +thickness1 = 2 +# Using the cv2.putText() method for inserting text in the image of the specified path +image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA) +# Displaying the output image +cv2.imshow(window_name1, image_1) +cv2.waitKey(0) +cv2.destroyAllWindows() diff --git a/scripts/.idea/vcs.xml b/scripts/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..6c0b8635858dc7ad44b93df54b762707ce49eefc --- /dev/null +++ b/scripts/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/scripts/.idea/workspace.xml b/scripts/.idea/workspace.xml new file mode 100644 index 0000000000000000000000000000000000000000..e45519a6841bc50a93d2d3bdb05aaa935ff861a0 --- /dev/null +++ b/scripts/.idea/workspace.xml @@ -0,0 +1,75 @@ + + + + + + + + + + + + + + + + + + +