sshravani commited on
Commit
fb4330d
·
1 Parent(s): d414280

essential files only

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. evaluation/FGD.py +199 -0
  2. evaluation/__init__.py +0 -0
  3. evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
  4. evaluation/__pycache__/metrics.cpython-37.pyc +0 -0
  5. evaluation/diversity_LVD.py +64 -0
  6. evaluation/get_quality_samples.py +62 -0
  7. evaluation/metrics.py +109 -0
  8. evaluation/mode_transition.py +60 -0
  9. evaluation/peak_velocity.py +65 -0
  10. evaluation/util.py +148 -0
  11. losses/__init__.py +1 -0
  12. losses/__pycache__/__init__.cpython-37.pyc +0 -0
  13. losses/__pycache__/losses.cpython-37.pyc +0 -0
  14. losses/losses.py +91 -0
  15. nets/LS3DCG.py +414 -0
  16. nets/__init__.py +8 -0
  17. nets/__pycache__/LS3DCG.cpython-37.pyc +0 -0
  18. nets/__pycache__/__init__.cpython-37.pyc +0 -0
  19. nets/__pycache__/base.cpython-37.pyc +0 -0
  20. nets/__pycache__/body_ae.cpython-37.pyc +0 -0
  21. nets/__pycache__/init_model.cpython-37.pyc +0 -0
  22. nets/__pycache__/layers.cpython-37.pyc +0 -0
  23. nets/__pycache__/smplx_body_pixel.cpython-37.pyc +0 -0
  24. nets/__pycache__/smplx_body_vq.cpython-37.pyc +0 -0
  25. nets/__pycache__/smplx_face.cpython-37.pyc +0 -0
  26. nets/__pycache__/utils.cpython-37.pyc +0 -0
  27. nets/base.py +89 -0
  28. nets/body_ae.py +152 -0
  29. nets/init_model.py +35 -0
  30. nets/layers.py +1052 -0
  31. nets/smplx_body_pixel.py +326 -0
  32. nets/smplx_body_vq.py +302 -0
  33. nets/smplx_face.py +238 -0
  34. nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc +0 -0
  35. nets/spg/__pycache__/s2g_face.cpython-37.pyc +0 -0
  36. nets/spg/__pycache__/s2glayers.cpython-37.pyc +0 -0
  37. nets/spg/__pycache__/vqvae_1d.cpython-37.pyc +0 -0
  38. nets/spg/__pycache__/vqvae_modules.cpython-37.pyc +0 -0
  39. nets/spg/__pycache__/wav2vec.cpython-37.pyc +0 -0
  40. nets/spg/gated_pixelcnn_v2.py +179 -0
  41. nets/spg/s2g_face.py +226 -0
  42. nets/spg/s2glayers.py +522 -0
  43. nets/spg/vqvae_1d.py +235 -0
  44. nets/spg/vqvae_modules.py +380 -0
  45. nets/spg/wav2vec.py +143 -0
  46. nets/utils.py +122 -0
  47. scripts/.idea/__init__.py +0 -0
  48. scripts/.idea/aws.xml +11 -0
  49. scripts/.idea/deployment.xml +70 -0
  50. scripts/.idea/get_prevar.py +132 -0
evaluation/FGD.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from scipy import linalg
7
+ import math
8
+ from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
9
+
10
+ import warnings
11
+ warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
12
+
13
+
14
+ change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
15
+ class EmbeddingSpaceEvaluator:
16
+ def __init__(self, ae, vae, device):
17
+
18
+ # init embed net
19
+ self.ae = ae
20
+ # self.vae = vae
21
+
22
+ # storage
23
+ self.real_feat_list = []
24
+ self.generated_feat_list = []
25
+ self.real_joints_list = []
26
+ self.generated_joints_list = []
27
+ self.real_6d_list = []
28
+ self.generated_6d_list = []
29
+ self.audio_beat_list = []
30
+
31
+ def reset(self):
32
+ self.real_feat_list = []
33
+ self.generated_feat_list = []
34
+
35
+ def get_no_of_samples(self):
36
+ return len(self.real_feat_list)
37
+
38
+ def push_samples(self, generated_poses, real_poses):
39
+ # self.net.eval()
40
+ # convert poses to latent features
41
+ real_feat, real_poses = self.ae.extract(real_poses)
42
+ generated_feat, generated_poses = self.ae.extract(generated_poses)
43
+
44
+ num_joints = real_poses.shape[2] // 3
45
+
46
+ real_feat = real_feat.squeeze()
47
+ generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
48
+
49
+ self.real_feat_list.append(real_feat.data.cpu().numpy())
50
+ self.generated_feat_list.append(generated_feat.data.cpu().numpy())
51
+
52
+ # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
53
+ # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
54
+ #
55
+ # self.real_feat_list.append(real_poses.data.cpu().numpy())
56
+ # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
57
+
58
+ def push_joints(self, generated_poses, real_poses):
59
+ self.real_joints_list.append(real_poses.data.cpu())
60
+ self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
61
+
62
+ def push_aud(self, aud):
63
+ self.audio_beat_list.append(aud.squeeze().data.cpu())
64
+
65
+ def get_MAAC(self):
66
+ ang_vel_list = []
67
+ for real_joints in self.real_joints_list:
68
+ real_joints[:, 15:21] = real_joints[:, 16:22]
69
+ vec = real_joints[:, 15:21] - real_joints[:, 13:19]
70
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
71
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
72
+ angle = torch.acos(inner_product) / math.pi
73
+ ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
74
+ ang_vel_list.append(ang_vel.unsqueeze(dim=0))
75
+ all_vel = torch.cat(ang_vel_list, dim=0)
76
+ MAAC = all_vel.mean(dim=0)
77
+ return MAAC
78
+
79
+ def get_BCscore(self):
80
+ thres = 0.01
81
+ sigma = 0.1
82
+ sum_1 = 0
83
+ total_beat = 0
84
+ for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
85
+ motion_beat_time = []
86
+ if joints.dim() == 4:
87
+ joints = joints[0]
88
+ joints[:, 15:21] = joints[:, 16:22]
89
+ vec = joints[:, 15:21] - joints[:, 13:19]
90
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
91
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
92
+ angle = torch.acos(inner_product) / math.pi
93
+ ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
94
+
95
+ angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
96
+
97
+ sum_2 = 0
98
+ for i in range(angle_diff.shape[1]):
99
+ motion_beat_time = []
100
+ for t in range(1, joints.shape[0]-1):
101
+ if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
102
+ if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
103
+ t][i] >= thres):
104
+ motion_beat_time.append(float(t) / 30.0)
105
+ if (len(motion_beat_time) == 0):
106
+ continue
107
+ motion_beat_time = torch.tensor(motion_beat_time)
108
+ sum = 0
109
+ for audio in audio_beat_time:
110
+ sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
111
+ sum_2 = sum_2 + sum
112
+ total_beat = total_beat + len(audio_beat_time)
113
+ sum_1 = sum_1 + sum_2
114
+ return sum_1/total_beat
115
+
116
+
117
+ def get_scores(self):
118
+ generated_feats = np.vstack(self.generated_feat_list)
119
+ real_feats = np.vstack(self.real_feat_list)
120
+
121
+ def frechet_distance(samples_A, samples_B):
122
+ A_mu = np.mean(samples_A, axis=0)
123
+ A_sigma = np.cov(samples_A, rowvar=False)
124
+ B_mu = np.mean(samples_B, axis=0)
125
+ B_sigma = np.cov(samples_B, rowvar=False)
126
+ try:
127
+ frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
128
+ except ValueError:
129
+ frechet_dist = 1e+10
130
+ return frechet_dist
131
+
132
+ ####################################################################
133
+ # frechet distance
134
+ frechet_dist = frechet_distance(generated_feats, real_feats)
135
+
136
+ ####################################################################
137
+ # distance between real and generated samples on the latent feature space
138
+ dists = []
139
+ for i in range(real_feats.shape[0]):
140
+ d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
141
+ dists.append(d)
142
+ feat_dist = np.mean(dists)
143
+
144
+ return frechet_dist, feat_dist
145
+
146
+ @staticmethod
147
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
148
+ """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
149
+ """Numpy implementation of the Frechet Distance.
150
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
151
+ and X_2 ~ N(mu_2, C_2) is
152
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
153
+ Stable version by Dougal J. Sutherland.
154
+ Params:
155
+ -- mu1 : Numpy array containing the activations of a layer of the
156
+ inception net (like returned by the function 'get_predictions')
157
+ for generated samples.
158
+ -- mu2 : The sample mean over activations, precalculated on an
159
+ representative data set.
160
+ -- sigma1: The covariance matrix over activations for generated samples.
161
+ -- sigma2: The covariance matrix over activations, precalculated on an
162
+ representative data set.
163
+ Returns:
164
+ -- : The Frechet Distance.
165
+ """
166
+
167
+ mu1 = np.atleast_1d(mu1)
168
+ mu2 = np.atleast_1d(mu2)
169
+
170
+ sigma1 = np.atleast_2d(sigma1)
171
+ sigma2 = np.atleast_2d(sigma2)
172
+
173
+ assert mu1.shape == mu2.shape, \
174
+ 'Training and test mean vectors have different lengths'
175
+ assert sigma1.shape == sigma2.shape, \
176
+ 'Training and test covariances have different dimensions'
177
+
178
+ diff = mu1 - mu2
179
+
180
+ # Product might be almost singular
181
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
182
+ if not np.isfinite(covmean).all():
183
+ msg = ('fid calculation produces singular product; '
184
+ 'adding %s to diagonal of cov estimates') % eps
185
+ print(msg)
186
+ offset = np.eye(sigma1.shape[0]) * eps
187
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
188
+
189
+ # Numerical error might give slight imaginary component
190
+ if np.iscomplexobj(covmean):
191
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
192
+ m = np.max(np.abs(covmean.imag))
193
+ raise ValueError('Imaginary component {}'.format(m))
194
+ covmean = covmean.real
195
+
196
+ tr_covmean = np.trace(covmean)
197
+
198
+ return (diff.dot(diff) + np.trace(sigma1) +
199
+ np.trace(sigma2) - 2 * tr_covmean)
evaluation/__init__.py ADDED
File without changes
evaluation/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (181 Bytes). View file
 
evaluation/__pycache__/metrics.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
evaluation/diversity_LVD.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LVD: different initial pose
3
+ diversity: same initial pose
4
+ '''
5
+ import os
6
+ import sys
7
+ sys.path.append(os.getcwd())
8
+
9
+ from glob import glob
10
+
11
+ from argparse import ArgumentParser
12
+ import json
13
+
14
+ from evaluation.util import *
15
+ from evaluation.metrics import *
16
+ from tqdm import tqdm
17
+
18
+ parser = ArgumentParser()
19
+ parser.add_argument('--speaker', required=True, type=str)
20
+ parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
21
+ args = parser.parse_args()
22
+
23
+ speaker = args.speaker
24
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
25
+
26
+ LVD_list = []
27
+ diversity_list = []
28
+
29
+ for aud in tqdm(test_audios):
30
+ base_name = os.path.splitext(aud)[0]
31
+ gt_path = get_full_path(aud, speaker, 'val')
32
+ _, gt_poses, _ = get_gts(gt_path)
33
+ gt_poses = gt_poses[np.newaxis,...]
34
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
35
+ for post_fix in args.post_fix:
36
+ pred_path = base_name + '_'+post_fix+'.json'
37
+ pred_poses = np.array(json.load(open(pred_path)))
38
+ # print(pred_poses.shape)#(B, seq_len, 108)
39
+ pred_poses = cvt25(pred_poses, gt_poses)
40
+ # print(pred_poses.shape)#(B, seq, pose_dim)
41
+
42
+ gt_valid_points = hand_points(gt_poses)
43
+ pred_valid_points = hand_points(pred_poses)
44
+
45
+ lvd = LVD(gt_valid_points, pred_valid_points)
46
+ # div = diversity(pred_valid_points)
47
+
48
+ LVD_list.append(lvd)
49
+ # diversity_list.append(div)
50
+
51
+ # gt_velocity = peak_velocity(gt_valid_points, order=2)
52
+ # pred_velocity = peak_velocity(pred_valid_points, order=2)
53
+
54
+ # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
55
+ # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
56
+
57
+ # gt_consistency_list.append(gt_consistency)
58
+ # pred_consistency_list.append(pred_consistency)
59
+
60
+ lvd = np.mean(LVD_list)
61
+ # diversity_list = np.mean(diversity_list)
62
+
63
+ print('LVD:', lvd)
64
+ # print("diversity:", diversity_list)
evaluation/get_quality_samples.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ '''
3
+ import os
4
+ import sys
5
+ sys.path.append(os.getcwd())
6
+
7
+ from glob import glob
8
+
9
+ from argparse import ArgumentParser
10
+ import json
11
+
12
+ from evaluation.util import *
13
+ from evaluation.metrics import *
14
+ from tqdm import tqdm
15
+
16
+ parser = ArgumentParser()
17
+ parser.add_argument('--speaker', required=True, type=str)
18
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
19
+ args = parser.parse_args()
20
+
21
+ speaker = args.speaker
22
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
23
+
24
+ quality_samples={'gt':[]}
25
+ for post_fix in args.post_fix:
26
+ quality_samples[post_fix] = []
27
+
28
+ for aud in tqdm(test_audios):
29
+ base_name = os.path.splitext(aud)[0]
30
+ gt_path = get_full_path(aud, speaker, 'val')
31
+ _, gt_poses, _ = get_gts(gt_path)
32
+ gt_poses = gt_poses[np.newaxis,...]
33
+ gt_valid_points = valid_points(gt_poses)
34
+ # print(gt_valid_points.shape)
35
+ quality_samples['gt'].append(gt_valid_points)
36
+
37
+ for post_fix in args.post_fix:
38
+ pred_path = base_name + '_'+post_fix+'.json'
39
+ pred_poses = np.array(json.load(open(pred_path)))
40
+ # print(pred_poses.shape)#(B, seq_len, 108)
41
+ pred_poses = cvt25(pred_poses, gt_poses)
42
+ # print(pred_poses.shape)#(B, seq, pose_dim)
43
+
44
+ pred_valid_points = valid_points(pred_poses)[0:1]
45
+ quality_samples[post_fix].append(pred_valid_points)
46
+
47
+ quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
48
+ for post_fix in args.post_fix:
49
+ quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
50
+
51
+ print('gt:', quality_samples['gt'].shape)
52
+ quality_samples['gt'] = quality_samples['gt'].tolist()
53
+ for post_fix in args.post_fix:
54
+ print(post_fix, ':', quality_samples[post_fix].shape)
55
+ quality_samples[post_fix] = quality_samples[post_fix].tolist()
56
+
57
+ save_dir = '../../experiments/'
58
+ os.makedirs(save_dir, exist_ok=True)
59
+ save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
60
+ with open(save_name, 'w') as f:
61
+ json.dump(quality_samples, f)
62
+
evaluation/metrics.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Warning: metrics are for reference only, may have limited significance
3
+ '''
4
+ import os
5
+ import sys
6
+ sys.path.append(os.getcwd())
7
+ import numpy as np
8
+ import torch
9
+
10
+ from data_utils.lower_body import rearrange, symmetry
11
+ import torch.nn.functional as F
12
+
13
+ def data_driven_baselines(gt_kps):
14
+ '''
15
+ gt_kps: T, D
16
+ '''
17
+ gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
18
+
19
+ mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
20
+ mean = np.mean(np.abs(gt_velocity-mean))
21
+ last_step = gt_kps[1] - gt_kps[0]
22
+ last_step = last_step[np.newaxis] #(1, D)
23
+ last_step = np.mean(np.abs(gt_velocity-last_step))
24
+ return last_step, mean
25
+
26
+ def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
27
+ if gt_kps.shape[0] > pr_kps.shape[1]:
28
+ length = pr_kps.shape[1]
29
+ else:
30
+ length = gt_kps.shape[0]
31
+ gt_kps = gt_kps[:length]
32
+ pr_kps = pr_kps[:, :length]
33
+ global symmetry
34
+ symmetry = torch.tensor(symmetry).bool()
35
+
36
+ if symmetrical:
37
+ # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
38
+ gt_kps = gt_kps[:, rearrange]
39
+ ns_gt_kps = gt_kps[:, ~symmetry]
40
+ ys_gt_kps = gt_kps[:, symmetry]
41
+ ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
42
+ ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
43
+ ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
44
+ left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
45
+ right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
46
+ move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
47
+ ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
48
+ ys_gt_velocity = ys_gt_velocity.transpose(0,1)
49
+ gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
50
+
51
+ pr_kps = pr_kps[:, :, rearrange]
52
+ ns_pr_kps = pr_kps[:, :, ~symmetry]
53
+ ys_pr_kps = pr_kps[:, :, symmetry]
54
+ ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
55
+ ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
56
+ ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
57
+ left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
58
+ right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
59
+ move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
60
+ torch.zeros(left_pr_vel.shape).cuda())
61
+ ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
62
+ ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
63
+ ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
64
+ pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
65
+ else:
66
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
67
+ pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
68
+
69
+ if weight:
70
+ w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
71
+ else:
72
+ w = 1 / gt_velocity.shape[0]
73
+
74
+ v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
75
+
76
+ return v_diff
77
+
78
+
79
+ def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
80
+ gt_kps = gt_kps.squeeze()
81
+ pr_kps = pr_kps.squeeze()
82
+ if len(pr_kps.shape) == 4:
83
+ return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
84
+ # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
85
+ length = gt_kps.shape[0]-10
86
+ # gt_kps = gt_kps[25:length]
87
+ # pr_kps = pr_kps[25:length] #(T, D)
88
+ # if pr_kps.shape[0] < gt_kps.shape[0]:
89
+ # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
90
+
91
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
92
+ pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
93
+
94
+ return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
95
+
96
+ def diversity(kps):
97
+ '''
98
+ kps: bs, seq, dim
99
+ '''
100
+ dis_list = []
101
+ #the distance between each pair
102
+ for i in range(kps.shape[0]):
103
+ for j in range(i+1, kps.shape[0]):
104
+ seq_i = kps[i]
105
+ seq_j = kps[j]
106
+
107
+ dis = np.mean(np.abs(seq_i - seq_j))
108
+ dis_list.append(dis)
109
+ return np.mean(dis_list)
evaluation/mode_transition.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+
5
+ from glob import glob
6
+
7
+ from argparse import ArgumentParser
8
+ import json
9
+
10
+ from evaluation.util import *
11
+ from evaluation.metrics import *
12
+ from tqdm import tqdm
13
+
14
+ parser = ArgumentParser()
15
+ parser.add_argument('--speaker', required=True, type=str)
16
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
+ args = parser.parse_args()
18
+
19
+ speaker = args.speaker
20
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
+
22
+ precision_list=[]
23
+ recall_list=[]
24
+ accuracy_list=[]
25
+
26
+ for aud in tqdm(test_audios):
27
+ base_name = os.path.splitext(aud)[0]
28
+ gt_path = get_full_path(aud, speaker, 'val')
29
+ _, gt_poses, _ = get_gts(gt_path)
30
+ if gt_poses.shape[0] < 50:
31
+ continue
32
+ gt_poses = gt_poses[np.newaxis,...]
33
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
34
+ for post_fix in args.post_fix:
35
+ pred_path = base_name + '_'+post_fix+'.json'
36
+ pred_poses = np.array(json.load(open(pred_path)))
37
+ # print(pred_poses.shape)#(B, seq_len, 108)
38
+ pred_poses = cvt25(pred_poses, gt_poses)
39
+ # print(pred_poses.shape)#(B, seq, pose_dim)
40
+
41
+ gt_valid_points = valid_points(gt_poses)
42
+ pred_valid_points = valid_points(pred_poses)
43
+
44
+ # print(gt_valid_points.shape, pred_valid_points.shape)
45
+
46
+ gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
47
+ pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
48
+
49
+ # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
50
+ # pred_mode_transition_seq = baseline
51
+ precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
52
+ precision_list.append(precision)
53
+ recall_list.append(recall)
54
+ accuracy_list.append(accuracy)
55
+ print(len(precision_list), len(recall_list), len(accuracy_list))
56
+ precision_list = np.mean(precision_list)
57
+ recall_list = np.mean(recall_list)
58
+ accuracy_list = np.mean(accuracy_list)
59
+
60
+ print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
evaluation/peak_velocity.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+
5
+ from glob import glob
6
+
7
+ from argparse import ArgumentParser
8
+ import json
9
+
10
+ from evaluation.util import *
11
+ from evaluation.metrics import *
12
+ from tqdm import tqdm
13
+
14
+ parser = ArgumentParser()
15
+ parser.add_argument('--speaker', required=True, type=str)
16
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
+ args = parser.parse_args()
18
+
19
+ speaker = args.speaker
20
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
+
22
+ gt_consistency_list=[]
23
+ pred_consistency_list=[]
24
+
25
+ for aud in tqdm(test_audios):
26
+ base_name = os.path.splitext(aud)[0]
27
+ gt_path = get_full_path(aud, speaker, 'val')
28
+ _, gt_poses, _ = get_gts(gt_path)
29
+ gt_poses = gt_poses[np.newaxis,...]
30
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
31
+ for post_fix in args.post_fix:
32
+ pred_path = base_name + '_'+post_fix+'.json'
33
+ pred_poses = np.array(json.load(open(pred_path)))
34
+ # print(pred_poses.shape)#(B, seq_len, 108)
35
+ pred_poses = cvt25(pred_poses, gt_poses)
36
+ # print(pred_poses.shape)#(B, seq, pose_dim)
37
+
38
+ gt_valid_points = hand_points(gt_poses)
39
+ pred_valid_points = hand_points(pred_poses)
40
+
41
+ gt_velocity = peak_velocity(gt_valid_points, order=2)
42
+ pred_velocity = peak_velocity(pred_valid_points, order=2)
43
+
44
+ gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
45
+ pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
46
+
47
+ gt_consistency_list.append(gt_consistency)
48
+ pred_consistency_list.append(pred_consistency)
49
+
50
+ gt_consistency_list = np.concatenate(gt_consistency_list)
51
+ pred_consistency_list = np.concatenate(pred_consistency_list)
52
+
53
+ print(gt_consistency_list.max(), gt_consistency_list.min())
54
+ print(pred_consistency_list.max(), pred_consistency_list.min())
55
+ print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
56
+ print(np.std(gt_consistency_list), np.std(pred_consistency_list))
57
+
58
+ draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
59
+ draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
60
+
61
+ to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
62
+ to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
63
+
64
+ np.save('%s_gt.npy'%(speaker), gt_consistency_list)
65
+ np.save('%s_pred.npy'%(speaker), pred_consistency_list)
evaluation/util.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import numpy as np
4
+ import json
5
+ from matplotlib import pyplot as plt
6
+ import pandas as pd
7
+ def get_gts(clip):
8
+ '''
9
+ clip: abs path to the clip dir
10
+ '''
11
+ keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
12
+
13
+ upper_body_points = list(np.arange(0, 25))
14
+ poses = []
15
+ confs = []
16
+ neck_to_nose_len = []
17
+ mean_position = []
18
+ for kp_file in keypoints_files:
19
+ kp_load = json.load(open(kp_file, 'r'))['people'][0]
20
+ posepts = kp_load['pose_keypoints_2d']
21
+ lhandpts = kp_load['hand_left_keypoints_2d']
22
+ rhandpts = kp_load['hand_right_keypoints_2d']
23
+ facepts = kp_load['face_keypoints_2d']
24
+
25
+ neck = np.array(posepts).reshape(-1,3)[1]
26
+ nose = np.array(posepts).reshape(-1,3)[0]
27
+ x_offset = abs(neck[0]-nose[0])
28
+ y_offset = abs(neck[1]-nose[1])
29
+ neck_to_nose_len.append(y_offset)
30
+ mean_position.append([neck[0],neck[1]])
31
+
32
+ keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
33
+
34
+ upper_body = keypoints[upper_body_points, :]
35
+ hand_points = keypoints[25:, :]
36
+ keypoints = np.vstack([upper_body, hand_points])
37
+
38
+ poses.append(keypoints)
39
+
40
+ if len(neck_to_nose_len) > 0:
41
+ scale_factor = np.mean(neck_to_nose_len)
42
+ else:
43
+ raise ValueError(clip)
44
+ mean_position = np.mean(np.array(mean_position), axis=0)
45
+
46
+ unlocalized_poses = np.array(poses).copy()
47
+ localized_poses = []
48
+ for i in range(len(poses)):
49
+ keypoints = poses[i]
50
+ neck = keypoints[1].copy()
51
+
52
+ keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
53
+ keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
54
+ localized_poses.append(keypoints.reshape(-1))
55
+
56
+ localized_poses=np.array(localized_poses)
57
+ return unlocalized_poses, localized_poses, (scale_factor, mean_position)
58
+
59
+ def get_full_path(wav_name, speaker, split):
60
+ '''
61
+ get clip path from aud file
62
+ '''
63
+ wav_name = os.path.basename(wav_name)
64
+ wav_name = os.path.splitext(wav_name)[0]
65
+ clip_name, vid_name = wav_name[:10], wav_name[11:]
66
+
67
+ full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
68
+
69
+ assert os.path.isdir(full_path), full_path
70
+
71
+ return full_path
72
+
73
+ def smooth(res):
74
+ '''
75
+ res: (B, seq_len, pose_dim)
76
+ '''
77
+ window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
78
+ w_size=7
79
+ for i in range(10, res.shape[1]-3):
80
+ window.append(res[:, i+3, :])
81
+ if len(window) > w_size:
82
+ window = window[1:]
83
+
84
+ if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
85
+ res[:, i, :] = np.mean(window, axis=1)
86
+
87
+ return res
88
+
89
+ def cvt25(pred_poses, gt_poses=None):
90
+ '''
91
+ gt_poses: (1, seq_len, 270), 135 *2
92
+ pred_poses: (B, seq_len, 108), 54 * 2
93
+ '''
94
+ if gt_poses is None:
95
+ gt_poses = np.zeros_like(pred_poses)
96
+ else:
97
+ gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
98
+
99
+ length = min(pred_poses.shape[1], gt_poses.shape[1])
100
+ pred_poses = pred_poses[:, :length, :]
101
+ gt_poses = gt_poses[:, :length, :]
102
+ gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
103
+ pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
104
+
105
+ gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
106
+ gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
107
+
108
+ return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
109
+
110
+ def hand_points(seq):
111
+ '''
112
+ seq: (B, seq_len, 135*2)
113
+ hands only
114
+ '''
115
+ hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
116
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
117
+ return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
118
+
119
+ def valid_points(seq):
120
+ '''
121
+ hands with some head points
122
+ '''
123
+ valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
124
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
125
+
126
+ seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
127
+ assert seq.shape[-1] == 108, seq.shape
128
+ return seq
129
+
130
+ def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
131
+ plt.figure()
132
+ plt.hist(seq, bins=100, range=(0, 100), color=color)
133
+ plt.savefig(save_name)
134
+
135
+ def to_excel(seq, save_name='res.xlsx'):
136
+ '''
137
+ seq: (T)
138
+ '''
139
+ df = pd.DataFrame(seq)
140
+ writer = pd.ExcelWriter(save_name)
141
+ df.to_excel(writer, 'sheet1')
142
+ writer.save()
143
+ writer.close()
144
+
145
+
146
+ if __name__ == '__main__':
147
+ random_data = np.random.randint(0, 10, 100)
148
+ draw_cdf(random_data)
losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .losses import *
losses/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (140 Bytes). View file
 
losses/__pycache__/losses.cpython-37.pyc ADDED
Binary file (3.5 kB). View file
 
losses/losses.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+
11
+ class KeypointLoss(nn.Module):
12
+ def __init__(self):
13
+ super(KeypointLoss, self).__init__()
14
+
15
+ def forward(self, pred_seq, gt_seq, gt_conf=None):
16
+ #pred_seq: (B, C, T)
17
+ if gt_conf is not None:
18
+ gt_conf = gt_conf >= 0.01
19
+ return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
20
+ else:
21
+ return F.mse_loss(pred_seq, gt_seq)
22
+
23
+
24
+ class KLLoss(nn.Module):
25
+ def __init__(self, kl_tolerance):
26
+ super(KLLoss, self).__init__()
27
+ self.kl_tolerance = kl_tolerance
28
+
29
+ def forward(self, mu, var, mul=1):
30
+ kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
31
+ kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
32
+ # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
33
+ if self.kl_tolerance is not None:
34
+ # above_line = kld_loss[kld_loss > self.kl_tolerance]
35
+ # if len(above_line) > 0:
36
+ # kld_loss = torch.mean(kld_loss)
37
+ # else:
38
+ # kld_loss = 0
39
+ kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
40
+ # else:
41
+ kld_loss = torch.mean(kld_loss)
42
+ return kld_loss
43
+
44
+
45
+ class L2KLLoss(nn.Module):
46
+ def __init__(self, kl_tolerance):
47
+ super(L2KLLoss, self).__init__()
48
+ self.kl_tolerance = kl_tolerance
49
+
50
+ def forward(self, x):
51
+ # TODO: check
52
+ kld_loss = torch.sum(x ** 2, dim=1)
53
+ if self.kl_tolerance is not None:
54
+ above_line = kld_loss[kld_loss > self.kl_tolerance]
55
+ if len(above_line) > 0:
56
+ kld_loss = torch.mean(kld_loss)
57
+ else:
58
+ kld_loss = 0
59
+ else:
60
+ kld_loss = torch.mean(kld_loss)
61
+ return kld_loss
62
+
63
+ class L2RegLoss(nn.Module):
64
+ def __init__(self):
65
+ super(L2RegLoss, self).__init__()
66
+
67
+ def forward(self, x):
68
+ #TODO: check
69
+ return torch.sum(x**2)
70
+
71
+
72
+ class L2Loss(nn.Module):
73
+ def __init__(self):
74
+ super(L2Loss, self).__init__()
75
+
76
+ def forward(self, x):
77
+ # TODO: check
78
+ return torch.sum(x ** 2)
79
+
80
+
81
+ class AudioLoss(nn.Module):
82
+ def __init__(self):
83
+ super(AudioLoss, self).__init__()
84
+
85
+ def forward(self, dynamics, gt_poses):
86
+ #pay attention, normalized
87
+ mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
88
+ gt = gt_poses - mean
89
+ return F.mse_loss(dynamics, gt)
90
+
91
+ L1Loss = nn.L1Loss
nets/LS3DCG.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ not exactly the same as the official repo but the results are good
3
+ '''
4
+ import sys
5
+ import os
6
+
7
+ from data_utils.lower_body import c_index_3d, c_index_6d
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+ import math
17
+
18
+ from nets.base import TrainWrapperBaseClass
19
+ from nets.layers import SeqEncoder1D
20
+ from losses import KeypointLoss, L1Loss, KLLoss
21
+ from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
22
+ from nets.utils import denormalize
23
+
24
+ class Conv1d_tf(nn.Conv1d):
25
+ """
26
+ Conv1d with the padding behavior from TF
27
+ modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
28
+ """
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super(Conv1d_tf, self).__init__(*args, **kwargs)
32
+ self.padding = kwargs.get("padding", "same")
33
+
34
+ def _compute_padding(self, input, dim):
35
+ input_size = input.size(dim + 2)
36
+ filter_size = self.weight.size(dim + 2)
37
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
38
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
39
+ total_padding = max(
40
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
41
+ )
42
+ additional_padding = int(total_padding % 2 != 0)
43
+
44
+ return additional_padding, total_padding
45
+
46
+ def forward(self, input):
47
+ if self.padding == "VALID":
48
+ return F.conv1d(
49
+ input,
50
+ self.weight,
51
+ self.bias,
52
+ self.stride,
53
+ padding=0,
54
+ dilation=self.dilation,
55
+ groups=self.groups,
56
+ )
57
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
58
+ if rows_odd:
59
+ input = F.pad(input, [0, rows_odd])
60
+
61
+ return F.conv1d(
62
+ input,
63
+ self.weight,
64
+ self.bias,
65
+ self.stride,
66
+ padding=(padding_rows // 2),
67
+ dilation=self.dilation,
68
+ groups=self.groups,
69
+ )
70
+
71
+
72
+ def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
73
+ if k is None and s is None:
74
+ if not downsample:
75
+ k = 3
76
+ s = 1
77
+ else:
78
+ k = 4
79
+ s = 2
80
+
81
+ if type == '1d':
82
+ conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
83
+ if norm == 'bn':
84
+ norm_block = nn.BatchNorm1d(out_channels)
85
+ elif norm == 'ln':
86
+ norm_block = nn.LayerNorm(out_channels)
87
+ elif type == '2d':
88
+ conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
89
+ norm_block = nn.BatchNorm2d(out_channels)
90
+ else:
91
+ assert False
92
+
93
+ return nn.Sequential(
94
+ conv_block,
95
+ norm_block,
96
+ nn.LeakyReLU(0.2, True)
97
+ )
98
+
99
+ class Decoder(nn.Module):
100
+ def __init__(self, in_ch, out_ch):
101
+ super(Decoder, self).__init__()
102
+ self.up1 = nn.Sequential(
103
+ ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
104
+ ConvNormRelu(in_ch // 2, in_ch // 2),
105
+ nn.Upsample(scale_factor=2, mode='nearest')
106
+ )
107
+ self.up2 = nn.Sequential(
108
+ ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
109
+ ConvNormRelu(in_ch // 4, in_ch // 4),
110
+ nn.Upsample(scale_factor=2, mode='nearest')
111
+ )
112
+ self.up3 = nn.Sequential(
113
+ ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
114
+ ConvNormRelu(in_ch // 8, in_ch // 8),
115
+ nn.Conv1d(in_ch // 8, out_ch, 1, 1)
116
+ )
117
+
118
+ def forward(self, x, x1, x2, x3):
119
+ x = F.interpolate(x, x3.shape[2])
120
+ x = torch.cat([x, x3], dim=1)
121
+ x = self.up1(x)
122
+ x = F.interpolate(x, x2.shape[2])
123
+ x = torch.cat([x, x2], dim=1)
124
+ x = self.up2(x)
125
+ x = F.interpolate(x, x1.shape[2])
126
+ x = torch.cat([x, x1], dim=1)
127
+ x = self.up3(x)
128
+ return x
129
+
130
+
131
+ class EncoderDecoder(nn.Module):
132
+ def __init__(self, n_frames, each_dim):
133
+ super().__init__()
134
+ self.n_frames = n_frames
135
+
136
+ self.down1 = nn.Sequential(
137
+ ConvNormRelu(64, 64, '1d', False),
138
+ ConvNormRelu(64, 128, '1d', False),
139
+ )
140
+ self.down2 = nn.Sequential(
141
+ ConvNormRelu(128, 128, '1d', False),
142
+ ConvNormRelu(128, 256, '1d', False),
143
+ )
144
+ self.down3 = nn.Sequential(
145
+ ConvNormRelu(256, 256, '1d', False),
146
+ ConvNormRelu(256, 512, '1d', False),
147
+ )
148
+ self.down4 = nn.Sequential(
149
+ ConvNormRelu(512, 512, '1d', False),
150
+ ConvNormRelu(512, 1024, '1d', False),
151
+ )
152
+
153
+ self.down = nn.MaxPool1d(kernel_size=2)
154
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
155
+
156
+ self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
157
+ self.body_decoder = Decoder(1024, each_dim[1])
158
+ self.hand_decoder = Decoder(1024, each_dim[2])
159
+
160
+ def forward(self, spectrogram, time_steps=None):
161
+ if time_steps is None:
162
+ time_steps = self.n_frames
163
+
164
+ x1 = self.down1(spectrogram)
165
+ x = self.down(x1)
166
+ x2 = self.down2(x)
167
+ x = self.down(x2)
168
+ x3 = self.down3(x)
169
+ x = self.down(x3)
170
+ x = self.down4(x)
171
+ x = self.up(x)
172
+
173
+ face = self.face_decoder(x, x1, x2, x3)
174
+ body = self.body_decoder(x, x1, x2, x3)
175
+ hand = self.hand_decoder(x, x1, x2, x3)
176
+
177
+ return face, body, hand
178
+
179
+
180
+ class Generator(nn.Module):
181
+ def __init__(self,
182
+ each_dim,
183
+ training=False,
184
+ device=None
185
+ ):
186
+ super().__init__()
187
+
188
+ self.training = training
189
+ self.device = device
190
+
191
+ self.encoderdecoder = EncoderDecoder(15, each_dim)
192
+
193
+ def forward(self, in_spec, time_steps=None):
194
+ if time_steps is not None:
195
+ self.gen_length = time_steps
196
+
197
+ face, body, hand = self.encoderdecoder(in_spec)
198
+ out = torch.cat([face, body, hand], dim=1)
199
+ out = out.transpose(1, 2)
200
+
201
+ return out
202
+
203
+
204
+ class Discriminator(nn.Module):
205
+ def __init__(self, input_dim):
206
+ super().__init__()
207
+ self.net = nn.Sequential(
208
+ ConvNormRelu(input_dim, 128, '1d'),
209
+ ConvNormRelu(128, 256, '1d'),
210
+ nn.MaxPool1d(kernel_size=2),
211
+ ConvNormRelu(256, 256, '1d'),
212
+ ConvNormRelu(256, 512, '1d'),
213
+ nn.MaxPool1d(kernel_size=2),
214
+ ConvNormRelu(512, 512, '1d'),
215
+ ConvNormRelu(512, 1024, '1d'),
216
+ nn.MaxPool1d(kernel_size=2),
217
+ nn.Conv1d(1024, 1, 1, 1),
218
+ nn.Sigmoid()
219
+ )
220
+
221
+ def forward(self, x):
222
+ x = x.transpose(1, 2)
223
+
224
+ out = self.net(x)
225
+ return out
226
+
227
+
228
+ class TrainWrapper(TrainWrapperBaseClass):
229
+ def __init__(self, args, config) -> None:
230
+ self.args = args
231
+ self.config = config
232
+ self.device = torch.device(self.args.gpu)
233
+ self.global_step = 0
234
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
235
+ self.init_params()
236
+
237
+ self.generator = Generator(
238
+ each_dim=self.each_dim,
239
+ training=not self.args.infer,
240
+ device=self.device,
241
+ ).to(self.device)
242
+ self.discriminator = Discriminator(
243
+ input_dim=self.each_dim[1] + self.each_dim[2] + 64
244
+ ).to(self.device)
245
+ if self.convert_to_6d:
246
+ self.c_index = c_index_6d
247
+ else:
248
+ self.c_index = c_index_3d
249
+ self.MSELoss = KeypointLoss().to(self.device)
250
+ self.L1Loss = L1Loss().to(self.device)
251
+ super().__init__(args, config)
252
+
253
+ def init_params(self):
254
+ scale = 1
255
+
256
+ global_orient = round(0 * scale)
257
+ leye_pose = reye_pose = round(0 * scale)
258
+ jaw_pose = round(3 * scale)
259
+ body_pose = round((63 - 24) * scale)
260
+ left_hand_pose = right_hand_pose = round(45 * scale)
261
+
262
+ expression = 100
263
+
264
+ b_j = 0
265
+ jaw_dim = jaw_pose
266
+ b_e = b_j + jaw_dim
267
+ eye_dim = leye_pose + reye_pose
268
+ b_b = b_e + eye_dim
269
+ body_dim = global_orient + body_pose
270
+ b_h = b_b + body_dim
271
+ hand_dim = left_hand_pose + right_hand_pose
272
+ b_f = b_h + hand_dim
273
+ face_dim = expression
274
+
275
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
276
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
277
+ self.pose = int(self.full_dim / round(3 * scale))
278
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
279
+
280
+ def __call__(self, bat):
281
+ assert (not self.args.infer), "infer mode"
282
+ self.global_step += 1
283
+
284
+ loss_dict = {}
285
+
286
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
287
+ expression = bat['expression'].to(self.device).to(torch.float32)
288
+ jaw = poses[:, :3, :]
289
+ poses = poses[:, self.c_index, :]
290
+
291
+ pred = self.generator(in_spec=aud)
292
+
293
+ D_loss, D_loss_dict = self.get_loss(
294
+ pred_poses=pred.detach(),
295
+ gt_poses=poses,
296
+ aud=aud,
297
+ mode='training_D',
298
+ )
299
+
300
+ self.discriminator_optimizer.zero_grad()
301
+ D_loss.backward()
302
+ self.discriminator_optimizer.step()
303
+
304
+ G_loss, G_loss_dict = self.get_loss(
305
+ pred_poses=pred,
306
+ gt_poses=poses,
307
+ aud=aud,
308
+ expression=expression,
309
+ jaw=jaw,
310
+ mode='training_G',
311
+ )
312
+ self.generator_optimizer.zero_grad()
313
+ G_loss.backward()
314
+ self.generator_optimizer.step()
315
+
316
+ total_loss = None
317
+ loss_dict = {}
318
+ for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
319
+ loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
320
+
321
+ return total_loss, loss_dict
322
+
323
+ def get_loss(self,
324
+ pred_poses,
325
+ gt_poses,
326
+ aud=None,
327
+ jaw=None,
328
+ expression=None,
329
+ mode='training_G',
330
+ ):
331
+ loss_dict = {}
332
+ aud = aud.transpose(1, 2)
333
+ gt_poses = gt_poses.transpose(1, 2)
334
+ gt_aud = torch.cat([gt_poses, aud], dim=2)
335
+ pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
336
+
337
+ if mode == 'training_D':
338
+ dis_real = self.discriminator(gt_aud)
339
+ dis_fake = self.discriminator(pred_aud)
340
+ dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
341
+ torch.zeros_like(dis_fake).to(self.device), dis_fake)
342
+ loss_dict['dis'] = dis_error
343
+
344
+ return dis_error, loss_dict
345
+ elif mode == 'training_G':
346
+ jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
347
+ face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
348
+ body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
349
+ hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
350
+ l1_loss = jaw_loss + face_loss + body_loss + hand_loss
351
+
352
+ dis_output = self.discriminator(pred_aud)
353
+ gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
354
+ gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
355
+
356
+ loss_dict['gen'] = gen_error
357
+ loss_dict['jaw_loss'] = jaw_loss
358
+ loss_dict['face_loss'] = face_loss
359
+ loss_dict['body_loss'] = body_loss
360
+ loss_dict['hand_loss'] = hand_loss
361
+ return gen_loss, loss_dict
362
+ else:
363
+ raise ValueError(mode)
364
+
365
+ def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
366
+ output = []
367
+ assert self.args.infer, "train mode"
368
+ self.generator.eval()
369
+
370
+ if self.config.Data.pose.normalization:
371
+ assert norm_stats is not None
372
+ data_mean = norm_stats[0]
373
+ data_std = norm_stats[1]
374
+
375
+ pre_length = self.config.Data.pose.pre_pose_length
376
+ generate_length = self.config.Data.pose.generate_length
377
+ # assert pre_length == initial_pose.shape[-1]
378
+ # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
379
+ # B = pre_poses.shape[0]
380
+
381
+ aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
382
+ num_poses_to_generate = aud_feat.shape[-1]
383
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
384
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
385
+
386
+ with torch.no_grad():
387
+ pred_poses = self.generator(aud_feat)
388
+ pred_poses = pred_poses.cpu().numpy()
389
+ output = pred_poses.squeeze()
390
+
391
+ return output
392
+
393
+ def generate(self, aud, id):
394
+ self.generator.eval()
395
+ pred_poses = self.generator(aud)
396
+ return pred_poses
397
+
398
+
399
+ if __name__ == '__main__':
400
+ from trainer.options import parse_args
401
+
402
+ parser = parse_args()
403
+ args = parser.parse_args(
404
+ ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
405
+ '--infer'])
406
+
407
+ generator = TrainWrapper(args)
408
+
409
+ aud_fn = '../sample_audio/jon.wav'
410
+ initial_pose = torch.randn(64, 108, 4)
411
+ norm_stats = (np.random.randn(108), np.random.randn(108))
412
+ output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
413
+
414
+ print(output.shape)
nets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .smplx_face import TrainWrapper as s2g_face
2
+ from .smplx_body_vq import TrainWrapper as s2g_body_vq
3
+ from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
4
+ from .body_ae import TrainWrapper as s2g_body_ae
5
+ from .LS3DCG import TrainWrapper as LS3DCG
6
+ from .base import TrainWrapperBaseClass
7
+
8
+ from .utils import normalize, denormalize
nets/__pycache__/LS3DCG.cpython-37.pyc ADDED
Binary file (11.2 kB). View file
 
nets/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (431 Bytes). View file
 
nets/__pycache__/base.cpython-37.pyc ADDED
Binary file (2.98 kB). View file
 
nets/__pycache__/body_ae.cpython-37.pyc ADDED
Binary file (4.48 kB). View file
 
nets/__pycache__/init_model.cpython-37.pyc ADDED
Binary file (520 Bytes). View file
 
nets/__pycache__/layers.cpython-37.pyc ADDED
Binary file (22.7 kB). View file
 
nets/__pycache__/smplx_body_pixel.cpython-37.pyc ADDED
Binary file (9.51 kB). View file
 
nets/__pycache__/smplx_body_vq.cpython-37.pyc ADDED
Binary file (7.86 kB). View file
 
nets/__pycache__/smplx_face.cpython-37.pyc ADDED
Binary file (5.96 kB). View file
 
nets/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.81 kB). View file
 
nets/base.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+
5
+ class TrainWrapperBaseClass():
6
+ def __init__(self, args, config) -> None:
7
+ self.init_optimizer()
8
+
9
+ def init_optimizer(self) -> None:
10
+ print('using Adam')
11
+ self.generator_optimizer = optim.Adam(
12
+ self.generator.parameters(),
13
+ lr = self.config.Train.learning_rate.generator_learning_rate,
14
+ betas=[0.9, 0.999]
15
+ )
16
+ if self.discriminator is not None:
17
+ self.discriminator_optimizer = optim.Adam(
18
+ self.discriminator.parameters(),
19
+ lr = self.config.Train.learning_rate.discriminator_learning_rate,
20
+ betas=[0.9, 0.999]
21
+ )
22
+
23
+ def __call__(self, bat):
24
+ raise NotImplementedError
25
+
26
+ def get_loss(self, **kwargs):
27
+ raise NotImplementedError
28
+
29
+ def state_dict(self):
30
+ model_state = {
31
+ 'generator': self.generator.state_dict(),
32
+ 'generator_optim': self.generator_optimizer.state_dict(),
33
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
34
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
35
+ }
36
+ return model_state
37
+
38
+ def parameters(self):
39
+ return self.generator.parameters()
40
+
41
+ def load_state_dict(self, state_dict):
42
+ if 'generator' in state_dict:
43
+ self.generator.load_state_dict(state_dict['generator'])
44
+ else:
45
+ self.generator.load_state_dict(state_dict)
46
+
47
+ if 'generator_optim' in state_dict and self.generator_optimizer is not None:
48
+ self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
49
+
50
+ if self.discriminator is not None:
51
+ self.discriminator.load_state_dict(state_dict['discriminator'])
52
+
53
+ if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
54
+ self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
55
+
56
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
57
+ raise NotImplementedError
58
+
59
+ def init_params(self):
60
+ if self.config.Data.pose.convert_to_6d:
61
+ scale = 2
62
+ else:
63
+ scale = 1
64
+
65
+ global_orient = round(0 * scale)
66
+ leye_pose = reye_pose = round(0 * scale)
67
+ jaw_pose = round(0 * scale)
68
+ body_pose = round((63 - 24) * scale)
69
+ left_hand_pose = right_hand_pose = round(45 * scale)
70
+ if self.expression:
71
+ expression = 100
72
+ else:
73
+ expression = 0
74
+
75
+ b_j = 0
76
+ jaw_dim = jaw_pose
77
+ b_e = b_j + jaw_dim
78
+ eye_dim = leye_pose + reye_pose
79
+ b_b = b_e + eye_dim
80
+ body_dim = global_orient + body_pose
81
+ b_h = b_b + body_dim
82
+ hand_dim = left_hand_pose + right_hand_pose
83
+ b_f = b_h + hand_dim
84
+ face_dim = expression
85
+
86
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
87
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
88
+ self.pose = int(self.full_dim / round(3 * scale))
89
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
nets/body_ae.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from nets.base import TrainWrapperBaseClass
7
+ from nets.spg.s2glayers import Discriminator as D_S2G
8
+ from nets.spg.vqvae_1d import AE as s2g_body
9
+ import torch
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+
13
+ from data_utils.lower_body import c_index, c_index_3d, c_index_6d
14
+
15
+
16
+ def separate_aa(aa):
17
+ aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
18
+ axis = F.normalize(aa[:, :, :, :3], dim=-1)
19
+ angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
20
+ return axis, angle
21
+
22
+
23
+ class TrainWrapper(TrainWrapperBaseClass):
24
+ '''
25
+ a wrapper receving a batch from data_utils and calculate loss
26
+ '''
27
+
28
+ def __init__(self, args, config):
29
+ self.args = args
30
+ self.config = config
31
+ self.device = torch.device(self.args.gpu)
32
+ self.global_step = 0
33
+
34
+ self.gan = False
35
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
36
+ self.preleng = self.config.Data.pose.pre_pose_length
37
+ self.expression = self.config.Data.pose.expression
38
+ self.epoch = 0
39
+ self.init_params()
40
+ self.num_classes = 4
41
+ self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
42
+ num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
43
+ if self.gan:
44
+ self.discriminator = D_S2G(
45
+ pose_dim=110 + 64, pose=self.pose
46
+ ).to(self.device)
47
+ else:
48
+ self.discriminator = None
49
+
50
+ if self.convert_to_6d:
51
+ self.c_index = c_index_6d
52
+ else:
53
+ self.c_index = c_index_3d
54
+
55
+ super().__init__(args, config)
56
+
57
+ def init_optimizer(self):
58
+
59
+ self.g_optimizer = optim.Adam(
60
+ self.g.parameters(),
61
+ lr=self.config.Train.learning_rate.generator_learning_rate,
62
+ betas=[0.9, 0.999]
63
+ )
64
+
65
+ def state_dict(self):
66
+ model_state = {
67
+ 'g': self.g.state_dict(),
68
+ 'g_optim': self.g_optimizer.state_dict(),
69
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
70
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
71
+ }
72
+ return model_state
73
+
74
+
75
+ def __call__(self, bat):
76
+ # assert (not self.args.infer), "infer mode"
77
+ self.global_step += 1
78
+
79
+ total_loss = None
80
+ loss_dict = {}
81
+
82
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
83
+
84
+ # id = bat['speaker'].to(self.device) - 20
85
+ # id = F.one_hot(id, self.num_classes)
86
+
87
+ poses = poses[:, self.c_index, :]
88
+ gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
89
+
90
+ loss = 0
91
+ loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
92
+
93
+ return total_loss, loss_dict
94
+
95
+ def vq_train(self, gt, name, model, dict, total_loss, pre=None):
96
+ x_recon = model(gt_poses=gt, pre_state=pre)
97
+ loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
98
+ # total_loss = total_loss + loss
99
+
100
+ if name == 'g':
101
+ optimizer_name = 'g_optimizer'
102
+
103
+ optimizer = getattr(self, optimizer_name)
104
+ optimizer.zero_grad()
105
+ loss.backward()
106
+ optimizer.step()
107
+
108
+ for key in list(loss_dict.keys()):
109
+ dict[name + key] = loss_dict.get(key, 0).item()
110
+ return dict, total_loss
111
+
112
+ def get_loss(self,
113
+ pred_poses,
114
+ gt_poses,
115
+ pre=None
116
+ ):
117
+ loss_dict = {}
118
+
119
+
120
+ rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
121
+ v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
122
+ v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
123
+ velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
124
+
125
+ if pre is None:
126
+ f0_vel = 0
127
+ else:
128
+ v0_pr = pred_poses[:, 0] - pre[:, -1]
129
+ v0_gt = gt_poses[:, 0] - pre[:, -1]
130
+ f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
131
+
132
+ gen_loss = rec_loss + velocity_loss + f0_vel
133
+
134
+ loss_dict['rec_loss'] = rec_loss
135
+ loss_dict['velocity_loss'] = velocity_loss
136
+ # loss_dict['e_q_loss'] = e_q_loss
137
+ if pre is not None:
138
+ loss_dict['f0_vel'] = f0_vel
139
+
140
+ return gen_loss, loss_dict
141
+
142
+ def load_state_dict(self, state_dict):
143
+ self.g.load_state_dict(state_dict['g'])
144
+
145
+ def extract(self, x):
146
+ self.g.eval()
147
+ if x.shape[2] > self.full_dim:
148
+ if x.shape[2] == 239:
149
+ x = x[:, :, 102:]
150
+ x = x[:, :, self.c_index]
151
+ feat = self.g.encode(x)
152
+ return feat.transpose(1, 2), x
nets/init_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nets import *
2
+
3
+
4
+ def init_model(model_name, args, config):
5
+
6
+ if model_name == 's2g_face':
7
+ generator = s2g_face(
8
+ args,
9
+ config,
10
+ )
11
+ elif model_name == 's2g_body_vq':
12
+ generator = s2g_body_vq(
13
+ args,
14
+ config,
15
+ )
16
+ elif model_name == 's2g_body_pixel':
17
+ generator = s2g_body_pixel(
18
+ args,
19
+ config,
20
+ )
21
+ elif model_name == 's2g_body_ae':
22
+ generator = s2g_body_ae(
23
+ args,
24
+ config,
25
+ )
26
+ elif model_name == 's2g_LS3DCG':
27
+ generator = LS3DCG(
28
+ args,
29
+ config,
30
+ )
31
+ else:
32
+ raise ValueError
33
+ return generator
34
+
35
+
nets/layers.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+
10
+
11
+ # TODO: be aware of the actual netork structures
12
+
13
+ def get_log(x):
14
+ log = 0
15
+ while x > 1:
16
+ if x % 2 == 0:
17
+ x = x // 2
18
+ log += 1
19
+ else:
20
+ raise ValueError('x is not a power of 2')
21
+
22
+ return log
23
+
24
+
25
+ class ConvNormRelu(nn.Module):
26
+ '''
27
+ (B,C_in,H,W) -> (B, C_out, H, W)
28
+ there exist some kernel size that makes the result is not H/s
29
+ #TODO: there might some problems with residual
30
+ '''
31
+
32
+ def __init__(self,
33
+ in_channels,
34
+ out_channels,
35
+ type='1d',
36
+ leaky=False,
37
+ downsample=False,
38
+ kernel_size=None,
39
+ stride=None,
40
+ padding=None,
41
+ p=0,
42
+ groups=1,
43
+ residual=False,
44
+ norm='bn'):
45
+ '''
46
+ conv-bn-relu
47
+ '''
48
+ super(ConvNormRelu, self).__init__()
49
+ self.residual = residual
50
+ self.norm_type = norm
51
+ # kernel_size = k
52
+ # stride = s
53
+
54
+ if kernel_size is None and stride is None:
55
+ if not downsample:
56
+ kernel_size = 3
57
+ stride = 1
58
+ else:
59
+ kernel_size = 4
60
+ stride = 2
61
+
62
+ if padding is None:
63
+ if isinstance(kernel_size, int) and isinstance(stride, tuple):
64
+ padding = tuple(int((kernel_size - st) / 2) for st in stride)
65
+ elif isinstance(kernel_size, tuple) and isinstance(stride, int):
66
+ padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
67
+ elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
68
+ padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
69
+ else:
70
+ padding = int((kernel_size - stride) / 2)
71
+
72
+ if self.residual:
73
+ if downsample:
74
+ if type == '1d':
75
+ self.residual_layer = nn.Sequential(
76
+ nn.Conv1d(
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ kernel_size=kernel_size,
80
+ stride=stride,
81
+ padding=padding
82
+ )
83
+ )
84
+ elif type == '2d':
85
+ self.residual_layer = nn.Sequential(
86
+ nn.Conv2d(
87
+ in_channels=in_channels,
88
+ out_channels=out_channels,
89
+ kernel_size=kernel_size,
90
+ stride=stride,
91
+ padding=padding
92
+ )
93
+ )
94
+ else:
95
+ if in_channels == out_channels:
96
+ self.residual_layer = nn.Identity()
97
+ else:
98
+ if type == '1d':
99
+ self.residual_layer = nn.Sequential(
100
+ nn.Conv1d(
101
+ in_channels=in_channels,
102
+ out_channels=out_channels,
103
+ kernel_size=kernel_size,
104
+ stride=stride,
105
+ padding=padding
106
+ )
107
+ )
108
+ elif type == '2d':
109
+ self.residual_layer = nn.Sequential(
110
+ nn.Conv2d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=kernel_size,
114
+ stride=stride,
115
+ padding=padding
116
+ )
117
+ )
118
+
119
+ in_channels = in_channels * groups
120
+ out_channels = out_channels * groups
121
+ if type == '1d':
122
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
123
+ kernel_size=kernel_size, stride=stride, padding=padding,
124
+ groups=groups)
125
+ self.norm = nn.BatchNorm1d(out_channels)
126
+ self.dropout = nn.Dropout(p=p)
127
+ elif type == '2d':
128
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
129
+ kernel_size=kernel_size, stride=stride, padding=padding,
130
+ groups=groups)
131
+ self.norm = nn.BatchNorm2d(out_channels)
132
+ self.dropout = nn.Dropout2d(p=p)
133
+ if norm == 'gn':
134
+ self.norm = nn.GroupNorm(2, out_channels)
135
+ elif norm == 'ln':
136
+ self.norm = nn.LayerNorm(out_channels)
137
+ if leaky:
138
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
139
+ else:
140
+ self.relu = nn.ReLU()
141
+
142
+ def forward(self, x, **kwargs):
143
+ if self.norm_type == 'ln':
144
+ out = self.dropout(self.conv(x))
145
+ out = self.norm(out.transpose(1,2)).transpose(1,2)
146
+ else:
147
+ out = self.norm(self.dropout(self.conv(x)))
148
+ if self.residual:
149
+ residual = self.residual_layer(x)
150
+ out += residual
151
+ return self.relu(out)
152
+
153
+
154
+ class UNet1D(nn.Module):
155
+ def __init__(self,
156
+ input_channels,
157
+ output_channels,
158
+ max_depth=5,
159
+ kernel_size=None,
160
+ stride=None,
161
+ p=0,
162
+ groups=1):
163
+ super(UNet1D, self).__init__()
164
+ self.pre_downsampling_conv = nn.ModuleList([])
165
+ self.conv1 = nn.ModuleList([])
166
+ self.conv2 = nn.ModuleList([])
167
+ self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
168
+ self.max_depth = max_depth
169
+ self.groups = groups
170
+
171
+ self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels,
172
+ type='1d', leaky=True, downsample=False,
173
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
174
+ self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels,
175
+ type='1d', leaky=True, downsample=False,
176
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
177
+
178
+ for i in range(self.max_depth):
179
+ self.conv1.append(ConvNormRelu(output_channels, output_channels,
180
+ type='1d', leaky=True, downsample=True,
181
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
182
+
183
+ for i in range(self.max_depth):
184
+ self.conv2.append(ConvNormRelu(output_channels, output_channels,
185
+ type='1d', leaky=True, downsample=False,
186
+ kernel_size=kernel_size, stride=stride, p=p, groups=groups))
187
+
188
+ def forward(self, x):
189
+
190
+ input_size = x.shape[-1]
191
+
192
+ assert get_log(
193
+ input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth'
194
+
195
+ x = nn.Sequential(*self.pre_downsampling_conv)(x)
196
+
197
+ residuals = []
198
+ residuals.append(x)
199
+ for i, conv1 in enumerate(self.conv1):
200
+ x = conv1(x)
201
+ if i < self.max_depth - 1:
202
+ residuals.append(x)
203
+
204
+ for i, conv2 in enumerate(self.conv2):
205
+ x = self.upconv(x) + residuals[self.max_depth - i - 1]
206
+ x = conv2(x)
207
+
208
+ return x
209
+
210
+
211
+ class UNet2D(nn.Module):
212
+ def __init__(self):
213
+ super(UNet2D, self).__init__()
214
+ raise NotImplementedError('2D Unet is wierd')
215
+
216
+
217
+ class AudioPoseEncoder1D(nn.Module):
218
+ '''
219
+ (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T)
220
+ '''
221
+
222
+ def __init__(self,
223
+ C_in,
224
+ C_out,
225
+ kernel_size=None,
226
+ stride=None,
227
+ min_layer_nums=None
228
+ ):
229
+ super(AudioPoseEncoder1D, self).__init__()
230
+ self.C_in = C_in
231
+ self.C_out = C_out
232
+
233
+ conv_layers = nn.ModuleList([])
234
+ cur_C = C_in
235
+ num_layers = 0
236
+ while cur_C < self.C_out:
237
+ conv_layers.append(ConvNormRelu(
238
+ in_channels=cur_C,
239
+ out_channels=cur_C * 2,
240
+ kernel_size=kernel_size,
241
+ stride=stride
242
+ ))
243
+ cur_C *= 2
244
+ num_layers += 1
245
+
246
+ if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums):
247
+ while (cur_C != C_out) or num_layers < min_layer_nums:
248
+ conv_layers.append(ConvNormRelu(
249
+ in_channels=cur_C,
250
+ out_channels=C_out,
251
+ kernel_size=kernel_size,
252
+ stride=stride
253
+ ))
254
+ num_layers += 1
255
+ cur_C = C_out
256
+
257
+ self.conv_layers = nn.Sequential(*conv_layers)
258
+
259
+ def forward(self, x):
260
+ '''
261
+ x: (B, C, T)
262
+ '''
263
+ x = self.conv_layers(x)
264
+ return x
265
+
266
+
267
+ class AudioPoseEncoder2D(nn.Module):
268
+ '''
269
+ (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T)
270
+ '''
271
+
272
+ def __init__(self):
273
+ raise NotImplementedError
274
+
275
+
276
+ class AudioPoseEncoderRNN(nn.Module):
277
+ '''
278
+ (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T)
279
+ '''
280
+
281
+ def __init__(self,
282
+ C_in,
283
+ hidden_size,
284
+ num_layers,
285
+ rnn_cell='gru',
286
+ bidirectional=False
287
+ ):
288
+ super(AudioPoseEncoderRNN, self).__init__()
289
+ if rnn_cell == 'gru':
290
+ self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
291
+ bidirectional=bidirectional)
292
+ elif rnn_cell == 'lstm':
293
+ self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
294
+ bidirectional=bidirectional)
295
+ else:
296
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
297
+
298
+ def forward(self, x, state=None):
299
+
300
+ x = x.permute(0, 2, 1)
301
+ x, state = self.cell(x, state)
302
+ x = x.permute(0, 2, 1)
303
+
304
+ return x
305
+
306
+
307
+ class AudioPoseEncoderGraph(nn.Module):
308
+ '''
309
+ (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V)
310
+ '''
311
+
312
+ def __init__(self,
313
+ layers_config, # 理应是(C_in, C_out, kernel_size)的list
314
+ A, # adjacent matrix (num_parts, V, V)
315
+ residual,
316
+ local_bn=False,
317
+ share_weights=False
318
+ ) -> None:
319
+ super().__init__()
320
+ self.A = A
321
+ self.num_joints = A.shape[1]
322
+ self.num_parts = A.shape[0]
323
+ self.C_in = layers_config[0][0]
324
+ self.C_out = layers_config[-1][1]
325
+
326
+ self.conv_layers = nn.ModuleList([
327
+ GraphConvNormRelu(
328
+ C_in=c_in,
329
+ C_out=c_out,
330
+ A=self.A,
331
+ residual=residual,
332
+ local_bn=local_bn,
333
+ kernel_size=k,
334
+ share_weights=share_weights
335
+ ) for (c_in, c_out, k) in layers_config
336
+ ])
337
+
338
+ self.conv_layers = nn.Sequential(*self.conv_layers)
339
+
340
+ def forward(self, x):
341
+ '''
342
+ x: (B, C, T), C should be num_joints*D
343
+ output: (B, D, T, V)
344
+ '''
345
+ B, C, T = x.shape
346
+ x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面
347
+ x = x.permute(0, 2, 3, 1) # (B, D, T, V)
348
+ assert x.shape[1] == self.C_in
349
+
350
+ x_conved = self.conv_layers(x)
351
+
352
+ # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T)
353
+
354
+ return x_conved
355
+
356
+
357
+ class SeqEncoder2D(nn.Module):
358
+ '''
359
+ seq_encoder, encoding a seq to a vector
360
+ (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out)
361
+ '''
362
+
363
+ def __init__(self,
364
+ C_in, # should be 2
365
+ T_in,
366
+ C_out,
367
+ num_joints,
368
+ min_layer_num=None,
369
+ residual=False
370
+ ):
371
+ super(SeqEncoder2D, self).__init__()
372
+ self.C_in = C_in
373
+ self.C_out = C_out
374
+ self.T_in = T_in
375
+ self.num_joints = num_joints
376
+
377
+ conv_layers = nn.ModuleList([])
378
+ conv_layers.append(ConvNormRelu(
379
+ in_channels=C_in,
380
+ out_channels=32,
381
+ type='2d',
382
+ residual=residual
383
+ ))
384
+
385
+ cur_C = 32
386
+ cur_H = T_in
387
+ cur_W = num_joints
388
+ num_layers = 1
389
+ while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1):
390
+ ks = [3, 3]
391
+ st = [1, 1]
392
+
393
+ if cur_H > 1:
394
+ if cur_H > 4:
395
+ ks[0] = 4
396
+ st[0] = 2
397
+ else:
398
+ ks[0] = cur_H
399
+ st[0] = cur_H
400
+ if cur_W > 1:
401
+ if cur_W > 4:
402
+ ks[1] = 4
403
+ st[1] = 2
404
+ else:
405
+ ks[1] = cur_W
406
+ st[1] = cur_W
407
+
408
+ conv_layers.append(ConvNormRelu(
409
+ in_channels=cur_C,
410
+ out_channels=min(C_out, cur_C * 2),
411
+ type='2d',
412
+ kernel_size=tuple(ks),
413
+ stride=tuple(st),
414
+ residual=residual
415
+ ))
416
+ cur_C = min(cur_C * 2, C_out)
417
+ if cur_H > 1:
418
+ if cur_H > 4:
419
+ cur_H //= 2
420
+ else:
421
+ cur_H = 1
422
+ if cur_W > 1:
423
+ if cur_W > 4:
424
+ cur_W //= 2
425
+ else:
426
+ cur_W = 1
427
+ num_layers += 1
428
+
429
+ if min_layer_num is not None and (num_layers < min_layer_num):
430
+ while num_layers < min_layer_num:
431
+ conv_layers.append(ConvNormRelu(
432
+ in_channels=C_out,
433
+ out_channels=C_out,
434
+ type='2d',
435
+ kernel_size=1,
436
+ stride=1,
437
+ residual=residual
438
+ ))
439
+ num_layers += 1
440
+
441
+ self.conv_layers = nn.Sequential(*conv_layers)
442
+ self.num_layers = num_layers
443
+
444
+ def forward(self, x):
445
+ B, C, T = x.shape
446
+ x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front
447
+ x = x.permute(0, 2, 3, 1) # (B, D, T, V)
448
+ assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints
449
+
450
+ x = self.conv_layers(x)
451
+ return x.squeeze()
452
+
453
+
454
+ class SeqEncoder1D(nn.Module):
455
+ '''
456
+ (B, C, T)->(B, D)
457
+ '''
458
+
459
+ def __init__(self,
460
+ C_in,
461
+ C_out,
462
+ T_in,
463
+ min_layer_nums=None
464
+ ):
465
+ super(SeqEncoder1D, self).__init__()
466
+ conv_layers = nn.ModuleList([])
467
+ cur_C = C_in
468
+ cur_T = T_in
469
+ self.num_layers = 0
470
+ while (cur_C < C_out) or (cur_T > 1):
471
+ ks = 3
472
+ st = 1
473
+ if cur_T > 1:
474
+ if cur_T > 4:
475
+ ks = 4
476
+ st = 2
477
+ else:
478
+ ks = cur_T
479
+ st = cur_T
480
+
481
+ conv_layers.append(ConvNormRelu(
482
+ in_channels=cur_C,
483
+ out_channels=min(C_out, cur_C * 2),
484
+ type='1d',
485
+ kernel_size=ks,
486
+ stride=st
487
+ ))
488
+ cur_C = min(cur_C * 2, C_out)
489
+ if cur_T > 1:
490
+ if cur_T > 4:
491
+ cur_T = cur_T // 2
492
+ else:
493
+ cur_T = 1
494
+ self.num_layers += 1
495
+
496
+ if min_layer_nums is not None and (self.num_layers < min_layer_nums):
497
+ while self.num_layers < min_layer_nums:
498
+ conv_layers.append(ConvNormRelu(
499
+ in_channels=C_out,
500
+ out_channels=C_out,
501
+ type='1d',
502
+ kernel_size=1,
503
+ stride=1
504
+ ))
505
+ self.num_layers += 1
506
+ self.conv_layers = nn.Sequential(*conv_layers)
507
+
508
+ def forward(self, x):
509
+ x = self.conv_layers(x)
510
+ return x.squeeze()
511
+
512
+
513
+ class SeqEncoderRNN(nn.Module):
514
+ '''
515
+ (B, C, T) -> (B, T, C) -> (B, D)
516
+ LSTM/GRU-FC
517
+ '''
518
+
519
+ def __init__(self,
520
+ hidden_size,
521
+ in_size,
522
+ num_rnn_layers,
523
+ rnn_cell='gru',
524
+ bidirectional=False
525
+ ):
526
+ super(SeqEncoderRNN, self).__init__()
527
+ self.hidden_size = hidden_size
528
+ self.in_size = in_size
529
+ self.num_rnn_layers = num_rnn_layers
530
+ self.bidirectional = bidirectional
531
+
532
+ if rnn_cell == 'gru':
533
+ self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
534
+ batch_first=True, bidirectional=bidirectional)
535
+ elif rnn_cell == 'lstm':
536
+ self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
537
+ batch_first=True, bidirectional=bidirectional)
538
+
539
+ def forward(self, x, state=None):
540
+
541
+ x = x.permute(0, 2, 1)
542
+ B, T, C = x.shape
543
+ x, _ = self.cell(x, state)
544
+ if self.bidirectional:
545
+ out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1)
546
+ else:
547
+ out = x[:, -1, :]
548
+ assert out.shape[0] == B
549
+ return out
550
+
551
+
552
+ class SeqEncoderGraph(nn.Module):
553
+ '''
554
+ '''
555
+
556
+ def __init__(self,
557
+ embedding_size,
558
+ layer_configs,
559
+ residual,
560
+ local_bn,
561
+ A,
562
+ T,
563
+ share_weights=False
564
+ ) -> None:
565
+ super().__init__()
566
+
567
+ self.C_in = layer_configs[0][0]
568
+ self.C_out = embedding_size
569
+
570
+ self.num_joints = A.shape[1]
571
+
572
+ self.graph_encoder = AudioPoseEncoderGraph(
573
+ layers_config=layer_configs,
574
+ A=A,
575
+ residual=residual,
576
+ local_bn=local_bn,
577
+ share_weights=share_weights
578
+ )
579
+
580
+ cur_C = layer_configs[-1][1]
581
+ self.spatial_pool = ConvNormRelu(
582
+ in_channels=cur_C,
583
+ out_channels=cur_C,
584
+ type='2d',
585
+ kernel_size=(1, self.num_joints),
586
+ stride=(1, 1),
587
+ padding=(0, 0)
588
+ )
589
+
590
+ temporal_pool = nn.ModuleList([])
591
+ cur_H = T
592
+ num_layers = 0
593
+ self.temporal_conv_info = []
594
+ while cur_C < self.C_out or cur_H > 1:
595
+ self.temporal_conv_info.append(cur_C)
596
+ ks = [3, 1]
597
+ st = [1, 1]
598
+
599
+ if cur_H > 1:
600
+ if cur_H > 4:
601
+ ks[0] = 4
602
+ st[0] = 2
603
+ else:
604
+ ks[0] = cur_H
605
+ st[0] = cur_H
606
+
607
+ temporal_pool.append(ConvNormRelu(
608
+ in_channels=cur_C,
609
+ out_channels=min(self.C_out, cur_C * 2),
610
+ type='2d',
611
+ kernel_size=tuple(ks),
612
+ stride=tuple(st)
613
+ ))
614
+ cur_C = min(cur_C * 2, self.C_out)
615
+
616
+ if cur_H > 1:
617
+ if cur_H > 4:
618
+ cur_H //= 2
619
+ else:
620
+ cur_H = 1
621
+
622
+ num_layers += 1
623
+
624
+ self.temporal_pool = nn.Sequential(*temporal_pool)
625
+ print("graph seq encoder info: temporal pool:", self.temporal_conv_info)
626
+ self.num_layers = num_layers
627
+ # need fc?
628
+
629
+ def forward(self, x):
630
+ '''
631
+ x: (B, C, T)
632
+ '''
633
+ B, C, T = x.shape
634
+ x = self.graph_encoder(x)
635
+ x = self.spatial_pool(x)
636
+ x = self.temporal_pool(x)
637
+ x = x.view(B, self.C_out)
638
+
639
+ return x
640
+
641
+
642
+ class SeqDecoder2D(nn.Module):
643
+ '''
644
+ (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T)
645
+ '''
646
+
647
+ def __init__(self):
648
+ super(SeqDecoder2D, self).__init__()
649
+ raise NotImplementedError
650
+
651
+
652
+ class SeqDecoder1D(nn.Module):
653
+ '''
654
+ (B, D)->(B, D, 1)->...->(B, C_out, T)
655
+ '''
656
+
657
+ def __init__(self,
658
+ D_in,
659
+ C_out,
660
+ T_out,
661
+ min_layer_num=None
662
+ ):
663
+ super(SeqDecoder1D, self).__init__()
664
+ self.T_out = T_out
665
+ self.min_layer_num = min_layer_num
666
+
667
+ cur_t = 1
668
+
669
+ self.pre_conv = ConvNormRelu(
670
+ in_channels=D_in,
671
+ out_channels=C_out,
672
+ type='1d'
673
+ )
674
+ self.num_layers = 1
675
+ self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
676
+ self.conv_layers = nn.ModuleList([])
677
+ cur_t *= 2
678
+ while cur_t <= T_out:
679
+ self.conv_layers.append(ConvNormRelu(
680
+ in_channels=C_out,
681
+ out_channels=C_out,
682
+ type='1d'
683
+ ))
684
+ cur_t *= 2
685
+ self.num_layers += 1
686
+
687
+ post_conv = nn.ModuleList([ConvNormRelu(
688
+ in_channels=C_out,
689
+ out_channels=C_out,
690
+ type='1d'
691
+ )])
692
+ self.num_layers += 1
693
+ if min_layer_num is not None and self.num_layers < min_layer_num:
694
+ while self.num_layers < min_layer_num:
695
+ post_conv.append(ConvNormRelu(
696
+ in_channels=C_out,
697
+ out_channels=C_out,
698
+ type='1d'
699
+ ))
700
+ self.num_layers += 1
701
+ self.post_conv = nn.Sequential(*post_conv)
702
+
703
+ def forward(self, x):
704
+
705
+ x = x.unsqueeze(-1)
706
+ x = self.pre_conv(x)
707
+ for conv in self.conv_layers:
708
+ x = self.upconv(x)
709
+ x = conv(x)
710
+
711
+ x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest')
712
+ x = self.post_conv(x)
713
+ return x
714
+
715
+
716
+ class SeqDecoderRNN(nn.Module):
717
+ '''
718
+ (B, D)->(B, C_out, T)
719
+ '''
720
+
721
+ def __init__(self,
722
+ hidden_size,
723
+ C_out,
724
+ T_out,
725
+ num_layers,
726
+ rnn_cell='gru'
727
+ ):
728
+ super(SeqDecoderRNN, self).__init__()
729
+ self.num_steps = T_out
730
+ if rnn_cell == 'gru':
731
+ self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
732
+ bidirectional=False)
733
+ elif rnn_cell == 'lstm':
734
+ self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
735
+ bidirectional=False)
736
+ else:
737
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
738
+
739
+ self.fc = nn.Linear(hidden_size, C_out)
740
+
741
+ def forward(self, hidden, frame_0):
742
+ frame_0 = frame_0.permute(0, 2, 1)
743
+ dec_input = frame_0
744
+ outputs = []
745
+ for i in range(self.num_steps):
746
+ frame_out, hidden = self.cell(dec_input, hidden)
747
+ frame_out = self.fc(frame_out)
748
+ dec_input = frame_out
749
+ outputs.append(frame_out)
750
+ output = torch.cat(outputs, dim=1)
751
+ return output.permute(0, 2, 1)
752
+
753
+
754
+ class SeqTranslator2D(nn.Module):
755
+ '''
756
+ (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out)
757
+ '''
758
+
759
+ def __init__(self,
760
+ C_in=64,
761
+ C_out=108,
762
+ T_in=75,
763
+ T_out=25,
764
+ residual=True
765
+ ):
766
+ super(SeqTranslator2D, self).__init__()
767
+ print("Warning: hard coded")
768
+ self.C_in = C_in
769
+ self.C_out = C_out
770
+ self.T_in = T_in
771
+ self.T_out = T_out
772
+ self.residual = residual
773
+
774
+ self.conv_layers = nn.Sequential(
775
+ ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1),
776
+ ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
777
+ ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
778
+
779
+ ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)),
780
+ ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
781
+ ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
782
+
783
+ ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)),
784
+ ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)),
785
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
786
+
787
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
788
+ ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1),
789
+ )
790
+
791
+ def forward(self, x):
792
+ assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in
793
+ x = x.view(x.shape[0], 1, x.shape[1], x.shape[2])
794
+ x = self.conv_layers(x)
795
+ x = x.squeeze(2)
796
+ return x
797
+
798
+
799
+ class SeqTranslator1D(nn.Module):
800
+ '''
801
+ (B, C, T)->(B, C_out, T)
802
+ '''
803
+
804
+ def __init__(self,
805
+ C_in,
806
+ C_out,
807
+ kernel_size=None,
808
+ stride=None,
809
+ min_layers_num=None,
810
+ residual=True,
811
+ norm='bn'
812
+ ):
813
+ super(SeqTranslator1D, self).__init__()
814
+
815
+ conv_layers = nn.ModuleList([])
816
+ conv_layers.append(ConvNormRelu(
817
+ in_channels=C_in,
818
+ out_channels=C_out,
819
+ type='1d',
820
+ kernel_size=kernel_size,
821
+ stride=stride,
822
+ residual=residual,
823
+ norm=norm
824
+ ))
825
+ self.num_layers = 1
826
+ if min_layers_num is not None and self.num_layers < min_layers_num:
827
+ while self.num_layers < min_layers_num:
828
+ conv_layers.append(ConvNormRelu(
829
+ in_channels=C_out,
830
+ out_channels=C_out,
831
+ type='1d',
832
+ kernel_size=kernel_size,
833
+ stride=stride,
834
+ residual=residual,
835
+ norm=norm
836
+ ))
837
+ self.num_layers += 1
838
+ self.conv_layers = nn.Sequential(*conv_layers)
839
+
840
+ def forward(self, x):
841
+ return self.conv_layers(x)
842
+
843
+
844
+ class SeqTranslatorRNN(nn.Module):
845
+ '''
846
+ (B, C, T)->(B, C_out, T)
847
+ LSTM-FC
848
+ '''
849
+
850
+ def __init__(self,
851
+ C_in,
852
+ C_out,
853
+ hidden_size,
854
+ num_layers,
855
+ rnn_cell='gru'
856
+ ):
857
+ super(SeqTranslatorRNN, self).__init__()
858
+
859
+ if rnn_cell == 'gru':
860
+ self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
861
+ bidirectional=False)
862
+ self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
863
+ bidirectional=False)
864
+ elif rnn_cell == 'lstm':
865
+ self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
866
+ bidirectional=False)
867
+ self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
868
+ bidirectional=False)
869
+ else:
870
+ raise ValueError('invalid rnn cell:%s' % (rnn_cell))
871
+
872
+ self.fc = nn.Linear(hidden_size, C_out)
873
+
874
+ def forward(self, x, frame_0):
875
+
876
+ num_steps = x.shape[-1]
877
+ x = x.permute(0, 2, 1)
878
+ frame_0 = frame_0.permute(0, 2, 1)
879
+ _, hidden = self.enc_cell(x, None)
880
+
881
+ outputs = []
882
+ for i in range(num_steps):
883
+ inputs = frame_0
884
+ output_frame, hidden = self.dec_cell(inputs, hidden)
885
+ output_frame = self.fc(output_frame)
886
+ frame_0 = output_frame
887
+ outputs.append(output_frame)
888
+ outputs = torch.cat(outputs, dim=1)
889
+ return outputs.permute(0, 2, 1)
890
+
891
+
892
+ class ResBlock(nn.Module):
893
+ def __init__(self,
894
+ input_dim,
895
+ fc_dim,
896
+ afn,
897
+ nfn
898
+ ):
899
+ '''
900
+ afn: activation fn
901
+ nfn: normalization fn
902
+ '''
903
+ super(ResBlock, self).__init__()
904
+
905
+ self.input_dim = input_dim
906
+ self.fc_dim = fc_dim
907
+ self.afn = afn
908
+ self.nfn = nfn
909
+
910
+ if self.afn != 'relu':
911
+ raise ValueError('Wrong')
912
+
913
+ if self.nfn == 'layer_norm':
914
+ raise ValueError('wrong')
915
+
916
+ self.layers = nn.Sequential(
917
+ nn.Linear(self.input_dim, self.fc_dim // 2),
918
+ nn.ReLU(),
919
+ nn.Linear(self.fc_dim // 2, self.fc_dim // 2),
920
+ nn.ReLU(),
921
+ nn.Linear(self.fc_dim // 2, self.fc_dim),
922
+ nn.ReLU()
923
+ )
924
+
925
+ self.shortcut_layer = nn.Sequential(
926
+ nn.Linear(self.input_dim, self.fc_dim),
927
+ nn.ReLU(),
928
+ )
929
+
930
+ def forward(self, inputs):
931
+ return self.layers(inputs) + self.shortcut_layer(inputs)
932
+
933
+
934
+ class AudioEncoder(nn.Module):
935
+ def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False):
936
+ super(AudioEncoder, self).__init__()
937
+ self.in_channels = channels[0]
938
+ self.augmentation = augmentation
939
+
940
+ model = []
941
+ acti = nn.LeakyReLU(0.2)
942
+
943
+ nr_layer = len(channels) - 1
944
+
945
+ for i in range(nr_layer):
946
+ if conv_pool is None:
947
+ model.append(nn.ReflectionPad1d(padding))
948
+ model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
949
+ model.append(acti)
950
+ else:
951
+ model.append(nn.ReflectionPad1d(padding))
952
+ model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
953
+ model.append(acti)
954
+ model.append(conv_pool(kernel_size=2, stride=2))
955
+
956
+ if self.augmentation:
957
+ model.append(
958
+ nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride)
959
+ )
960
+ model.append(acti)
961
+
962
+ self.model = nn.Sequential(*model)
963
+
964
+ def forward(self, x):
965
+
966
+ x = x[:, :self.in_channels, :]
967
+ x = self.model(x)
968
+ return x
969
+
970
+
971
+ class AudioDecoder(nn.Module):
972
+ def __init__(self, channels, kernel_size=7, ups=25):
973
+ super(AudioDecoder, self).__init__()
974
+
975
+ model = []
976
+ pad = (kernel_size - 1) // 2
977
+ acti = nn.LeakyReLU(0.2)
978
+
979
+ for i in range(len(channels) - 2):
980
+ model.append(nn.Upsample(scale_factor=2, mode='nearest'))
981
+ model.append(nn.ReflectionPad1d(pad))
982
+ model.append(nn.Conv1d(channels[i], channels[i + 1],
983
+ kernel_size=kernel_size, stride=1))
984
+ if i == 0 or i == 1:
985
+ model.append(nn.Dropout(p=0.2))
986
+ if not i == len(channels) - 2:
987
+ model.append(acti)
988
+
989
+ model.append(nn.Upsample(size=ups, mode='nearest'))
990
+ model.append(nn.ReflectionPad1d(pad))
991
+ model.append(nn.Conv1d(channels[-2], channels[-1],
992
+ kernel_size=kernel_size, stride=1))
993
+
994
+ self.model = nn.Sequential(*model)
995
+
996
+ def forward(self, x):
997
+ return self.model(x)
998
+
999
+
1000
+ class Audio2Pose(nn.Module):
1001
+ def __init__(self, pose_dim, embed_size, augmentation, ups=25):
1002
+ super(Audio2Pose, self).__init__()
1003
+ self.pose_dim = pose_dim
1004
+ self.embed_size = embed_size
1005
+ self.augmentation = augmentation
1006
+
1007
+ self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1,
1008
+ conv_pool=nn.AvgPool1d, augmentation=self.augmentation)
1009
+ if self.augmentation:
1010
+ self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim])
1011
+ else:
1012
+ self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups)
1013
+
1014
+ if self.augmentation:
1015
+ self.pose_enc = nn.Sequential(
1016
+ nn.Linear(self.embed_size // 2, 256),
1017
+ nn.LayerNorm(256)
1018
+ )
1019
+
1020
+ def forward(self, audio_feat, dec_input=None):
1021
+
1022
+ B = audio_feat.shape[0]
1023
+
1024
+ aud_embed = self.aud_enc.forward(audio_feat)
1025
+
1026
+ if self.augmentation:
1027
+ dec_input = dec_input.squeeze(0)
1028
+ dec_embed = self.pose_enc(dec_input)
1029
+ dec_embed = dec_embed.unsqueeze(2)
1030
+ dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1])
1031
+ aud_embed = torch.cat([aud_embed, dec_embed], dim=1)
1032
+
1033
+ out = self.aud_dec.forward(aud_embed)
1034
+ return out
1035
+
1036
+
1037
+ if __name__ == '__main__':
1038
+ import numpy as np
1039
+ import os
1040
+ import sys
1041
+
1042
+ test_model = SeqEncoder2D(
1043
+ C_in=2,
1044
+ T_in=25,
1045
+ C_out=512,
1046
+ num_joints=54,
1047
+ )
1048
+ print(test_model.num_layers)
1049
+
1050
+ input = torch.randn((64, 108, 25))
1051
+ output = test_model(input)
1052
+ print(output.shape)
nets/smplx_body_pixel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ from torch.optim.lr_scheduler import StepLR
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ from nets.layers import *
10
+ from nets.base import TrainWrapperBaseClass
11
+ from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
12
+ from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder
13
+ from nets.spg.vqvae_1d import AudioEncoder
14
+ from nets.utils import parse_audio, denormalize
15
+ from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
16
+ import numpy as np
17
+ import torch.optim as optim
18
+ import torch.nn.functional as F
19
+ from sklearn.preprocessing import normalize
20
+
21
+ from data_utils.lower_body import c_index, c_index_3d, c_index_6d
22
+ from data_utils.utils import smooth_geom, get_mfcc_sepa
23
+
24
+
25
+ class TrainWrapper(TrainWrapperBaseClass):
26
+ '''
27
+ a wrapper receving a batch from data_utils and calculate loss
28
+ '''
29
+
30
+ def __init__(self, args, config):
31
+ self.args = args
32
+ self.config = config
33
+ self.device = torch.device(self.args.gpu)
34
+ self.global_step = 0
35
+
36
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
37
+ self.expression = self.config.Data.pose.expression
38
+ self.epoch = 0
39
+ self.init_params()
40
+ self.num_classes = 4
41
+ self.audio = True
42
+ self.composition = self.config.Model.composition
43
+ self.bh_model = self.config.Model.bh_model
44
+
45
+ if self.audio:
46
+ self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
47
+ else:
48
+ self.audioencoder = None
49
+ if self.convert_to_6d:
50
+ dim, layer = 512, 10
51
+ else:
52
+ dim, layer = 256, 15
53
+ self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
54
+ self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
55
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
56
+ self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
57
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
58
+
59
+ model_path = self.config.Model.vq_path
60
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
61
+ self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
62
+ self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
63
+
64
+ if torch.cuda.device_count() > 1:
65
+ self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1])
66
+ self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1])
67
+ self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1])
68
+ if self.audioencoder is not None:
69
+ self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1])
70
+
71
+ self.discriminator = None
72
+ if self.convert_to_6d:
73
+ self.c_index = c_index_6d
74
+ else:
75
+ self.c_index = c_index_3d
76
+
77
+ super().__init__(args, config)
78
+
79
+ def init_optimizer(self):
80
+
81
+ print('using Adam')
82
+ self.generator_optimizer = optim.Adam(
83
+ self.generator.parameters(),
84
+ lr=self.config.Train.learning_rate.generator_learning_rate,
85
+ betas=[0.9, 0.999]
86
+ )
87
+ if self.audioencoder is not None:
88
+ opt = self.config.Model.AudioOpt
89
+ if opt == 'Adam':
90
+ self.audioencoder_optimizer = optim.Adam(
91
+ self.audioencoder.parameters(),
92
+ lr=self.config.Train.learning_rate.generator_learning_rate,
93
+ betas=[0.9, 0.999]
94
+ )
95
+ else:
96
+ print('using SGD')
97
+ self.audioencoder_optimizer = optim.SGD(
98
+ filter(lambda p: p.requires_grad,self.audioencoder.parameters()),
99
+ lr=self.config.Train.learning_rate.generator_learning_rate*10,
100
+ momentum=0.9,
101
+ nesterov=False,
102
+ )
103
+
104
+ def state_dict(self):
105
+ model_state = {
106
+ 'generator': self.generator.state_dict(),
107
+ 'generator_optim': self.generator_optimizer.state_dict(),
108
+ 'audioencoder': self.audioencoder.state_dict() if self.audio else None,
109
+ 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
110
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
111
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
112
+ }
113
+ return model_state
114
+
115
+ def load_state_dict(self, state_dict):
116
+
117
+ from collections import OrderedDict
118
+ new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.`
119
+ for k, v in state_dict.items():
120
+ sub_dict = OrderedDict()
121
+ if v is not None:
122
+ for k1, v1 in v.items():
123
+ name = k1.replace('module.', '')
124
+ sub_dict[name] = v1
125
+ new_state_dict[k] = sub_dict
126
+ state_dict = new_state_dict
127
+ if 'generator' in state_dict:
128
+ self.generator.load_state_dict(state_dict['generator'])
129
+ else:
130
+ self.generator.load_state_dict(state_dict)
131
+
132
+ if 'generator_optim' in state_dict and self.generator_optimizer is not None:
133
+ self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
134
+
135
+ if self.discriminator is not None:
136
+ self.discriminator.load_state_dict(state_dict['discriminator'])
137
+
138
+ if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
139
+ self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
140
+
141
+ if 'audioencoder' in state_dict and self.audioencoder is not None:
142
+ self.audioencoder.load_state_dict(state_dict['audioencoder'])
143
+
144
+ def init_params(self):
145
+ if self.config.Data.pose.convert_to_6d:
146
+ scale = 2
147
+ else:
148
+ scale = 1
149
+
150
+ global_orient = round(0 * scale)
151
+ leye_pose = reye_pose = round(0 * scale)
152
+ jaw_pose = round(0 * scale)
153
+ body_pose = round((63 - 24) * scale)
154
+ left_hand_pose = right_hand_pose = round(45 * scale)
155
+ if self.expression:
156
+ expression = 100
157
+ else:
158
+ expression = 0
159
+
160
+ b_j = 0
161
+ jaw_dim = jaw_pose
162
+ b_e = b_j + jaw_dim
163
+ eye_dim = leye_pose + reye_pose
164
+ b_b = b_e + eye_dim
165
+ body_dim = global_orient + body_pose
166
+ b_h = b_b + body_dim
167
+ hand_dim = left_hand_pose + right_hand_pose
168
+ b_f = b_h + hand_dim
169
+ face_dim = expression
170
+
171
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
172
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
173
+ self.pose = int(self.full_dim / round(3 * scale))
174
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
175
+
176
+ def __call__(self, bat):
177
+ # assert (not self.args.infer), "infer mode"
178
+ self.global_step += 1
179
+
180
+ total_loss = None
181
+ loss_dict = {}
182
+
183
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
184
+
185
+ id = bat['speaker'].to(self.device) - 20
186
+ # id = F.one_hot(id, self.num_classes)
187
+
188
+ poses = poses[:, self.c_index, :]
189
+
190
+ aud = aud.permute(0, 2, 1)
191
+ gt_poses = poses.permute(0, 2, 1)
192
+
193
+ with torch.no_grad():
194
+ self.g_body.eval()
195
+ self.g_hand.eval()
196
+ if torch.cuda.device_count() > 1:
197
+ _, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
198
+ _, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
199
+ else:
200
+ _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
201
+ _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
202
+ latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1)
203
+ latents = latents.detach()
204
+
205
+ if self.audio:
206
+ audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
207
+ logits = self.generator(latents[:, :], id, audio)
208
+ else:
209
+ logits = self.generator(latents, id)
210
+ logits = logits.permute(0, 2, 3, 1).contiguous()
211
+
212
+ self.generator_optimizer.zero_grad()
213
+ if self.audio:
214
+ self.audioencoder_optimizer.zero_grad()
215
+
216
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
217
+ loss.backward()
218
+
219
+ grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
220
+
221
+ if torch.isnan(grad).sum() > 0:
222
+ print('fuck')
223
+
224
+ loss_dict['grad'] = grad.item()
225
+ loss_dict['ce_loss'] = loss.item()
226
+ self.generator_optimizer.step()
227
+ if self.audio:
228
+ self.audioencoder_optimizer.step()
229
+
230
+ return total_loss, loss_dict
231
+
232
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None,
233
+ continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs):
234
+ '''
235
+ initial_pose: (B, C, T), normalized
236
+ (aud_fn, txgfile) -> generated motion (B, T, C)
237
+ '''
238
+ output = []
239
+
240
+ assert self.args.infer, "train mode"
241
+ self.generator.eval()
242
+ self.g_body.eval()
243
+ self.g_hand.eval()
244
+
245
+ if continuity:
246
+ aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps)
247
+ else:
248
+ aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am)
249
+ aud_feat = aud_feat.transpose(1, 0)
250
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
251
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
252
+
253
+ if id is None:
254
+ id = torch.tensor([0]).to(self.device)
255
+ else:
256
+ id = id.repeat(B)
257
+
258
+ with torch.no_grad():
259
+ aud_feat = aud_feat.permute(0, 2, 1)
260
+ if continuity:
261
+ self.audioencoder.eval()
262
+ pre_pose = {}
263
+ pre_pose['b'] = pre_pose['h'] = None
264
+ pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose)
265
+ pre_pose['b'] = body_0[:, :, -4:].transpose(1,2)
266
+ pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2)
267
+ _, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose)
268
+ body = torch.cat([body_0, body_1], dim=2)
269
+ hand = torch.cat([hand_0, hand_1], dim=2)
270
+
271
+ else:
272
+ if self.audio:
273
+ self.audioencoder.eval()
274
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
275
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio)
276
+ else:
277
+ latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B)
278
+
279
+ body_latents = latents[..., 0]
280
+ hand_latents = latents[..., 1]
281
+
282
+ body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
283
+ hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
284
+
285
+ pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy()
286
+
287
+ output = pred_poses
288
+
289
+ return output
290
+
291
+ def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None):
292
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
293
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio,
294
+ pre_latents=pre_latents, pre_audio=pre_audio)
295
+
296
+ body_latents = latents[..., 0]
297
+ hand_latents = latents[..., 1]
298
+
299
+ body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1],
300
+ latents=body_latents, pre_state=pre_pose['b'])
301
+ hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1],
302
+ latents=hand_latents, pre_state=pre_pose['h'])
303
+
304
+ return latents, audio, body, hand
305
+
306
+ def generate(self, aud, id, frame_num=0):
307
+
308
+ self.generator.eval()
309
+ self.g_body.eval()
310
+ self.g_hand.eval()
311
+ aud_feat = aud.permute(0, 2, 1)
312
+ if self.audio:
313
+ self.audioencoder.eval()
314
+ audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
315
+ latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio)
316
+ else:
317
+ latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0])
318
+
319
+ body_latents = latents[..., 0]
320
+ hand_latents = latents[..., 1]
321
+
322
+ body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
323
+ hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
324
+
325
+ pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2)
326
+ return pred_poses
nets/smplx_body_vq.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from torch.optim.lr_scheduler import StepLR
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ from nets.layers import *
9
+ from nets.base import TrainWrapperBaseClass
10
+ from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G
11
+ from nets.spg.vqvae_1d import VQVAE as s2g_body
12
+ from nets.utils import parse_audio, denormalize
13
+ from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
14
+ import numpy as np
15
+ import torch.optim as optim
16
+ import torch.nn.functional as F
17
+ from sklearn.preprocessing import normalize
18
+
19
+ from data_utils.lower_body import c_index, c_index_3d, c_index_6d
20
+
21
+
22
+ class TrainWrapper(TrainWrapperBaseClass):
23
+ '''
24
+ a wrapper receving a batch from data_utils and calculate loss
25
+ '''
26
+
27
+ def __init__(self, args, config):
28
+ self.args = args
29
+ self.config = config
30
+ self.device = torch.device(self.args.gpu)
31
+ self.global_step = 0
32
+
33
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
34
+ self.expression = self.config.Data.pose.expression
35
+ self.epoch = 0
36
+ self.init_params()
37
+ self.num_classes = 4
38
+ self.composition = self.config.Model.composition
39
+ if self.composition:
40
+ self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
41
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
42
+ self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
43
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
44
+ else:
45
+ self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num,
46
+ num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
47
+
48
+ self.discriminator = None
49
+
50
+ if self.convert_to_6d:
51
+ self.c_index = c_index_6d
52
+ else:
53
+ self.c_index = c_index_3d
54
+
55
+ super().__init__(args, config)
56
+
57
+ def init_optimizer(self):
58
+ print('using Adam')
59
+ if self.composition:
60
+ self.g_body_optimizer = optim.Adam(
61
+ self.g_body.parameters(),
62
+ lr=self.config.Train.learning_rate.generator_learning_rate,
63
+ betas=[0.9, 0.999]
64
+ )
65
+ self.g_hand_optimizer = optim.Adam(
66
+ self.g_hand.parameters(),
67
+ lr=self.config.Train.learning_rate.generator_learning_rate,
68
+ betas=[0.9, 0.999]
69
+ )
70
+ else:
71
+ self.g_optimizer = optim.Adam(
72
+ self.g.parameters(),
73
+ lr=self.config.Train.learning_rate.generator_learning_rate,
74
+ betas=[0.9, 0.999]
75
+ )
76
+
77
+ def state_dict(self):
78
+ if self.composition:
79
+ model_state = {
80
+ 'g_body': self.g_body.state_dict(),
81
+ 'g_body_optim': self.g_body_optimizer.state_dict(),
82
+ 'g_hand': self.g_hand.state_dict(),
83
+ 'g_hand_optim': self.g_hand_optimizer.state_dict(),
84
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
85
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
86
+ }
87
+ else:
88
+ model_state = {
89
+ 'g': self.g.state_dict(),
90
+ 'g_optim': self.g_optimizer.state_dict(),
91
+ 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
92
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
93
+ }
94
+ return model_state
95
+
96
+ def init_params(self):
97
+ if self.config.Data.pose.convert_to_6d:
98
+ scale = 2
99
+ else:
100
+ scale = 1
101
+
102
+ global_orient = round(0 * scale)
103
+ leye_pose = reye_pose = round(0 * scale)
104
+ jaw_pose = round(0 * scale)
105
+ body_pose = round((63 - 24) * scale)
106
+ left_hand_pose = right_hand_pose = round(45 * scale)
107
+ if self.expression:
108
+ expression = 100
109
+ else:
110
+ expression = 0
111
+
112
+ b_j = 0
113
+ jaw_dim = jaw_pose
114
+ b_e = b_j + jaw_dim
115
+ eye_dim = leye_pose + reye_pose
116
+ b_b = b_e + eye_dim
117
+ body_dim = global_orient + body_pose
118
+ b_h = b_b + body_dim
119
+ hand_dim = left_hand_pose + right_hand_pose
120
+ b_f = b_h + hand_dim
121
+ face_dim = expression
122
+
123
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
124
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
125
+ self.pose = int(self.full_dim / round(3 * scale))
126
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
127
+
128
+ def __call__(self, bat):
129
+ # assert (not self.args.infer), "infer mode"
130
+ self.global_step += 1
131
+
132
+ total_loss = None
133
+ loss_dict = {}
134
+
135
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
136
+
137
+ # id = bat['speaker'].to(self.device) - 20
138
+ # id = F.one_hot(id, self.num_classes)
139
+
140
+ poses = poses[:, self.c_index, :]
141
+ gt_poses = poses.permute(0, 2, 1)
142
+ b_poses = gt_poses[..., :self.each_dim[1]]
143
+ h_poses = gt_poses[..., self.each_dim[1]:]
144
+
145
+ if self.composition:
146
+ loss = 0
147
+ loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss)
148
+ loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss)
149
+ else:
150
+ loss = 0
151
+ loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
152
+
153
+ return total_loss, loss_dict
154
+
155
+ def vq_train(self, gt, name, model, dict, total_loss, pre=None):
156
+ e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre)
157
+ loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre)
158
+ # total_loss = total_loss + loss
159
+
160
+ if name == 'b':
161
+ optimizer_name = 'g_body_optimizer'
162
+ elif name == 'h':
163
+ optimizer_name = 'g_hand_optimizer'
164
+ elif name == 'g':
165
+ optimizer_name = 'g_optimizer'
166
+ else:
167
+ raise ValueError("model's name must be b or h")
168
+ optimizer = getattr(self, optimizer_name)
169
+ optimizer.zero_grad()
170
+ loss.backward()
171
+ optimizer.step()
172
+
173
+ for key in list(loss_dict.keys()):
174
+ dict[name + key] = loss_dict.get(key, 0).item()
175
+ return dict, total_loss
176
+
177
+ def get_loss(self,
178
+ pred_poses,
179
+ gt_poses,
180
+ e_q_loss,
181
+ pre=None
182
+ ):
183
+ loss_dict = {}
184
+
185
+
186
+ rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
187
+ v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
188
+ v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
189
+ velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
190
+
191
+ if pre is None:
192
+ f0_vel = 0
193
+ else:
194
+ v0_pr = pred_poses[:, 0] - pre[:, -1]
195
+ v0_gt = gt_poses[:, 0] - pre[:, -1]
196
+ f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
197
+
198
+ gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel
199
+
200
+ loss_dict['rec_loss'] = rec_loss
201
+ loss_dict['velocity_loss'] = velocity_loss
202
+ # loss_dict['e_q_loss'] = e_q_loss
203
+ if pre is not None:
204
+ loss_dict['f0_vel'] = f0_vel
205
+
206
+ return gen_loss, loss_dict
207
+
208
+ def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False,
209
+ id=None, fps=15, sr=22000, smooth=False, **kwargs):
210
+ '''
211
+ initial_pose: (B, C, T), normalized
212
+ (aud_fn, txgfile) -> generated motion (B, T, C)
213
+ '''
214
+ output = []
215
+
216
+ assert self.args.infer, "train mode"
217
+ if self.composition:
218
+ self.g_body.eval()
219
+ self.g_hand.eval()
220
+ else:
221
+ self.g.eval()
222
+
223
+ if self.config.Data.pose.normalization:
224
+ assert norm_stats is not None
225
+ data_mean = norm_stats[0]
226
+ data_std = norm_stats[1]
227
+
228
+ # assert initial_pose.shape[-1] == pre_length
229
+ if initial_pose is not None:
230
+ gt = initial_pose[:, :, :].to(self.device).to(torch.float32)
231
+ pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32)
232
+ poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
233
+ B = pre_poses.shape[0]
234
+ else:
235
+ gt = None
236
+ pre_poses = None
237
+ B = 1
238
+
239
+ if type(aud_fn) == torch.Tensor:
240
+ aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device)
241
+ num_poses_to_generate = aud_feat.shape[-1]
242
+ else:
243
+ aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
244
+ aud_feat = aud_feat[:, :]
245
+ num_poses_to_generate = aud_feat.shape[-1]
246
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
247
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
248
+
249
+ # pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32)
250
+ if id is None:
251
+ id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device)
252
+
253
+ with torch.no_grad():
254
+ aud_feat = aud_feat.permute(0, 2, 1)
255
+ gt_poses = gt[:, self.c_index].permute(0, 2, 1)
256
+ if self.composition:
257
+ if continuity:
258
+ pred_poses_body = []
259
+ pred_poses_hand = []
260
+ pre_b = None
261
+ pre_h = None
262
+ for i in range(5):
263
+ _, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b)
264
+ pre_b = pred_body[..., -1:].transpose(1,2)
265
+ pred_poses_body.append(pred_body)
266
+ _, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h)
267
+ pre_h = pred_hand[..., -1:].transpose(1,2)
268
+ pred_poses_hand.append(pred_hand)
269
+
270
+ pred_poses_body = torch.cat(pred_poses_body, dim=2)
271
+ pred_poses_hand = torch.cat(pred_poses_hand, dim=2)
272
+ else:
273
+ _, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
274
+ _, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
275
+ pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1)
276
+ else:
277
+ _, pred_poses = self.g(gt_poses=gt_poses, id=id)
278
+ pred_poses = pred_poses.transpose(1, 2).cpu().numpy()
279
+ output = pred_poses
280
+
281
+ if self.config.Data.pose.normalization:
282
+ output = denormalize(output, data_mean, data_std)
283
+
284
+ if smooth:
285
+ lamda = 0.8
286
+ smooth_f = 10
287
+ frame = 149
288
+ for i in range(smooth_f):
289
+ f = frame + i
290
+ l = lamda * (i + 1) / smooth_f
291
+ output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f]
292
+
293
+ output = np.concatenate(output, axis=1)
294
+
295
+ return output
296
+
297
+ def load_state_dict(self, state_dict):
298
+ if self.composition:
299
+ self.g_body.load_state_dict(state_dict['g_body'])
300
+ self.g_hand.load_state_dict(state_dict['g_hand'])
301
+ else:
302
+ self.g.load_state_dict(state_dict['g'])
nets/smplx_face.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from nets.layers import *
7
+ from nets.base import TrainWrapperBaseClass
8
+ # from nets.spg.faceformer import Faceformer
9
+ from nets.spg.s2g_face import Generator as s2g_face
10
+ from losses import KeypointLoss
11
+ from nets.utils import denormalize
12
+ from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
13
+ import numpy as np
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+ from sklearn.preprocessing import normalize
17
+ import smplx
18
+
19
+
20
+ class TrainWrapper(TrainWrapperBaseClass):
21
+ '''
22
+ a wrapper receving a batch from data_utils and calculate loss
23
+ '''
24
+
25
+ def __init__(self, args, config):
26
+ self.args = args
27
+ self.config = config
28
+ self.device = torch.device(self.args.gpu)
29
+ self.global_step = 0
30
+
31
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
32
+ self.expression = self.config.Data.pose.expression
33
+ self.epoch = 0
34
+ self.init_params()
35
+ self.num_classes = 4
36
+
37
+ self.generator = s2g_face(
38
+ n_poses=self.config.Data.pose.generate_length,
39
+ each_dim=self.each_dim,
40
+ dim_list=self.dim_list,
41
+ training=not self.args.infer,
42
+ device=self.device,
43
+ identity=False if self.convert_to_6d else True,
44
+ num_classes=self.num_classes,
45
+ ).to(self.device)
46
+
47
+ # self.generator = Faceformer().to(self.device)
48
+
49
+ self.discriminator = None
50
+ self.am = None
51
+
52
+ self.MSELoss = KeypointLoss().to(self.device)
53
+ super().__init__(args, config)
54
+
55
+ def init_optimizer(self):
56
+ self.generator_optimizer = optim.SGD(
57
+ filter(lambda p: p.requires_grad,self.generator.parameters()),
58
+ lr=0.001,
59
+ momentum=0.9,
60
+ nesterov=False,
61
+ )
62
+
63
+ def init_params(self):
64
+ if self.convert_to_6d:
65
+ scale = 2
66
+ else:
67
+ scale = 1
68
+
69
+ global_orient = round(3 * scale)
70
+ leye_pose = reye_pose = round(3 * scale)
71
+ jaw_pose = round(3 * scale)
72
+ body_pose = round(63 * scale)
73
+ left_hand_pose = right_hand_pose = round(45 * scale)
74
+ if self.expression:
75
+ expression = 100
76
+ else:
77
+ expression = 0
78
+
79
+ b_j = 0
80
+ jaw_dim = jaw_pose
81
+ b_e = b_j + jaw_dim
82
+ eye_dim = leye_pose + reye_pose
83
+ b_b = b_e + eye_dim
84
+ body_dim = global_orient + body_pose
85
+ b_h = b_b + body_dim
86
+ hand_dim = left_hand_pose + right_hand_pose
87
+ b_f = b_h + hand_dim
88
+ face_dim = expression
89
+
90
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
91
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
92
+ self.pose = int(self.full_dim / round(3 * scale))
93
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
94
+
95
+ def __call__(self, bat):
96
+ # assert (not self.args.infer), "infer mode"
97
+ self.global_step += 1
98
+
99
+ total_loss = None
100
+ loss_dict = {}
101
+
102
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
103
+ id = bat['speaker'].to(self.device) - 20
104
+ id = F.one_hot(id, self.num_classes)
105
+
106
+ aud = aud.permute(0, 2, 1)
107
+ gt_poses = poses.permute(0, 2, 1)
108
+
109
+ if self.expression:
110
+ expression = bat['expression'].to(self.device).to(torch.float32)
111
+ gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
112
+
113
+ pred_poses, _ = self.generator(
114
+ aud,
115
+ gt_poses,
116
+ id,
117
+ )
118
+
119
+ G_loss, G_loss_dict = self.get_loss(
120
+ pred_poses=pred_poses,
121
+ gt_poses=gt_poses,
122
+ pre_poses=None,
123
+ mode='training_G',
124
+ gt_conf=None,
125
+ aud=aud,
126
+ )
127
+
128
+ self.generator_optimizer.zero_grad()
129
+ G_loss.backward()
130
+ grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
131
+ loss_dict['grad'] = grad.item()
132
+ self.generator_optimizer.step()
133
+
134
+ for key in list(G_loss_dict.keys()):
135
+ loss_dict[key] = G_loss_dict.get(key, 0).item()
136
+
137
+ return total_loss, loss_dict
138
+
139
+ def get_loss(self,
140
+ pred_poses,
141
+ gt_poses,
142
+ pre_poses,
143
+ aud,
144
+ mode='training_G',
145
+ gt_conf=None,
146
+ exp=1,
147
+ gt_nzero=None,
148
+ pre_nzero=None,
149
+ ):
150
+ loss_dict = {}
151
+
152
+
153
+ [b_j, b_e, b_b, b_h, b_f] = self.dim_list
154
+
155
+ MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
156
+ if self.expression:
157
+ expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
158
+ else:
159
+ expl = 0
160
+
161
+ gen_loss = expl + MSELoss
162
+
163
+ loss_dict['MSELoss'] = MSELoss
164
+ if self.expression:
165
+ loss_dict['exp_loss'] = expl
166
+
167
+ return gen_loss, loss_dict
168
+
169
+ def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
170
+ '''
171
+ initial_pose: (B, C, T), normalized
172
+ (aud_fn, txgfile) -> generated motion (B, T, C)
173
+ '''
174
+ output = []
175
+
176
+ # assert self.args.infer, "train mode"
177
+ self.generator.eval()
178
+
179
+ if self.config.Data.pose.normalization:
180
+ assert norm_stats is not None
181
+ data_mean = norm_stats[0]
182
+ data_std = norm_stats[1]
183
+
184
+ # assert initial_pose.shape[-1] == pre_length
185
+ if initial_pose is not None:
186
+ gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
187
+ pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
188
+ poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
189
+ B = pre_poses.shape[0]
190
+ else:
191
+ gt = None
192
+ pre_poses=None
193
+ B = 1
194
+
195
+ if type(aud_fn) == torch.Tensor:
196
+ aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
197
+ num_poses_to_generate = aud_feat.shape[-1]
198
+ else:
199
+ aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
200
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
201
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
202
+ if frame is None:
203
+ frame = aud_feat.shape[2]*30//16000
204
+ #
205
+ if id is None:
206
+ id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
207
+ else:
208
+ id = F.one_hot(id, self.num_classes).to(self.generator.device)
209
+
210
+ with torch.no_grad():
211
+ pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
212
+ pred_poses = pred_poses.cpu().numpy()
213
+ output = pred_poses
214
+
215
+ if self.config.Data.pose.normalization:
216
+ output = denormalize(output, data_mean, data_std)
217
+
218
+ return output
219
+
220
+
221
+ def generate(self, wv2_feat, frame):
222
+ '''
223
+ initial_pose: (B, C, T), normalized
224
+ (aud_fn, txgfile) -> generated motion (B, T, C)
225
+ '''
226
+ output = []
227
+
228
+ # assert self.args.infer, "train mode"
229
+ self.generator.eval()
230
+
231
+ B = 1
232
+
233
+ id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
234
+ id = id.repeat(wv2_feat.shape[0], 1)
235
+
236
+ with torch.no_grad():
237
+ pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
238
+ return pred_poses
nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc ADDED
Binary file (5.16 kB). View file
 
nets/spg/__pycache__/s2g_face.cpython-37.pyc ADDED
Binary file (6.96 kB). View file
 
nets/spg/__pycache__/s2glayers.cpython-37.pyc ADDED
Binary file (11.5 kB). View file
 
nets/spg/__pycache__/vqvae_1d.cpython-37.pyc ADDED
Binary file (8.07 kB). View file
 
nets/spg/__pycache__/vqvae_modules.cpython-37.pyc ADDED
Binary file (10.6 kB). View file
 
nets/spg/__pycache__/wav2vec.cpython-37.pyc ADDED
Binary file (3.89 kB). View file
 
nets/spg/gated_pixelcnn_v2.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def weights_init(m):
7
+ classname = m.__class__.__name__
8
+ if classname.find('Conv') != -1:
9
+ try:
10
+ nn.init.xavier_uniform_(m.weight.data)
11
+ m.bias.data.fill_(0)
12
+ except AttributeError:
13
+ print("Skipping initialization of ", classname)
14
+
15
+
16
+ class GatedActivation(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, x):
21
+ x, y = x.chunk(2, dim=1)
22
+ return F.tanh(x) * F.sigmoid(y)
23
+
24
+
25
+ class GatedMaskedConv2d(nn.Module):
26
+ def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
27
+ super().__init__()
28
+ assert kernel % 2 == 1, print("Kernel size must be odd")
29
+ self.mask_type = mask_type
30
+ self.residual = residual
31
+ self.bh_model = bh_model
32
+
33
+ self.class_cond_embedding = nn.Embedding(n_classes, 2 * dim)
34
+ self.class_cond_embedding = self.class_cond_embedding.to("cpu")
35
+
36
+ kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
37
+ padding_shp = (kernel // 2, 1 if self.bh_model else 0)
38
+ self.vert_stack = nn.Conv2d(
39
+ dim, dim * 2,
40
+ kernel_shp, 1, padding_shp
41
+ )
42
+
43
+ self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
44
+
45
+ kernel_shp = (1, 2)
46
+ padding_shp = (0, 1)
47
+ self.horiz_stack = nn.Conv2d(
48
+ dim, dim * 2,
49
+ kernel_shp, 1, padding_shp
50
+ )
51
+
52
+ self.horiz_resid = nn.Conv2d(dim, dim, 1)
53
+
54
+ self.gate = GatedActivation()
55
+
56
+ def make_causal(self):
57
+ self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
58
+ self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
59
+
60
+ def forward(self, x_v, x_h, h):
61
+ if self.mask_type == 'A':
62
+ self.make_causal()
63
+
64
+ h = h.to(self.class_cond_embedding.weight.device)
65
+ h = self.class_cond_embedding(h)
66
+
67
+ h_vert = self.vert_stack(x_v)
68
+ h_vert = h_vert[:, :, :x_v.size(-2), :]
69
+ out_v = self.gate(h_vert + h[:, :, None, None])
70
+
71
+ if self.bh_model:
72
+ h_horiz = self.horiz_stack(x_h)
73
+ h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
74
+ v2h = self.vert_to_horiz(h_vert)
75
+
76
+ out = self.gate(v2h + h_horiz + h[:, :, None, None])
77
+ if self.residual:
78
+ out_h = self.horiz_resid(out) + x_h
79
+ else:
80
+ out_h = self.horiz_resid(out)
81
+ else:
82
+ if self.residual:
83
+ out_v = self.horiz_resid(out_v) + x_v
84
+ else:
85
+ out_v = self.horiz_resid(out_v)
86
+ out_h = out_v
87
+
88
+ return out_v, out_h
89
+
90
+
91
+ class GatedPixelCNN(nn.Module):
92
+ def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
93
+ super().__init__()
94
+ self.dim = dim
95
+ self.audio = audio
96
+ self.bh_model = bh_model
97
+
98
+ if self.audio:
99
+ self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
100
+ self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
101
+ self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
102
+
103
+ # Create embedding layer to embed input
104
+ self.embedding = nn.Embedding(input_dim, dim)
105
+
106
+ # Building the PixelCNN layer by layer
107
+ self.layers = nn.ModuleList()
108
+
109
+ # Initial block with Mask-A convolution
110
+ # Rest with Mask-B convolutions
111
+ for i in range(n_layers):
112
+ mask_type = 'A' if i == 0 else 'B'
113
+ kernel = 7 if i == 0 else 3
114
+ residual = False if i == 0 else True
115
+
116
+ self.layers.append(
117
+ GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
118
+ )
119
+
120
+ # Add the output layer
121
+ self.output_conv = nn.Sequential(
122
+ nn.Conv2d(dim, 512, 1),
123
+ nn.ReLU(True),
124
+ nn.Conv2d(512, input_dim, 1)
125
+ )
126
+
127
+ self.apply(weights_init)
128
+
129
+ self.dp = nn.Dropout(0.1)
130
+ self.to("cpu")
131
+
132
+ def forward(self, x, label, aud=None):
133
+ shp = x.size() + (-1,)
134
+ x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
135
+ x = x.permute(0, 3, 1, 2) # (B, C, W, W)
136
+
137
+ x_v, x_h = (x, x)
138
+ for i, layer in enumerate(self.layers):
139
+ if i == 1 and self.audio is True:
140
+ aud = self.embedding_aud(aud)
141
+ a = torch.ones(aud.shape[-2]).to(aud.device)
142
+ a = self.dp(a)
143
+ aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
144
+ x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
145
+ if self.bh_model:
146
+ x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
147
+ x_v, x_h = layer(x_v, x_h, label)
148
+
149
+ if self.bh_model:
150
+ return self.output_conv(x_h)
151
+ else:
152
+ return self.output_conv(x_v)
153
+
154
+ def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
155
+ param = next(self.parameters())
156
+ x = torch.zeros(
157
+ (batch_size, *shape),
158
+ dtype=torch.int64, device=param.device
159
+ )
160
+ if pre_latents is not None:
161
+ x = torch.cat([pre_latents, x], dim=1)
162
+ aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
163
+ h0 = pre_latents.shape[1]
164
+ h = h0 + shape[0]
165
+ else:
166
+ h0 = 0
167
+ h = shape[0]
168
+
169
+ for i in range(h0, h):
170
+ for j in range(shape[1]):
171
+ if self.audio:
172
+ logits = self.forward(x, label, aud_feat)
173
+ else:
174
+ logits = self.forward(x, label)
175
+ probs = F.softmax(logits[:, :, i, j], -1)
176
+ x.data[:, i, j].copy_(
177
+ probs.multinomial(1).squeeze().data
178
+ )
179
+ return x[:, h0:h]
nets/spg/s2g_face.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ not exactly the same as the official repo but the results are good
3
+ '''
4
+ import sys
5
+ import os
6
+
7
+ from transformers import Wav2Vec2Processor
8
+
9
+ from .wav2vec import Wav2Vec2Model
10
+ from torchaudio.sox_effects import apply_effects_tensor
11
+
12
+ sys.path.append(os.getcwd())
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torchaudio as ta
19
+ import math
20
+ from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
21
+
22
+
23
+ """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
24
+
25
+
26
+ def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
27
+ """
28
+ :param audio: 1 x T tensor containing a 16kHz audio signal
29
+ :param frame_rate: frame rate for video (we need one audio chunk per video frame)
30
+ :param chunk_size: number of audio samples per chunk
31
+ :return: num_chunks x chunk_size tensor containing sliced audio
32
+ """
33
+ samples_per_frame = 16000 // frame_rate
34
+ padding = (chunk_size - samples_per_frame) // 2
35
+ audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
36
+ anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
37
+ audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
38
+ return audio
39
+
40
+
41
+ class MeshtalkEncoder(nn.Module):
42
+ def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
43
+ """
44
+ :param latent_dim: size of the latent audio embedding
45
+ :param model_name: name of the model, used to load and save the model
46
+ """
47
+ super().__init__()
48
+
49
+ self.melspec = ta.transforms.MelSpectrogram(
50
+ sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
51
+ )
52
+
53
+ conv_len = 5
54
+ self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
55
+ self.weights_init(self.convert_dimensions)
56
+ self.receptive_field = conv_len
57
+
58
+ convs = []
59
+ for i in range(6):
60
+ dilation = 2 * (i % 3 + 1)
61
+ self.receptive_field += (conv_len - 1) * dilation
62
+ convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
63
+ self.weights_init(convs[-1])
64
+ self.convs = torch.nn.ModuleList(convs)
65
+ self.code = torch.nn.Linear(128, latent_dim)
66
+
67
+ self.apply(lambda x: self.weights_init(x))
68
+
69
+ def weights_init(self, m):
70
+ if isinstance(m, torch.nn.Conv1d):
71
+ torch.nn.init.xavier_uniform_(m.weight)
72
+ try:
73
+ torch.nn.init.constant_(m.bias, .01)
74
+ except:
75
+ pass
76
+
77
+ def forward(self, audio: torch.Tensor):
78
+ """
79
+ :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
80
+ :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
81
+ """
82
+ B, T = audio.shape[0], audio.shape[1]
83
+ x = self.melspec(audio).squeeze(1)
84
+ x = torch.log(x.clamp(min=1e-10, max=None))
85
+ if T == 1:
86
+ x = x.unsqueeze(1)
87
+
88
+ # Convert to the right dimensionality
89
+ x = x.view(-1, x.shape[2], x.shape[3])
90
+ x = F.leaky_relu(self.convert_dimensions(x), .2)
91
+
92
+ # Process stacks
93
+ for conv in self.convs:
94
+ x_ = F.leaky_relu(conv(x), .2)
95
+ if self.training:
96
+ x_ = F.dropout(x_, .2)
97
+ l = (x.shape[2] - x_.shape[2]) // 2
98
+ x = (x[:, :, l:-l] + x_) / 2
99
+
100
+ x = torch.mean(x, dim=-1)
101
+ x = x.view(B, T, x.shape[-1])
102
+ x = self.code(x)
103
+
104
+ return {"code": x}
105
+
106
+
107
+ class AudioEncoder(nn.Module):
108
+ def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
109
+ super().__init__()
110
+ self.identity = identity
111
+ if self.identity:
112
+ in_dim = in_dim + 64
113
+ self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
114
+ self.first_net = SeqTranslator1D(in_dim, out_dim,
115
+ min_layers_num=3,
116
+ residual=True,
117
+ norm='ln'
118
+ )
119
+ self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
120
+ self.dropout = nn.Dropout(0.1)
121
+ # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
122
+
123
+ def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
124
+
125
+ spectrogram = spectrogram
126
+ spectrogram = self.dropout(spectrogram)
127
+ if self.identity:
128
+ id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
129
+ id = self.id_mlp(id)
130
+ spectrogram = torch.cat([spectrogram, id], dim=1)
131
+ x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
132
+ if time_steps is not None:
133
+ x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
134
+ # x1, _ = self.att(x1, x1, x1)
135
+ # x1, hidden_state = self.grus(x1)
136
+ # x1 = x1.permute(0, 2, 1)
137
+ hidden_state=None
138
+
139
+ return x1, hidden_state
140
+
141
+
142
+ class Generator(nn.Module):
143
+ def __init__(self,
144
+ n_poses,
145
+ each_dim: list,
146
+ dim_list: list,
147
+ training=False,
148
+ device=None,
149
+ identity=True,
150
+ num_classes=0,
151
+ ):
152
+ super().__init__()
153
+
154
+ self.training = training
155
+ self.device = device
156
+ self.gen_length = n_poses
157
+ self.identity = identity
158
+
159
+ norm = 'ln'
160
+ in_dim = 256
161
+ out_dim = 256
162
+
163
+ self.encoder_choice = 'faceformer'
164
+
165
+ if self.encoder_choice == 'meshtalk':
166
+ self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
167
+ elif self.encoder_choice == 'faceformer':
168
+ # wav2vec 2.0 weights initialization
169
+ self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
170
+ self.audio_encoder.feature_extractor._freeze_parameters()
171
+ self.audio_feature_map = nn.Linear(768, in_dim)
172
+ else:
173
+ self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
174
+
175
+ self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
176
+
177
+ self.dim_list = dim_list
178
+
179
+ self.decoder = nn.ModuleList()
180
+ self.final_out = nn.ModuleList()
181
+
182
+ self.decoder.append(nn.Sequential(
183
+ ConvNormRelu(out_dim, 64, norm=norm),
184
+ ConvNormRelu(64, 64, norm=norm),
185
+ ConvNormRelu(64, 64, norm=norm),
186
+ ))
187
+ self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
188
+
189
+ self.decoder.append(nn.Sequential(
190
+ ConvNormRelu(out_dim, out_dim, norm=norm),
191
+ ConvNormRelu(out_dim, out_dim, norm=norm),
192
+ ConvNormRelu(out_dim, out_dim, norm=norm),
193
+ ))
194
+ self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
195
+
196
+ def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
197
+ if self.training:
198
+ time_steps = gt_poses.shape[1]
199
+
200
+ # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
201
+ if self.encoder_choice == 'meshtalk':
202
+ in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
203
+ feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
204
+ elif self.encoder_choice == 'faceformer':
205
+ hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
206
+ feature = self.audio_feature_map(hidden_states).transpose(1, 2)
207
+ else:
208
+ feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
209
+
210
+ # hidden_states = in_spec
211
+
212
+ feature, _ = self.audio_middle(feature, id=id)
213
+
214
+ out = []
215
+
216
+ for i in range(self.decoder.__len__()):
217
+ mid = self.decoder[i](feature)
218
+ mid = self.final_out[i](mid)
219
+ out.append(mid)
220
+
221
+ out = torch.cat(out, dim=1)
222
+ out = out.transpose(1, 2)
223
+
224
+ return out, None
225
+
226
+
nets/spg/s2glayers.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ not exactly the same as the official repo but the results are good
3
+ '''
4
+ import sys
5
+ import os
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import math
14
+ from nets.layers import SeqEncoder1D, SeqTranslator1D
15
+
16
+ """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
17
+
18
+
19
+ class Conv2d_tf(nn.Conv2d):
20
+ """
21
+ Conv2d with the padding behavior from TF
22
+ from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super(Conv2d_tf, self).__init__(*args, **kwargs)
27
+ self.padding = kwargs.get("padding", "SAME")
28
+
29
+ def _compute_padding(self, input, dim):
30
+ input_size = input.size(dim + 2)
31
+ filter_size = self.weight.size(dim + 2)
32
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
33
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
34
+ total_padding = max(
35
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
36
+ )
37
+ additional_padding = int(total_padding % 2 != 0)
38
+
39
+ return additional_padding, total_padding
40
+
41
+ def forward(self, input):
42
+ if self.padding == "VALID":
43
+ return F.conv2d(
44
+ input,
45
+ self.weight,
46
+ self.bias,
47
+ self.stride,
48
+ padding=0,
49
+ dilation=self.dilation,
50
+ groups=self.groups,
51
+ )
52
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
53
+ cols_odd, padding_cols = self._compute_padding(input, dim=1)
54
+ if rows_odd or cols_odd:
55
+ input = F.pad(input, [0, cols_odd, 0, rows_odd])
56
+
57
+ return F.conv2d(
58
+ input,
59
+ self.weight,
60
+ self.bias,
61
+ self.stride,
62
+ padding=(padding_rows // 2, padding_cols // 2),
63
+ dilation=self.dilation,
64
+ groups=self.groups,
65
+ )
66
+
67
+
68
+ class Conv1d_tf(nn.Conv1d):
69
+ """
70
+ Conv1d with the padding behavior from TF
71
+ modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
72
+ """
73
+
74
+ def __init__(self, *args, **kwargs):
75
+ super(Conv1d_tf, self).__init__(*args, **kwargs)
76
+ self.padding = kwargs.get("padding")
77
+
78
+ def _compute_padding(self, input, dim):
79
+ input_size = input.size(dim + 2)
80
+ filter_size = self.weight.size(dim + 2)
81
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
82
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
83
+ total_padding = max(
84
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
85
+ )
86
+ additional_padding = int(total_padding % 2 != 0)
87
+
88
+ return additional_padding, total_padding
89
+
90
+ def forward(self, input):
91
+ # if self.padding == "valid":
92
+ # return F.conv1d(
93
+ # input,
94
+ # self.weight,
95
+ # self.bias,
96
+ # self.stride,
97
+ # padding=0,
98
+ # dilation=self.dilation,
99
+ # groups=self.groups,
100
+ # )
101
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
102
+ if rows_odd:
103
+ input = F.pad(input, [0, rows_odd])
104
+
105
+ return F.conv1d(
106
+ input,
107
+ self.weight,
108
+ self.bias,
109
+ self.stride,
110
+ padding=(padding_rows // 2),
111
+ dilation=self.dilation,
112
+ groups=self.groups,
113
+ )
114
+
115
+
116
+ def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1,
117
+ nonlinear='lrelu', bn='bn'):
118
+ if k is None and s is None:
119
+ if not downsample:
120
+ k = 3
121
+ s = 1
122
+ padding = 'same'
123
+ else:
124
+ k = 4
125
+ s = 2
126
+ padding = 'valid'
127
+
128
+ if type == '1d':
129
+ conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
130
+ norm_block = nn.BatchNorm1d(out_channels)
131
+ elif type == '2d':
132
+ conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
133
+ norm_block = nn.BatchNorm2d(out_channels)
134
+ else:
135
+ assert False
136
+ if bn != 'bn':
137
+ if bn == 'gn':
138
+ norm_block = nn.GroupNorm(1, out_channels)
139
+ elif bn == 'ln':
140
+ norm_block = nn.LayerNorm(out_channels)
141
+ else:
142
+ norm_block = nn.Identity()
143
+ if nonlinear == 'lrelu':
144
+ nlinear = nn.LeakyReLU(0.2, True)
145
+ elif nonlinear == 'tanh':
146
+ nlinear = nn.Tanh()
147
+ elif nonlinear == 'none':
148
+ nlinear = nn.Identity()
149
+
150
+ return nn.Sequential(
151
+ conv_block,
152
+ norm_block,
153
+ nlinear
154
+ )
155
+
156
+
157
+ class UnetUp(nn.Module):
158
+ def __init__(self, in_ch, out_ch):
159
+ super(UnetUp, self).__init__()
160
+ self.conv = ConvNormRelu(in_ch, out_ch)
161
+
162
+ def forward(self, x1, x2):
163
+ # x1 = torch.repeat_interleave(x1, 2, dim=2)
164
+ # x1 = x1[:, :, :x2.shape[2]]
165
+ x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear')
166
+ x = x1 + x2
167
+ x = self.conv(x)
168
+ return x
169
+
170
+
171
+ class UNet(nn.Module):
172
+ def __init__(self, input_dim, dim):
173
+ super(UNet, self).__init__()
174
+ # dim = 512
175
+ self.down1 = nn.Sequential(
176
+ ConvNormRelu(input_dim, input_dim, '1d', False),
177
+ ConvNormRelu(input_dim, dim, '1d', False),
178
+ ConvNormRelu(dim, dim, '1d', False)
179
+ )
180
+ self.gru = nn.GRU(dim, dim, 1, batch_first=True)
181
+ self.down2 = ConvNormRelu(dim, dim, '1d', True)
182
+ self.down3 = ConvNormRelu(dim, dim, '1d', True)
183
+ self.down4 = ConvNormRelu(dim, dim, '1d', True)
184
+ self.down5 = ConvNormRelu(dim, dim, '1d', True)
185
+ self.down6 = ConvNormRelu(dim, dim, '1d', True)
186
+ self.up1 = UnetUp(dim, dim)
187
+ self.up2 = UnetUp(dim, dim)
188
+ self.up3 = UnetUp(dim, dim)
189
+ self.up4 = UnetUp(dim, dim)
190
+ self.up5 = UnetUp(dim, dim)
191
+
192
+ def forward(self, x1, pre_pose=None, w_pre=False):
193
+ x2_0 = self.down1(x1)
194
+ if w_pre:
195
+ i = 1
196
+ x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1)
197
+ x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1)
198
+ # x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15]
199
+ else:
200
+ # x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2)
201
+ x2 = x2_0
202
+ x3 = self.down2(x2)
203
+ x4 = self.down3(x3)
204
+ x5 = self.down4(x4)
205
+ x6 = self.down5(x5)
206
+ x7 = self.down6(x6)
207
+ x = self.up1(x7, x6)
208
+ x = self.up2(x, x5)
209
+ x = self.up3(x, x4)
210
+ x = self.up4(x, x3)
211
+ x = self.up5(x, x2) # [B, 512, 15]
212
+ return x, x2_0
213
+
214
+
215
+ class AudioEncoder(nn.Module):
216
+ def __init__(self, n_frames, template_length, pose=False, common_dim=512):
217
+ super().__init__()
218
+ self.n_frames = n_frames
219
+ self.pose = pose
220
+ self.step = 0
221
+ self.weight = 0
222
+ if self.pose:
223
+ # self.first_net = nn.Sequential(
224
+ # ConvNormRelu(1, 64, '2d', False),
225
+ # ConvNormRelu(64, 64, '2d', True),
226
+ # ConvNormRelu(64, 128, '2d', False),
227
+ # ConvNormRelu(128, 128, '2d', True),
228
+ # ConvNormRelu(128, 256, '2d', False),
229
+ # ConvNormRelu(256, 256, '2d', True),
230
+ # ConvNormRelu(256, 256, '2d', False),
231
+ # ConvNormRelu(256, 256, '2d', False, padding='VALID')
232
+ # )
233
+ # decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4,
234
+ # dim_feedforward=2 * args.feature_dim, batch_first=True)
235
+ # a = nn.TransformerDecoder
236
+ self.first_net = SeqTranslator1D(256, 256,
237
+ min_layers_num=4,
238
+ residual=True
239
+ )
240
+ self.dropout_0 = nn.Dropout(0.1)
241
+ self.mu_fc = nn.Conv1d(256, 128, 1, 1)
242
+ self.var_fc = nn.Conv1d(256, 128, 1, 1)
243
+ self.trans_motion = SeqTranslator1D(common_dim, common_dim,
244
+ kernel_size=1,
245
+ stride=1,
246
+ min_layers_num=3,
247
+ residual=True
248
+ )
249
+ # self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1)
250
+ self.unet = UNet(128 + template_length, common_dim)
251
+
252
+ else:
253
+ self.first_net = SeqTranslator1D(256, 256,
254
+ min_layers_num=4,
255
+ residual=True
256
+ )
257
+ self.dropout_0 = nn.Dropout(0.1)
258
+ # self.att = nn.MultiheadAttention(256, 4, dropout=0.1)
259
+ self.unet = UNet(256, 256)
260
+ self.dropout_1 = nn.Dropout(0.0)
261
+
262
+ def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False):
263
+ self.step = self.step + 1
264
+ if self.pose:
265
+ spect = spectrogram.transpose(1, 2)
266
+ if w_pre:
267
+ spect = spect[:, :, :]
268
+
269
+ out = self.first_net(spect)
270
+ out = self.dropout_0(out)
271
+
272
+ mu = self.mu_fc(out)
273
+ var = self.var_fc(out)
274
+ audio = self.__reparam(mu, var)
275
+ # audio = out
276
+
277
+ # template = self.trans_motion(template)
278
+ x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1)
279
+ # x1 = out
280
+ #x1, _ = self.att(x1, x1, x1)
281
+ #x1 = x1.permute(1,2,0)
282
+ x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre)
283
+ else:
284
+ spectrogram = spectrogram.transpose(1, 2)
285
+ x1 = self.first_net(spectrogram)#.permute(2,0,1)
286
+ #out, _ = self.att(out, out, out)
287
+ #out = out.permute(1, 2, 0)
288
+ x1 = self.dropout_0(x1)
289
+ x1, x2_0 = self.unet(x1)
290
+ x1 = self.dropout_1(x1)
291
+ mu = None
292
+ var = None
293
+
294
+ return x1, (mu, var), x2_0
295
+
296
+ def __reparam(self, mu, log_var):
297
+ std = torch.exp(0.5 * log_var)
298
+ eps = torch.randn_like(std, device='cuda')
299
+ z = eps * std + mu
300
+ return z
301
+
302
+
303
+ class Generator(nn.Module):
304
+ def __init__(self,
305
+ n_poses,
306
+ pose_dim,
307
+ pose,
308
+ n_pre_poses,
309
+ each_dim: list,
310
+ dim_list: list,
311
+ use_template=False,
312
+ template_length=0,
313
+ training=False,
314
+ device=None,
315
+ separate=False,
316
+ expression=False
317
+ ):
318
+ super().__init__()
319
+
320
+ self.use_template = use_template
321
+ self.template_length = template_length
322
+ self.training = training
323
+ self.device = device
324
+ self.separate = separate
325
+ self.pose = pose
326
+ self.decoderf = True
327
+ self.expression = expression
328
+
329
+ common_dim = 256
330
+
331
+ if self.use_template:
332
+ assert template_length > 0
333
+ # self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device)
334
+ # self.pose_encoder = SeqEncoder1D(
335
+ # C_in=pose_dim,
336
+ # C_out=512,
337
+ # T_in=n_poses,
338
+ # min_layer_nums=6
339
+ #
340
+ # )
341
+ self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim,
342
+ # kernel_size=1,
343
+ # stride=1,
344
+ min_layers_num=3,
345
+ residual=True
346
+ )
347
+ self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
348
+ self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
349
+
350
+ else:
351
+ self.template_length = 0
352
+
353
+ self.gen_length = n_poses
354
+
355
+ self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim)
356
+ self.speech_encoder = AudioEncoder(n_poses, template_length, False)
357
+
358
+ # self.pre_pose_encoder = SeqEncoder1D(
359
+ # C_in=pose_dim,
360
+ # C_out=128,
361
+ # T_in=15,
362
+ # min_layer_nums=3
363
+ #
364
+ # )
365
+ # self.pmu_fc = nn.Linear(128, 64)
366
+ # self.pvar_fc = nn.Linear(128, 64)
367
+
368
+ self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim,
369
+ min_layers_num=5,
370
+ residual=True
371
+ )
372
+ self.decoder_in = 256 + 64
373
+ self.dim_list = dim_list
374
+
375
+ if self.separate:
376
+ self.decoder = nn.ModuleList()
377
+ self.final_out = nn.ModuleList()
378
+
379
+ self.decoder.append(nn.Sequential(
380
+ ConvNormRelu(256, 64),
381
+ ConvNormRelu(64, 64),
382
+ ConvNormRelu(64, 64),
383
+ ))
384
+ self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
385
+
386
+ self.decoder.append(nn.Sequential(
387
+ ConvNormRelu(common_dim, common_dim),
388
+ ConvNormRelu(common_dim, common_dim),
389
+ ConvNormRelu(common_dim, common_dim),
390
+ ))
391
+ self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1))
392
+
393
+ self.decoder.append(nn.Sequential(
394
+ ConvNormRelu(common_dim, common_dim),
395
+ ConvNormRelu(common_dim, common_dim),
396
+ ConvNormRelu(common_dim, common_dim),
397
+ ))
398
+ self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1))
399
+
400
+ if self.expression:
401
+ self.decoder.append(nn.Sequential(
402
+ ConvNormRelu(256, 256),
403
+ ConvNormRelu(256, 256),
404
+ ConvNormRelu(256, 256),
405
+ ))
406
+ self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1))
407
+ else:
408
+ self.decoder = nn.Sequential(
409
+ ConvNormRelu(self.decoder_in, 512),
410
+ ConvNormRelu(512, 512),
411
+ ConvNormRelu(512, 512),
412
+ ConvNormRelu(512, 512),
413
+ ConvNormRelu(512, 512),
414
+ ConvNormRelu(512, 512),
415
+ )
416
+ self.final_out = nn.Conv1d(512, pose_dim, 1, 1)
417
+
418
+ def __reparam(self, mu, log_var):
419
+ std = torch.exp(0.5 * log_var)
420
+ eps = torch.randn_like(std, device=self.device)
421
+ z = eps * std + mu
422
+ return z
423
+
424
+ def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True):
425
+ if time_steps is not None:
426
+ self.gen_length = time_steps
427
+
428
+ if self.use_template:
429
+ if self.training:
430
+ if w_pre:
431
+ in_spec = in_spec[:, 15:, :]
432
+ pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1))
433
+ pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1))
434
+ mu = self.mu_fc(pose_enc)
435
+ var = self.var_fc(pose_enc)
436
+ template = self.__reparam(mu, var)
437
+ else:
438
+ pre_pose = None
439
+ pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
440
+ mu = self.mu_fc(pose_enc)
441
+ var = self.var_fc(pose_enc)
442
+ template = self.__reparam(mu, var)
443
+ elif pre_poses is not None:
444
+ if w_pre:
445
+ pre_pose = pre_poses[:, -1:, :-50]
446
+ if norm:
447
+ pre_pose = pre_pose.reshape(1, -1, 55, 5)
448
+ pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1),
449
+ F.normalize(pre_pose[..., 3:5], dim=-1)],
450
+ dim=-1).reshape(1, -1, 275)
451
+ pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1))
452
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to(
453
+ in_spec.device)
454
+ else:
455
+ pre_pose = None
456
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
457
+ elif gt_poses is not None:
458
+ template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
459
+ elif template is None:
460
+ pre_pose = None
461
+ template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
462
+ else:
463
+ template = None
464
+ mu = None
465
+ var = None
466
+
467
+ a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre)
468
+ s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps)
469
+
470
+ out = []
471
+
472
+ if self.separate:
473
+ for i in range(self.decoder.__len__()):
474
+ if i == 0 or i == 3:
475
+ mid = self.decoder[i](s_f)
476
+ else:
477
+ mid = self.decoder[i](a_t_f)
478
+ mid = self.final_out[i](mid)
479
+ out.append(mid)
480
+ out = torch.cat(out, dim=1)
481
+
482
+ else:
483
+ out = self.decoder(a_t_f)
484
+ out = self.final_out(out)
485
+
486
+ out = out.transpose(1, 2)
487
+
488
+ if self.training:
489
+ if w_pre:
490
+ return out, template, mu, var, (mu2, var2, x2_0, pre_pose)
491
+ else:
492
+ return out, template, mu, var, (mu2, var2, None, None)
493
+ else:
494
+ return out
495
+
496
+
497
+ class Discriminator(nn.Module):
498
+ def __init__(self, pose_dim, pose):
499
+ super().__init__()
500
+ self.net = nn.Sequential(
501
+ Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'),
502
+ nn.LeakyReLU(0.2, True),
503
+ ConvNormRelu(64, 128, '1d', True),
504
+ ConvNormRelu(128, 256, '1d', k=4, s=1),
505
+ Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'),
506
+ )
507
+
508
+ def forward(self, x):
509
+ x = x.transpose(1, 2)
510
+
511
+ out = self.net(x)
512
+ return out
513
+
514
+
515
+ def main():
516
+ d = Discriminator(275, 55)
517
+ x = torch.randn([8, 60, 275])
518
+ result = d(x)
519
+
520
+
521
+ if __name__ == "__main__":
522
+ main()
nets/spg/vqvae_1d.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .wav2vec import Wav2Vec2Model
7
+ from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
8
+
9
+
10
+
11
+ class AudioEncoder(nn.Module):
12
+ def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
13
+ super(AudioEncoder, self).__init__()
14
+ self._num_hiddens = num_hiddens
15
+ self._num_residual_layers = num_residual_layers
16
+ self._num_residual_hiddens = num_residual_hiddens
17
+
18
+ self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
19
+
20
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
21
+ self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
22
+ sample='down')
23
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
24
+ self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
25
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
26
+
27
+ def forward(self, x, frame_num=0):
28
+ h = self.project(x)
29
+ h = self._enc_1(h)
30
+ h = self._down_1(h)
31
+ h = self._enc_2(h)
32
+ h = self._down_2(h)
33
+ h = self._enc_3(h)
34
+ return h
35
+
36
+
37
+ class Wav2VecEncoder(nn.Module):
38
+ def __init__(self, num_hiddens, num_residual_layers):
39
+ super(Wav2VecEncoder, self).__init__()
40
+ self._num_hiddens = num_hiddens
41
+ self._num_residual_layers = num_residual_layers
42
+
43
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(
44
+ "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
45
+ self.audio_encoder.feature_extractor._freeze_parameters()
46
+
47
+ self.project = ConvNormRelu(768, self._num_hiddens, leaky=True)
48
+
49
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
50
+ self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
51
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
52
+ self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
53
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
54
+
55
+ def forward(self, x, frame_num):
56
+ h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2)
57
+ h = self.project(h)
58
+ h = self._enc_1(h)
59
+ h = self._down_1(h)
60
+ h = self._enc_2(h)
61
+ h = self._down_2(h)
62
+ h = self._enc_3(h)
63
+ return h
64
+
65
+
66
+ class Encoder(nn.Module):
67
+ def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
68
+ super(Encoder, self).__init__()
69
+ self._num_hiddens = num_hiddens
70
+ self._num_residual_layers = num_residual_layers
71
+ self._num_residual_hiddens = num_residual_hiddens
72
+
73
+ self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
74
+
75
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
76
+ self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
77
+ sample='down')
78
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
79
+ self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
80
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
81
+
82
+ self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
83
+
84
+ def forward(self, x):
85
+ h = self.project(x)
86
+ h = self._enc_1(h)
87
+ h = self._down_1(h)
88
+ h = self._enc_2(h)
89
+ h = self._down_2(h)
90
+ h = self._enc_3(h)
91
+ h = self.pre_vq_conv(h)
92
+ return h
93
+
94
+
95
+ class Frame_Enc(nn.Module):
96
+ def __init__(self, in_dim, num_hiddens):
97
+ super(Frame_Enc, self).__init__()
98
+ self.in_dim = in_dim
99
+ self.num_hiddens = num_hiddens
100
+
101
+ # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4)
102
+ self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1)
103
+ self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True)
104
+ self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1)
105
+ self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1)
106
+
107
+ def forward(self, x):
108
+ # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1)
109
+ x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1)
110
+ second_last = self.proj_2(x)
111
+ last = self.proj_1(x)
112
+ return second_last, last
113
+
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
118
+ super(Decoder, self).__init__()
119
+ self._num_hiddens = num_hiddens
120
+ self._num_residual_layers = num_residual_layers
121
+ self._num_residual_hiddens = num_residual_hiddens
122
+
123
+ self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
124
+
125
+ self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
126
+ self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
127
+ self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
128
+ self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
129
+ sample='up')
130
+ self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
131
+
132
+ if ae:
133
+ self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4)
134
+ self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True)
135
+ self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True)
136
+
137
+ self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
138
+
139
+ def forward(self, h, last_frame=None):
140
+
141
+ h = self.aft_vq_conv(h)
142
+ h = self._dec_1(h)
143
+ h = self._up_2(h)
144
+ h = self._dec_2(h)
145
+ h = self._up_3(h)
146
+ h = self._dec_3(h)
147
+
148
+ recon = self.project(h)
149
+ return recon, None
150
+
151
+
152
+ class Pre_VQ(nn.Module):
153
+ def __init__(self, num_hiddens, embedding_dim, num_chunks):
154
+ super(Pre_VQ, self).__init__()
155
+ self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks)
156
+ self.bn = nn.GroupNorm(num_chunks, num_hiddens)
157
+ self.relu = nn.ReLU()
158
+ self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks)
159
+
160
+ def forward(self, x):
161
+ x = self.conv(x)
162
+ x = self.bn(x)
163
+ x = self.relu(x)
164
+ x = self.proj(x)
165
+ return x
166
+
167
+
168
+ class VQVAE(nn.Module):
169
+ """VQ-VAE"""
170
+
171
+ def __init__(self, in_dim, embedding_dim, num_embeddings,
172
+ num_hiddens, num_residual_layers, num_residual_hiddens,
173
+ commitment_cost=0.25, decay=0.99, share=False):
174
+ super().__init__()
175
+ self.in_dim = in_dim
176
+ self.embedding_dim = embedding_dim
177
+ self.num_embeddings = num_embeddings
178
+ self.share_code_vq = share
179
+
180
+ self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
181
+ self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
182
+ self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
183
+
184
+ def forward(self, gt_poses, id=None, pre_state=None):
185
+ z = self.encoder(gt_poses.transpose(1, 2))
186
+ if not self.training:
187
+ e, _ = self.vq_layer(z)
188
+ x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
189
+ return e, x_recon
190
+
191
+ e, e_q_loss = self.vq_layer(z)
192
+ gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
193
+
194
+ return e_q_loss, gt_recon.transpose(1, 2)
195
+
196
+ def encode(self, gt_poses, id=None):
197
+ z = self.encoder(gt_poses.transpose(1, 2))
198
+ e, latents = self.vq_layer(z)
199
+ return e, latents
200
+
201
+ def decode(self, b, w, e=None, latents=None, pre_state=None):
202
+ if e is not None:
203
+ x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
204
+ else:
205
+ e = self.vq_layer.quantize(latents)
206
+ e = e.view(b, w, -1).permute(0, 2, 1).contiguous()
207
+ x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
208
+ return x
209
+
210
+
211
+ class AE(nn.Module):
212
+ """VQ-VAE"""
213
+
214
+ def __init__(self, in_dim, embedding_dim, num_embeddings,
215
+ num_hiddens, num_residual_layers, num_residual_hiddens):
216
+ super().__init__()
217
+ self.in_dim = in_dim
218
+ self.embedding_dim = embedding_dim
219
+ self.num_embeddings = num_embeddings
220
+
221
+ self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
222
+ self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True)
223
+
224
+ def forward(self, gt_poses, id=None, pre_state=None):
225
+ z = self.encoder(gt_poses.transpose(1, 2))
226
+ if not self.training:
227
+ x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
228
+ return z, x_recon
229
+ gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
230
+
231
+ return gt_recon.transpose(1, 2)
232
+
233
+ def encode(self, gt_poses, id=None):
234
+ z = self.encoder(gt_poses.transpose(1, 2))
235
+ return z
nets/spg/vqvae_modules.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import datasets, transforms
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+
11
+
12
+ class CasualCT(nn.Module):
13
+ def __init__(self,
14
+ in_channels,
15
+ out_channels,
16
+ leaky=False,
17
+ p=0,
18
+ groups=1, ):
19
+ '''
20
+ conv-bn-relu
21
+ '''
22
+ super(CasualCT, self).__init__()
23
+ padding = 0
24
+ kernel_size = 2
25
+ stride = 2
26
+ in_channels = in_channels * groups
27
+ out_channels = out_channels * groups
28
+
29
+ self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
30
+ kernel_size=kernel_size, stride=stride, padding=padding,
31
+ groups=groups)
32
+ self.norm = nn.BatchNorm1d(out_channels)
33
+ self.dropout = nn.Dropout(p=p)
34
+ if leaky:
35
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
36
+ else:
37
+ self.relu = nn.ReLU()
38
+
39
+ def forward(self, x, **kwargs):
40
+ out = self.norm(self.dropout(self.conv(x)))
41
+ return self.relu(out)
42
+
43
+
44
+ class CasualConv(nn.Module):
45
+ def __init__(self,
46
+ in_channels,
47
+ out_channels,
48
+ leaky=False,
49
+ p=0,
50
+ groups=1,
51
+ downsample=False):
52
+ '''
53
+ conv-bn-relu
54
+ '''
55
+ super(CasualConv, self).__init__()
56
+ padding = 0
57
+ kernel_size = 2
58
+ stride = 1
59
+ self.downsample = downsample
60
+ if self.downsample:
61
+ kernel_size = 2
62
+ stride = 2
63
+
64
+ in_channels = in_channels * groups
65
+ out_channels = out_channels * groups
66
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
67
+ kernel_size=kernel_size, stride=stride, padding=padding,
68
+ groups=groups)
69
+ self.norm = nn.BatchNorm1d(out_channels)
70
+ self.dropout = nn.Dropout(p=p)
71
+ if leaky:
72
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
73
+ else:
74
+ self.relu = nn.ReLU()
75
+
76
+ def forward(self, x, pre_state=None):
77
+ if not self.downsample:
78
+ if pre_state is not None:
79
+ x = torch.cat([pre_state, x], dim=-1)
80
+ else:
81
+ zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device)
82
+ x = torch.cat([zeros, x], dim=-1)
83
+ out = self.norm(self.dropout(self.conv(x)))
84
+ return self.relu(out)
85
+
86
+
87
+ class ConvNormRelu(nn.Module):
88
+ '''
89
+ (B,C_in,H,W) -> (B, C_out, H, W)
90
+ there exist some kernel size that makes the result is not H/s
91
+ #TODO: there might some problems with residual
92
+ '''
93
+
94
+ def __init__(self,
95
+ in_channels,
96
+ out_channels,
97
+ leaky=False,
98
+ sample='none',
99
+ p=0,
100
+ groups=1,
101
+ residual=False,
102
+ norm='bn'):
103
+ '''
104
+ conv-bn-relu
105
+ '''
106
+ super(ConvNormRelu, self).__init__()
107
+ self.residual = residual
108
+ self.norm_type = norm
109
+ padding = 1
110
+
111
+ if sample == 'none':
112
+ kernel_size = 3
113
+ stride = 1
114
+ elif sample == 'one':
115
+ padding = 0
116
+ kernel_size = stride = 1
117
+ else:
118
+ kernel_size = 4
119
+ stride = 2
120
+
121
+ if self.residual:
122
+ if sample == 'down':
123
+ self.residual_layer = nn.Conv1d(
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ kernel_size=kernel_size,
127
+ stride=stride,
128
+ padding=padding)
129
+ elif sample == 'up':
130
+ self.residual_layer = nn.ConvTranspose1d(
131
+ in_channels=in_channels,
132
+ out_channels=out_channels,
133
+ kernel_size=kernel_size,
134
+ stride=stride,
135
+ padding=padding)
136
+ else:
137
+ if in_channels == out_channels:
138
+ self.residual_layer = nn.Identity()
139
+ else:
140
+ self.residual_layer = nn.Sequential(
141
+ nn.Conv1d(
142
+ in_channels=in_channels,
143
+ out_channels=out_channels,
144
+ kernel_size=kernel_size,
145
+ stride=stride,
146
+ padding=padding
147
+ )
148
+ )
149
+
150
+ in_channels = in_channels * groups
151
+ out_channels = out_channels * groups
152
+ if sample == 'up':
153
+ self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
154
+ kernel_size=kernel_size, stride=stride, padding=padding,
155
+ groups=groups)
156
+ else:
157
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
158
+ kernel_size=kernel_size, stride=stride, padding=padding,
159
+ groups=groups)
160
+ self.norm = nn.BatchNorm1d(out_channels)
161
+ self.dropout = nn.Dropout(p=p)
162
+ if leaky:
163
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
164
+ else:
165
+ self.relu = nn.ReLU()
166
+
167
+ def forward(self, x, **kwargs):
168
+ out = self.norm(self.dropout(self.conv(x)))
169
+ if self.residual:
170
+ residual = self.residual_layer(x)
171
+ out += residual
172
+ return self.relu(out)
173
+
174
+
175
+ class Res_CNR_Stack(nn.Module):
176
+ def __init__(self,
177
+ channels,
178
+ layers,
179
+ sample='none',
180
+ leaky=False,
181
+ casual=False,
182
+ ):
183
+ super(Res_CNR_Stack, self).__init__()
184
+
185
+ if casual:
186
+ kernal_size = 1
187
+ padding = 0
188
+ conv = CasualConv
189
+ else:
190
+ kernal_size = 3
191
+ padding = 1
192
+ conv = ConvNormRelu
193
+
194
+ if sample == 'one':
195
+ kernal_size = 1
196
+ padding = 0
197
+
198
+ self._layers = nn.ModuleList()
199
+ for i in range(layers):
200
+ self._layers.append(conv(channels, channels, leaky=leaky, sample=sample))
201
+ self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding)
202
+ self.norm = nn.BatchNorm1d(channels)
203
+ self.relu = nn.ReLU()
204
+
205
+ def forward(self, x, pre_state=None):
206
+ # cur_state = []
207
+ h = x
208
+ for i in range(self._layers.__len__()):
209
+ # cur_state.append(h[..., -1:])
210
+ h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None)
211
+ h = self.norm(self.conv(h))
212
+ return self.relu(h + x)
213
+
214
+
215
+ class ExponentialMovingAverage(nn.Module):
216
+ """Maintains an exponential moving average for a value.
217
+
218
+ This module keeps track of a hidden exponential moving average that is
219
+ initialized as a vector of zeros which is then normalized to give the average.
220
+ This gives us a moving average which isn't biased towards either zero or the
221
+ initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)
222
+
223
+ Initially:
224
+ hidden_0 = 0
225
+ Then iteratively:
226
+ hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay)
227
+ average_i = hidden_i / (1 - decay^i)
228
+ """
229
+
230
+ def __init__(self, init_value, decay):
231
+ super().__init__()
232
+
233
+ self.decay = decay
234
+ self.counter = 0
235
+ self.register_buffer("hidden", torch.zeros_like(init_value))
236
+
237
+ def forward(self, value):
238
+ self.counter += 1
239
+ self.hidden.sub_((self.hidden - value) * (1 - self.decay))
240
+ average = self.hidden / (1 - self.decay ** self.counter)
241
+ return average
242
+
243
+
244
+ class VectorQuantizerEMA(nn.Module):
245
+ """
246
+ VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings.
247
+ Args:
248
+ embedding_dim (int): the dimensionality of the tensors in the
249
+ quantized space. Inputs to the modules must be in this format as well.
250
+ num_embeddings (int): the number of vectors in the quantized space.
251
+ commitment_cost (float): scalar which controls the weighting of the loss terms (see
252
+ equation 4 in the paper - this variable is Beta).
253
+ decay (float): decay for the moving averages.
254
+ epsilon (float): small float constant to avoid numerical instability.
255
+ """
256
+
257
+ def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay,
258
+ epsilon=1e-5):
259
+ super().__init__()
260
+ self.embedding_dim = embedding_dim
261
+ self.num_embeddings = num_embeddings
262
+ self.commitment_cost = commitment_cost
263
+ self.epsilon = epsilon
264
+
265
+ # initialize embeddings as buffers
266
+ embeddings = torch.empty(self.num_embeddings, self.embedding_dim)
267
+ nn.init.xavier_uniform_(embeddings)
268
+ self.register_buffer("embeddings", embeddings)
269
+ self.ema_dw = ExponentialMovingAverage(self.embeddings, decay)
270
+
271
+ # also maintain ema_cluster_size, which record the size of each embedding
272
+ self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay)
273
+
274
+ def forward(self, x):
275
+ # [B, C, H, W] -> [B, H, W, C]
276
+ x = x.permute(0, 2, 1).contiguous()
277
+ # [B, H, W, C] -> [BHW, C]
278
+ flat_x = x.reshape(-1, self.embedding_dim)
279
+
280
+ encoding_indices = self.get_code_indices(flat_x)
281
+ quantized = self.quantize(encoding_indices)
282
+ quantized = quantized.view_as(x) # [B, W, C]
283
+
284
+ if not self.training:
285
+ quantized = quantized.permute(0, 2, 1).contiguous()
286
+ return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2])
287
+
288
+ # update embeddings with EMA
289
+ with torch.no_grad():
290
+ encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
291
+ updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0))
292
+ n = torch.sum(updated_ema_cluster_size)
293
+ updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
294
+ (n + self.num_embeddings * self.epsilon) * n)
295
+ dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster
296
+ updated_ema_dw = self.ema_dw(dw)
297
+ normalised_updated_ema_w = (
298
+ updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1))
299
+ self.embeddings.data = normalised_updated_ema_w
300
+
301
+ # commitment loss
302
+ e_latent_loss = F.mse_loss(x, quantized.detach())
303
+ loss = self.commitment_cost * e_latent_loss
304
+
305
+ # Straight Through Estimator
306
+ quantized = x + (quantized - x).detach()
307
+
308
+ quantized = quantized.permute(0, 2, 1).contiguous()
309
+ return quantized, loss
310
+
311
+ def get_code_indices(self, flat_x):
312
+ # compute L2 distance
313
+ distances = (
314
+ torch.sum(flat_x ** 2, dim=1, keepdim=True) +
315
+ torch.sum(self.embeddings ** 2, dim=1) -
316
+ 2. * torch.matmul(flat_x, self.embeddings.t())
317
+ ) # [N, M]
318
+ encoding_indices = torch.argmin(distances, dim=1) # [N,]
319
+ return encoding_indices
320
+
321
+ def quantize(self, encoding_indices):
322
+ """Returns embedding tensor for a batch of indices."""
323
+ return F.embedding(encoding_indices, self.embeddings)
324
+
325
+
326
+
327
+ class Casual_Encoder(nn.Module):
328
+ def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
329
+ super(Casual_Encoder, self).__init__()
330
+ self._num_hiddens = num_hiddens
331
+ self._num_residual_layers = num_residual_layers
332
+ self._num_residual_hiddens = num_residual_hiddens
333
+
334
+ self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1)
335
+ self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
336
+ self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True)
337
+ self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
338
+ self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True)
339
+ self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
340
+ # self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
341
+
342
+ def forward(self, x):
343
+ h = self.project(x)
344
+ h, _ = self._enc_1(h)
345
+ h = self._down_1(h)
346
+ h, _ = self._enc_2(h)
347
+ h = self._down_2(h)
348
+ h, _ = self._enc_3(h)
349
+ # h = self.pre_vq_conv(h)
350
+ return h
351
+
352
+
353
+ class Casual_Decoder(nn.Module):
354
+ def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
355
+ super(Casual_Decoder, self).__init__()
356
+ self._num_hiddens = num_hiddens
357
+ self._num_residual_layers = num_residual_layers
358
+ self._num_residual_hiddens = num_residual_hiddens
359
+
360
+ # self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
361
+ self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
362
+ self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True)
363
+ self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
364
+ self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True)
365
+ self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
366
+ self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1)
367
+
368
+ def forward(self, h, pre_state=None):
369
+ cur_state = []
370
+ # h = self.aft_vq_conv(x)
371
+ h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None)
372
+ cur_state.append(s)
373
+ h = self._up_2(h)
374
+ h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None)
375
+ cur_state.append(s)
376
+ h = self._up_3(h)
377
+ h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None)
378
+ cur_state.append(s)
379
+ recon = self.project(h)
380
+ return recon, cur_state
nets/spg/wav2vec.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import copy
6
+ import math
7
+ from transformers import Wav2Vec2Model,Wav2Vec2Config
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+ from typing import Optional, Tuple
10
+ _CONFIG_FOR_DOC = "Wav2Vec2Config"
11
+
12
+ # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
13
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
14
+ def _compute_mask_indices(
15
+ shape: Tuple[int, int],
16
+ mask_prob: float,
17
+ mask_length: int,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ min_masks: int = 0,
20
+ ) -> np.ndarray:
21
+ bsz, all_sz = shape
22
+ mask = np.full((bsz, all_sz), False)
23
+
24
+ all_num_mask = int(
25
+ mask_prob * all_sz / float(mask_length)
26
+ + np.random.rand()
27
+ )
28
+ all_num_mask = max(min_masks, all_num_mask)
29
+ mask_idcs = []
30
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
31
+ for i in range(bsz):
32
+ if padding_mask is not None:
33
+ sz = all_sz - padding_mask[i].long().sum().item()
34
+ num_mask = int(
35
+ mask_prob * sz / float(mask_length)
36
+ + np.random.rand()
37
+ )
38
+ num_mask = max(min_masks, num_mask)
39
+ else:
40
+ sz = all_sz
41
+ num_mask = all_num_mask
42
+
43
+ lengths = np.full(num_mask, mask_length)
44
+
45
+ if sum(lengths) == 0:
46
+ lengths[0] = min(mask_length, sz - 1)
47
+
48
+ min_len = min(lengths)
49
+ if sz - min_len <= num_mask:
50
+ min_len = sz - num_mask - 1
51
+
52
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
53
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
54
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
55
+
56
+ min_len = min([len(m) for m in mask_idcs])
57
+ for i, mask_idc in enumerate(mask_idcs):
58
+ if len(mask_idc) > min_len:
59
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
60
+ mask[i, mask_idc] = True
61
+ return mask
62
+
63
+ # linear interpolation layer
64
+ def linear_interpolation(features, input_fps, output_fps, output_len=None):
65
+ features = features.transpose(1, 2)
66
+ seq_len = features.shape[2] / float(input_fps)
67
+ if output_len is None:
68
+ output_len = int(seq_len * output_fps)
69
+ output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear')
70
+ return output_features.transpose(1, 2)
71
+
72
+
73
+ class Wav2Vec2Model(Wav2Vec2Model):
74
+ def __init__(self, config):
75
+ super().__init__(config)
76
+ def forward(
77
+ self,
78
+ input_values,
79
+ attention_mask=None,
80
+ output_attentions=None,
81
+ output_hidden_states=None,
82
+ return_dict=None,
83
+ frame_num=None
84
+ ):
85
+ self.config.output_attentions = True
86
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
87
+ output_hidden_states = (
88
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
89
+ )
90
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91
+
92
+ hidden_states = self.feature_extractor(input_values)
93
+ hidden_states = hidden_states.transpose(1, 2)
94
+
95
+ hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num)
96
+
97
+ if attention_mask is not None:
98
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
99
+ attention_mask = torch.zeros(
100
+ hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
101
+ )
102
+ attention_mask[
103
+ (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
104
+ ] = 1
105
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
106
+
107
+ hidden_states = self.feature_projection(hidden_states)
108
+
109
+ if self.config.apply_spec_augment and self.training:
110
+ batch_size, sequence_length, hidden_size = hidden_states.size()
111
+ if self.config.mask_time_prob > 0:
112
+ mask_time_indices = _compute_mask_indices(
113
+ (batch_size, sequence_length),
114
+ self.config.mask_time_prob,
115
+ self.config.mask_time_length,
116
+ attention_mask=attention_mask,
117
+ min_masks=2,
118
+ )
119
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
120
+ if self.config.mask_feature_prob > 0:
121
+ mask_feature_indices = _compute_mask_indices(
122
+ (batch_size, hidden_size),
123
+ self.config.mask_feature_prob,
124
+ self.config.mask_feature_length,
125
+ )
126
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
127
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
128
+ encoder_outputs = self.encoder(
129
+ hidden_states[0],
130
+ attention_mask=attention_mask,
131
+ output_attentions=output_attentions,
132
+ output_hidden_states=output_hidden_states,
133
+ return_dict=return_dict,
134
+ )
135
+ hidden_states = encoder_outputs[0]
136
+ if not return_dict:
137
+ return (hidden_states,) + encoder_outputs[1:]
138
+
139
+ return BaseModelOutput(
140
+ last_hidden_state=hidden_states,
141
+ hidden_states=encoder_outputs.hidden_states,
142
+ attentions=encoder_outputs.attentions,
143
+ )
nets/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import textgrid as tg
3
+ import numpy as np
4
+
5
+ def get_parameter_size(model):
6
+ total_num = sum(p.numel() for p in model.parameters())
7
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
8
+ return total_num, trainable_num
9
+
10
+ def denormalize(kps, data_mean, data_std):
11
+ '''
12
+ kps: (B, T, C)
13
+ '''
14
+ data_std = data_std.reshape(1, 1, -1)
15
+ data_mean = data_mean.reshape(1, 1, -1)
16
+ return (kps * data_std) + data_mean
17
+
18
+ def normalize(kps, data_mean, data_std):
19
+ '''
20
+ kps: (B, T, C)
21
+ '''
22
+ data_std = data_std.squeeze().reshape(1, 1, -1)
23
+ data_mean = data_mean.squeeze().reshape(1, 1, -1)
24
+
25
+ return (kps-data_mean) / data_std
26
+
27
+ def parse_audio(textgrid_file):
28
+ '''a demo implementation'''
29
+ words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when']
30
+ txt=tg.TextGrid.fromFile(textgrid_file)
31
+
32
+ total_time=int(np.ceil(txt.maxTime))
33
+ code_seq=np.zeros(total_time)
34
+
35
+ word_level=txt[0]
36
+
37
+ for i in range(len(word_level)):
38
+ start_time=word_level[i].minTime
39
+ end_time=word_level[i].maxTime
40
+ mark=word_level[i].mark
41
+
42
+ if mark in words:
43
+ start=int(np.round(start_time))
44
+ end=int(np.round(end_time))
45
+
46
+ if start >= len(code_seq) or end >= len(code_seq):
47
+ code_seq[-1] = 1
48
+ else:
49
+ code_seq[start]=1
50
+
51
+ return code_seq
52
+
53
+
54
+ def get_path(model_name, model_type):
55
+ if model_name == 's2g_body_pixel':
56
+ if model_type == 'mfcc':
57
+ return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
58
+ elif model_type == 'wv2':
59
+ return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth'
60
+ elif model_type == 'random':
61
+ return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth'
62
+ elif model_type == 'wbhmodel':
63
+ return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth'
64
+ elif model_type == 'wobhmodel':
65
+ return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth'
66
+ elif model_name == 's2g_body':
67
+ if model_type == 'a+m-vae':
68
+ return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth'
69
+ elif model_type == 'a-vae':
70
+ return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth'
71
+ elif model_type == 'a-ed':
72
+ return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth'
73
+ elif model_name == 's2g_LS3DCG':
74
+ return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
75
+ elif model_name == 's2g_body_vq':
76
+ if model_type == 'n_com_1024':
77
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth'
78
+ elif model_type == 'n_com_2048':
79
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth'
80
+ elif model_type == 'n_com_4096':
81
+ return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth'
82
+ elif model_type == 'n_com_8192':
83
+ return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth'
84
+ elif model_type == 'n_com_16384':
85
+ return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth'
86
+ elif model_type == 'n_com_170000':
87
+ return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth'
88
+ elif model_type == 'com_1024':
89
+ return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth'
90
+ elif model_type == 'com_2048':
91
+ return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth'
92
+ elif model_type == 'com_4096':
93
+ return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth'
94
+ elif model_type == 'com_8192':
95
+ return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth'
96
+ elif model_type == 'com_16384':
97
+ return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth'
98
+
99
+
100
+ def get_dpath(model_name, model_type):
101
+ if model_name == 's2g_body_pixel':
102
+ if model_type == 'audio':
103
+ return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth'
104
+ elif model_type == 'wv2':
105
+ return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth'
106
+ elif model_type == 'random':
107
+ return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth'
108
+ elif model_type == 'wbhmodel':
109
+ return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth'
110
+ # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth'
111
+ elif model_type == 'wobhmodel':
112
+ return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth'
113
+ # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth'
114
+ elif model_name == 's2g_body':
115
+ if model_type == 'a+m-vae':
116
+ return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth'
117
+ elif model_type == 'a-vae':
118
+ return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth'
119
+ elif model_type == 'a-ed':
120
+ return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth'
121
+ elif model_name == 's2g_LS3DCG':
122
+ return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth'
scripts/.idea/__init__.py ADDED
File without changes
scripts/.idea/aws.xml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="accountSettings">
4
+ <option name="activeRegion" value="us-east-1" />
5
+ <option name="recentlyUsedRegions">
6
+ <list>
7
+ <option value="us-east-1" />
8
+ </list>
9
+ </option>
10
+ </component>
11
+ </project>
scripts/.idea/deployment.xml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
4
+ <serverData>
5
+ <paths name="80ti">
6
+ <serverdata>
7
+ <mappings>
8
+ <mapping local="$PROJECT_DIR$" web="/" />
9
+ </mappings>
10
+ </serverdata>
11
+ </paths>
12
+ <paths name="80ti (1)">
13
+ <serverdata>
14
+ <mappings>
15
+ <mapping local="$PROJECT_DIR$" web="/" />
16
+ </mappings>
17
+ </serverdata>
18
+ </paths>
19
+ <paths name="80ti (2)">
20
+ <serverdata>
21
+ <mappings>
22
+ <mapping local="$PROJECT_DIR$" web="/" />
23
+ </mappings>
24
+ </serverdata>
25
+ </paths>
26
+ <paths name="80ti (3)">
27
+ <serverdata>
28
+ <mappings>
29
+ <mapping local="$PROJECT_DIR$" web="/" />
30
+ </mappings>
31
+ </serverdata>
32
+ </paths>
33
+ <paths name="[email protected]:22">
34
+ <serverdata>
35
+ <mappings>
36
+ <mapping local="$PROJECT_DIR$" web="/" />
37
+ </mappings>
38
+ </serverdata>
39
+ </paths>
40
+ <paths name="titan">
41
+ <serverdata>
42
+ <mappings>
43
+ <mapping local="$PROJECT_DIR$" web="/" />
44
+ </mappings>
45
+ </serverdata>
46
+ </paths>
47
+ <paths name="titan (1)">
48
+ <serverdata>
49
+ <mappings>
50
+ <mapping local="$PROJECT_DIR$" web="/" />
51
+ </mappings>
52
+ </serverdata>
53
+ </paths>
54
+ <paths name="titan (2)">
55
+ <serverdata>
56
+ <mappings>
57
+ <mapping local="$PROJECT_DIR$" web="/" />
58
+ </mappings>
59
+ </serverdata>
60
+ </paths>
61
+ <paths name="titan (3)">
62
+ <serverdata>
63
+ <mappings>
64
+ <mapping local="$PROJECT_DIR$" web="/" />
65
+ </mappings>
66
+ </serverdata>
67
+ </paths>
68
+ </serverData>
69
+ </component>
70
+ </project>
scripts/.idea/get_prevar.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4
+
5
+ sys.path.append(os.getcwd())
6
+ from glob import glob
7
+
8
+ import numpy as np
9
+ import json
10
+ import smplx as smpl
11
+
12
+ from nets import *
13
+ from repro_nets import *
14
+ from trainer.options import parse_args
15
+ from data_utils import torch_data
16
+ from trainer.config import load_JsonConfig
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils import data
22
+
23
+ def init_model(model_name, model_path, args, config):
24
+ if model_name == 'freeMo':
25
+ # generator = freeMo_Generator(args)
26
+ # generator = freeMo_Generator(args)
27
+ generator = freeMo_dev(args, config)
28
+ # generator.load_state_dict(torch.load(model_path)['generator'])
29
+ elif model_name == 'smplx_S2G':
30
+ generator = smplx_S2G(args, config)
31
+ elif model_name == 'StyleGestures':
32
+ generator = StyleGesture_Generator(
33
+ args,
34
+ config
35
+ )
36
+ elif model_name == 'Audio2Gestures':
37
+ config.Train.using_mspec_stat = False
38
+ generator = Audio2Gesture_Generator(
39
+ args,
40
+ config,
41
+ torch.zeros([1, 1, 108]),
42
+ torch.ones([1, 1, 108])
43
+ )
44
+ elif model_name == 'S2G':
45
+ generator = S2G_Generator(
46
+ args,
47
+ config,
48
+ )
49
+ elif model_name == 'Tmpt':
50
+ generator = S2G_Generator(
51
+ args,
52
+ config,
53
+ )
54
+ else:
55
+ raise NotImplementedError
56
+
57
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
58
+ if model_name == 'smplx_S2G':
59
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
60
+ elif 'generator' in list(model_ckpt.keys()):
61
+ generator.load_state_dict(model_ckpt['generator'])
62
+ else:
63
+ model_ckpt = {'generator': model_ckpt}
64
+ generator.load_state_dict(model_ckpt)
65
+
66
+ return generator
67
+
68
+
69
+
70
+ def prevar_loader(data_root, speakers, args, config, model_path, device, generator):
71
+ path = model_path.split('ckpt')[0]
72
+ file = os.path.join(os.path.dirname(path), "pre_variable.npy")
73
+ data_base = torch_data(
74
+ data_root=data_root,
75
+ speakers=speakers,
76
+ split='pre',
77
+ limbscaling=False,
78
+ normalization=config.Data.pose.normalization,
79
+ norm_method=config.Data.pose.norm_method,
80
+ split_trans_zero=False,
81
+ num_pre_frames=config.Data.pose.pre_pose_length,
82
+ num_generate_length=config.Data.pose.generate_length,
83
+ num_frames=15,
84
+ aud_feat_win_size=config.Data.aud.aud_feat_win_size,
85
+ aud_feat_dim=config.Data.aud.aud_feat_dim,
86
+ feat_method=config.Data.aud.feat_method,
87
+ smplx=True,
88
+ audio_sr=22000,
89
+ convert_to_6d=config.Data.pose.convert_to_6d,
90
+ expression=config.Data.pose.expression
91
+ )
92
+
93
+ data_base.get_dataset()
94
+ pre_set = data_base.all_dataset
95
+ pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True)
96
+
97
+ total_pose = []
98
+
99
+ with torch.no_grad():
100
+ for bat in pre_loader:
101
+ pose = bat['poses'].to(device).to(torch.float32)
102
+ expression = bat['expression'].to(device).to(torch.float32)
103
+ pose = pose.permute(0, 2, 1)
104
+ pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0)
105
+ expression = expression.permute(0, 2, 1)
106
+ expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0)
107
+ pose = torch.cat([pose, expression], dim=-1)
108
+ pose = pose.reshape(pose.shape[0], -1, 1)
109
+ pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu()
110
+ total_pose.append(np.asarray(pose_code))
111
+ total_pose = np.concatenate(total_pose, axis=0)
112
+ mean = np.mean(total_pose, axis=0)
113
+ std = np.std(total_pose, axis=0)
114
+ prevar = (mean, std)
115
+ np.save(file, prevar, allow_pickle=True)
116
+
117
+ return mean, std
118
+
119
+ def main():
120
+ parser = parse_args()
121
+ args = parser.parse_args()
122
+ device = torch.device(args.gpu)
123
+ torch.cuda.set_device(device)
124
+
125
+ config = load_JsonConfig(args.config_file)
126
+
127
+ print('init model...')
128
+ generator = init_model(config.Model.model_name, args.model_path, args, config)
129
+ print('init pre-pose vectors...')
130
+ mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator)
131
+
132
+ main()