essential files only
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- evaluation/FGD.py +199 -0
- evaluation/__init__.py +0 -0
- evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
- evaluation/__pycache__/metrics.cpython-37.pyc +0 -0
- evaluation/diversity_LVD.py +64 -0
- evaluation/get_quality_samples.py +62 -0
- evaluation/metrics.py +109 -0
- evaluation/mode_transition.py +60 -0
- evaluation/peak_velocity.py +65 -0
- evaluation/util.py +148 -0
- losses/__init__.py +1 -0
- losses/__pycache__/__init__.cpython-37.pyc +0 -0
- losses/__pycache__/losses.cpython-37.pyc +0 -0
- losses/losses.py +91 -0
- nets/LS3DCG.py +414 -0
- nets/__init__.py +8 -0
- nets/__pycache__/LS3DCG.cpython-37.pyc +0 -0
- nets/__pycache__/__init__.cpython-37.pyc +0 -0
- nets/__pycache__/base.cpython-37.pyc +0 -0
- nets/__pycache__/body_ae.cpython-37.pyc +0 -0
- nets/__pycache__/init_model.cpython-37.pyc +0 -0
- nets/__pycache__/layers.cpython-37.pyc +0 -0
- nets/__pycache__/smplx_body_pixel.cpython-37.pyc +0 -0
- nets/__pycache__/smplx_body_vq.cpython-37.pyc +0 -0
- nets/__pycache__/smplx_face.cpython-37.pyc +0 -0
- nets/__pycache__/utils.cpython-37.pyc +0 -0
- nets/base.py +89 -0
- nets/body_ae.py +152 -0
- nets/init_model.py +35 -0
- nets/layers.py +1052 -0
- nets/smplx_body_pixel.py +326 -0
- nets/smplx_body_vq.py +302 -0
- nets/smplx_face.py +238 -0
- nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc +0 -0
- nets/spg/__pycache__/s2g_face.cpython-37.pyc +0 -0
- nets/spg/__pycache__/s2glayers.cpython-37.pyc +0 -0
- nets/spg/__pycache__/vqvae_1d.cpython-37.pyc +0 -0
- nets/spg/__pycache__/vqvae_modules.cpython-37.pyc +0 -0
- nets/spg/__pycache__/wav2vec.cpython-37.pyc +0 -0
- nets/spg/gated_pixelcnn_v2.py +179 -0
- nets/spg/s2g_face.py +226 -0
- nets/spg/s2glayers.py +522 -0
- nets/spg/vqvae_1d.py +235 -0
- nets/spg/vqvae_modules.py +380 -0
- nets/spg/wav2vec.py +143 -0
- nets/utils.py +122 -0
- scripts/.idea/__init__.py +0 -0
- scripts/.idea/aws.xml +11 -0
- scripts/.idea/deployment.xml +70 -0
- scripts/.idea/get_prevar.py +132 -0
evaluation/FGD.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from scipy import linalg
|
7 |
+
import math
|
8 |
+
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
|
12 |
+
|
13 |
+
|
14 |
+
change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
|
15 |
+
class EmbeddingSpaceEvaluator:
|
16 |
+
def __init__(self, ae, vae, device):
|
17 |
+
|
18 |
+
# init embed net
|
19 |
+
self.ae = ae
|
20 |
+
# self.vae = vae
|
21 |
+
|
22 |
+
# storage
|
23 |
+
self.real_feat_list = []
|
24 |
+
self.generated_feat_list = []
|
25 |
+
self.real_joints_list = []
|
26 |
+
self.generated_joints_list = []
|
27 |
+
self.real_6d_list = []
|
28 |
+
self.generated_6d_list = []
|
29 |
+
self.audio_beat_list = []
|
30 |
+
|
31 |
+
def reset(self):
|
32 |
+
self.real_feat_list = []
|
33 |
+
self.generated_feat_list = []
|
34 |
+
|
35 |
+
def get_no_of_samples(self):
|
36 |
+
return len(self.real_feat_list)
|
37 |
+
|
38 |
+
def push_samples(self, generated_poses, real_poses):
|
39 |
+
# self.net.eval()
|
40 |
+
# convert poses to latent features
|
41 |
+
real_feat, real_poses = self.ae.extract(real_poses)
|
42 |
+
generated_feat, generated_poses = self.ae.extract(generated_poses)
|
43 |
+
|
44 |
+
num_joints = real_poses.shape[2] // 3
|
45 |
+
|
46 |
+
real_feat = real_feat.squeeze()
|
47 |
+
generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
|
48 |
+
|
49 |
+
self.real_feat_list.append(real_feat.data.cpu().numpy())
|
50 |
+
self.generated_feat_list.append(generated_feat.data.cpu().numpy())
|
51 |
+
|
52 |
+
# real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
53 |
+
# generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
54 |
+
#
|
55 |
+
# self.real_feat_list.append(real_poses.data.cpu().numpy())
|
56 |
+
# self.generated_feat_list.append(generated_poses.data.cpu().numpy())
|
57 |
+
|
58 |
+
def push_joints(self, generated_poses, real_poses):
|
59 |
+
self.real_joints_list.append(real_poses.data.cpu())
|
60 |
+
self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
|
61 |
+
|
62 |
+
def push_aud(self, aud):
|
63 |
+
self.audio_beat_list.append(aud.squeeze().data.cpu())
|
64 |
+
|
65 |
+
def get_MAAC(self):
|
66 |
+
ang_vel_list = []
|
67 |
+
for real_joints in self.real_joints_list:
|
68 |
+
real_joints[:, 15:21] = real_joints[:, 16:22]
|
69 |
+
vec = real_joints[:, 15:21] - real_joints[:, 13:19]
|
70 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
71 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
72 |
+
angle = torch.acos(inner_product) / math.pi
|
73 |
+
ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
|
74 |
+
ang_vel_list.append(ang_vel.unsqueeze(dim=0))
|
75 |
+
all_vel = torch.cat(ang_vel_list, dim=0)
|
76 |
+
MAAC = all_vel.mean(dim=0)
|
77 |
+
return MAAC
|
78 |
+
|
79 |
+
def get_BCscore(self):
|
80 |
+
thres = 0.01
|
81 |
+
sigma = 0.1
|
82 |
+
sum_1 = 0
|
83 |
+
total_beat = 0
|
84 |
+
for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
|
85 |
+
motion_beat_time = []
|
86 |
+
if joints.dim() == 4:
|
87 |
+
joints = joints[0]
|
88 |
+
joints[:, 15:21] = joints[:, 16:22]
|
89 |
+
vec = joints[:, 15:21] - joints[:, 13:19]
|
90 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
91 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
92 |
+
angle = torch.acos(inner_product) / math.pi
|
93 |
+
ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
|
94 |
+
|
95 |
+
angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
|
96 |
+
|
97 |
+
sum_2 = 0
|
98 |
+
for i in range(angle_diff.shape[1]):
|
99 |
+
motion_beat_time = []
|
100 |
+
for t in range(1, joints.shape[0]-1):
|
101 |
+
if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
|
102 |
+
if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
|
103 |
+
t][i] >= thres):
|
104 |
+
motion_beat_time.append(float(t) / 30.0)
|
105 |
+
if (len(motion_beat_time) == 0):
|
106 |
+
continue
|
107 |
+
motion_beat_time = torch.tensor(motion_beat_time)
|
108 |
+
sum = 0
|
109 |
+
for audio in audio_beat_time:
|
110 |
+
sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
|
111 |
+
sum_2 = sum_2 + sum
|
112 |
+
total_beat = total_beat + len(audio_beat_time)
|
113 |
+
sum_1 = sum_1 + sum_2
|
114 |
+
return sum_1/total_beat
|
115 |
+
|
116 |
+
|
117 |
+
def get_scores(self):
|
118 |
+
generated_feats = np.vstack(self.generated_feat_list)
|
119 |
+
real_feats = np.vstack(self.real_feat_list)
|
120 |
+
|
121 |
+
def frechet_distance(samples_A, samples_B):
|
122 |
+
A_mu = np.mean(samples_A, axis=0)
|
123 |
+
A_sigma = np.cov(samples_A, rowvar=False)
|
124 |
+
B_mu = np.mean(samples_B, axis=0)
|
125 |
+
B_sigma = np.cov(samples_B, rowvar=False)
|
126 |
+
try:
|
127 |
+
frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
|
128 |
+
except ValueError:
|
129 |
+
frechet_dist = 1e+10
|
130 |
+
return frechet_dist
|
131 |
+
|
132 |
+
####################################################################
|
133 |
+
# frechet distance
|
134 |
+
frechet_dist = frechet_distance(generated_feats, real_feats)
|
135 |
+
|
136 |
+
####################################################################
|
137 |
+
# distance between real and generated samples on the latent feature space
|
138 |
+
dists = []
|
139 |
+
for i in range(real_feats.shape[0]):
|
140 |
+
d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
|
141 |
+
dists.append(d)
|
142 |
+
feat_dist = np.mean(dists)
|
143 |
+
|
144 |
+
return frechet_dist, feat_dist
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
148 |
+
""" from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
|
149 |
+
"""Numpy implementation of the Frechet Distance.
|
150 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
151 |
+
and X_2 ~ N(mu_2, C_2) is
|
152 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
153 |
+
Stable version by Dougal J. Sutherland.
|
154 |
+
Params:
|
155 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
156 |
+
inception net (like returned by the function 'get_predictions')
|
157 |
+
for generated samples.
|
158 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
159 |
+
representative data set.
|
160 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
161 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
162 |
+
representative data set.
|
163 |
+
Returns:
|
164 |
+
-- : The Frechet Distance.
|
165 |
+
"""
|
166 |
+
|
167 |
+
mu1 = np.atleast_1d(mu1)
|
168 |
+
mu2 = np.atleast_1d(mu2)
|
169 |
+
|
170 |
+
sigma1 = np.atleast_2d(sigma1)
|
171 |
+
sigma2 = np.atleast_2d(sigma2)
|
172 |
+
|
173 |
+
assert mu1.shape == mu2.shape, \
|
174 |
+
'Training and test mean vectors have different lengths'
|
175 |
+
assert sigma1.shape == sigma2.shape, \
|
176 |
+
'Training and test covariances have different dimensions'
|
177 |
+
|
178 |
+
diff = mu1 - mu2
|
179 |
+
|
180 |
+
# Product might be almost singular
|
181 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
182 |
+
if not np.isfinite(covmean).all():
|
183 |
+
msg = ('fid calculation produces singular product; '
|
184 |
+
'adding %s to diagonal of cov estimates') % eps
|
185 |
+
print(msg)
|
186 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
187 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
188 |
+
|
189 |
+
# Numerical error might give slight imaginary component
|
190 |
+
if np.iscomplexobj(covmean):
|
191 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
192 |
+
m = np.max(np.abs(covmean.imag))
|
193 |
+
raise ValueError('Imaginary component {}'.format(m))
|
194 |
+
covmean = covmean.real
|
195 |
+
|
196 |
+
tr_covmean = np.trace(covmean)
|
197 |
+
|
198 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
199 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
evaluation/__init__.py
ADDED
File without changes
|
evaluation/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (181 Bytes). View file
|
|
evaluation/__pycache__/metrics.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
evaluation/diversity_LVD.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LVD: different initial pose
|
3 |
+
diversity: same initial pose
|
4 |
+
'''
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
sys.path.append(os.getcwd())
|
8 |
+
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
import json
|
13 |
+
|
14 |
+
from evaluation.util import *
|
15 |
+
from evaluation.metrics import *
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
parser = ArgumentParser()
|
19 |
+
parser.add_argument('--speaker', required=True, type=str)
|
20 |
+
parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
speaker = args.speaker
|
24 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
25 |
+
|
26 |
+
LVD_list = []
|
27 |
+
diversity_list = []
|
28 |
+
|
29 |
+
for aud in tqdm(test_audios):
|
30 |
+
base_name = os.path.splitext(aud)[0]
|
31 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
32 |
+
_, gt_poses, _ = get_gts(gt_path)
|
33 |
+
gt_poses = gt_poses[np.newaxis,...]
|
34 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
35 |
+
for post_fix in args.post_fix:
|
36 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
37 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
38 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
39 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
40 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
41 |
+
|
42 |
+
gt_valid_points = hand_points(gt_poses)
|
43 |
+
pred_valid_points = hand_points(pred_poses)
|
44 |
+
|
45 |
+
lvd = LVD(gt_valid_points, pred_valid_points)
|
46 |
+
# div = diversity(pred_valid_points)
|
47 |
+
|
48 |
+
LVD_list.append(lvd)
|
49 |
+
# diversity_list.append(div)
|
50 |
+
|
51 |
+
# gt_velocity = peak_velocity(gt_valid_points, order=2)
|
52 |
+
# pred_velocity = peak_velocity(pred_valid_points, order=2)
|
53 |
+
|
54 |
+
# gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
55 |
+
# pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
56 |
+
|
57 |
+
# gt_consistency_list.append(gt_consistency)
|
58 |
+
# pred_consistency_list.append(pred_consistency)
|
59 |
+
|
60 |
+
lvd = np.mean(LVD_list)
|
61 |
+
# diversity_list = np.mean(diversity_list)
|
62 |
+
|
63 |
+
print('LVD:', lvd)
|
64 |
+
# print("diversity:", diversity_list)
|
evaluation/get_quality_samples.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
'''
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
from glob import glob
|
8 |
+
|
9 |
+
from argparse import ArgumentParser
|
10 |
+
import json
|
11 |
+
|
12 |
+
from evaluation.util import *
|
13 |
+
from evaluation.metrics import *
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
parser = ArgumentParser()
|
17 |
+
parser.add_argument('--speaker', required=True, type=str)
|
18 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
speaker = args.speaker
|
22 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
23 |
+
|
24 |
+
quality_samples={'gt':[]}
|
25 |
+
for post_fix in args.post_fix:
|
26 |
+
quality_samples[post_fix] = []
|
27 |
+
|
28 |
+
for aud in tqdm(test_audios):
|
29 |
+
base_name = os.path.splitext(aud)[0]
|
30 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
31 |
+
_, gt_poses, _ = get_gts(gt_path)
|
32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
33 |
+
gt_valid_points = valid_points(gt_poses)
|
34 |
+
# print(gt_valid_points.shape)
|
35 |
+
quality_samples['gt'].append(gt_valid_points)
|
36 |
+
|
37 |
+
for post_fix in args.post_fix:
|
38 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
39 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
40 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
41 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
42 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
43 |
+
|
44 |
+
pred_valid_points = valid_points(pred_poses)[0:1]
|
45 |
+
quality_samples[post_fix].append(pred_valid_points)
|
46 |
+
|
47 |
+
quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
|
48 |
+
for post_fix in args.post_fix:
|
49 |
+
quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
|
50 |
+
|
51 |
+
print('gt:', quality_samples['gt'].shape)
|
52 |
+
quality_samples['gt'] = quality_samples['gt'].tolist()
|
53 |
+
for post_fix in args.post_fix:
|
54 |
+
print(post_fix, ':', quality_samples[post_fix].shape)
|
55 |
+
quality_samples[post_fix] = quality_samples[post_fix].tolist()
|
56 |
+
|
57 |
+
save_dir = '../../experiments/'
|
58 |
+
os.makedirs(save_dir, exist_ok=True)
|
59 |
+
save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
|
60 |
+
with open(save_name, 'w') as f:
|
61 |
+
json.dump(quality_samples, f)
|
62 |
+
|
evaluation/metrics.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Warning: metrics are for reference only, may have limited significance
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data_utils.lower_body import rearrange, symmetry
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
def data_driven_baselines(gt_kps):
|
14 |
+
'''
|
15 |
+
gt_kps: T, D
|
16 |
+
'''
|
17 |
+
gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
|
18 |
+
|
19 |
+
mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
|
20 |
+
mean = np.mean(np.abs(gt_velocity-mean))
|
21 |
+
last_step = gt_kps[1] - gt_kps[0]
|
22 |
+
last_step = last_step[np.newaxis] #(1, D)
|
23 |
+
last_step = np.mean(np.abs(gt_velocity-last_step))
|
24 |
+
return last_step, mean
|
25 |
+
|
26 |
+
def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
|
27 |
+
if gt_kps.shape[0] > pr_kps.shape[1]:
|
28 |
+
length = pr_kps.shape[1]
|
29 |
+
else:
|
30 |
+
length = gt_kps.shape[0]
|
31 |
+
gt_kps = gt_kps[:length]
|
32 |
+
pr_kps = pr_kps[:, :length]
|
33 |
+
global symmetry
|
34 |
+
symmetry = torch.tensor(symmetry).bool()
|
35 |
+
|
36 |
+
if symmetrical:
|
37 |
+
# rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
|
38 |
+
gt_kps = gt_kps[:, rearrange]
|
39 |
+
ns_gt_kps = gt_kps[:, ~symmetry]
|
40 |
+
ys_gt_kps = gt_kps[:, symmetry]
|
41 |
+
ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
|
42 |
+
ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
|
43 |
+
ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
|
44 |
+
left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
|
45 |
+
right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
|
46 |
+
move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
|
47 |
+
ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
|
48 |
+
ys_gt_velocity = ys_gt_velocity.transpose(0,1)
|
49 |
+
gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
|
50 |
+
|
51 |
+
pr_kps = pr_kps[:, :, rearrange]
|
52 |
+
ns_pr_kps = pr_kps[:, :, ~symmetry]
|
53 |
+
ys_pr_kps = pr_kps[:, :, symmetry]
|
54 |
+
ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
|
55 |
+
ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
56 |
+
ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
57 |
+
left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
|
58 |
+
right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
|
59 |
+
move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
|
60 |
+
torch.zeros(left_pr_vel.shape).cuda())
|
61 |
+
ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
|
62 |
+
ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
|
63 |
+
ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
|
64 |
+
pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
|
65 |
+
else:
|
66 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
67 |
+
pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
|
68 |
+
|
69 |
+
if weight:
|
70 |
+
w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
|
71 |
+
else:
|
72 |
+
w = 1 / gt_velocity.shape[0]
|
73 |
+
|
74 |
+
v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
|
75 |
+
|
76 |
+
return v_diff
|
77 |
+
|
78 |
+
|
79 |
+
def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
|
80 |
+
gt_kps = gt_kps.squeeze()
|
81 |
+
pr_kps = pr_kps.squeeze()
|
82 |
+
if len(pr_kps.shape) == 4:
|
83 |
+
return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
|
84 |
+
# length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
|
85 |
+
length = gt_kps.shape[0]-10
|
86 |
+
# gt_kps = gt_kps[25:length]
|
87 |
+
# pr_kps = pr_kps[25:length] #(T, D)
|
88 |
+
# if pr_kps.shape[0] < gt_kps.shape[0]:
|
89 |
+
# pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
|
90 |
+
|
91 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
92 |
+
pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
|
93 |
+
|
94 |
+
return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
|
95 |
+
|
96 |
+
def diversity(kps):
|
97 |
+
'''
|
98 |
+
kps: bs, seq, dim
|
99 |
+
'''
|
100 |
+
dis_list = []
|
101 |
+
#the distance between each pair
|
102 |
+
for i in range(kps.shape[0]):
|
103 |
+
for j in range(i+1, kps.shape[0]):
|
104 |
+
seq_i = kps[i]
|
105 |
+
seq_j = kps[j]
|
106 |
+
|
107 |
+
dis = np.mean(np.abs(seq_i - seq_j))
|
108 |
+
dis_list.append(dis)
|
109 |
+
return np.mean(dis_list)
|
evaluation/mode_transition.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
import json
|
9 |
+
|
10 |
+
from evaluation.util import *
|
11 |
+
from evaluation.metrics import *
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
speaker = args.speaker
|
20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
21 |
+
|
22 |
+
precision_list=[]
|
23 |
+
recall_list=[]
|
24 |
+
accuracy_list=[]
|
25 |
+
|
26 |
+
for aud in tqdm(test_audios):
|
27 |
+
base_name = os.path.splitext(aud)[0]
|
28 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
29 |
+
_, gt_poses, _ = get_gts(gt_path)
|
30 |
+
if gt_poses.shape[0] < 50:
|
31 |
+
continue
|
32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
33 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
34 |
+
for post_fix in args.post_fix:
|
35 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
36 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
37 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
38 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
39 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
40 |
+
|
41 |
+
gt_valid_points = valid_points(gt_poses)
|
42 |
+
pred_valid_points = valid_points(pred_poses)
|
43 |
+
|
44 |
+
# print(gt_valid_points.shape, pred_valid_points.shape)
|
45 |
+
|
46 |
+
gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
|
47 |
+
pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
|
48 |
+
|
49 |
+
# baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
|
50 |
+
# pred_mode_transition_seq = baseline
|
51 |
+
precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
|
52 |
+
precision_list.append(precision)
|
53 |
+
recall_list.append(recall)
|
54 |
+
accuracy_list.append(accuracy)
|
55 |
+
print(len(precision_list), len(recall_list), len(accuracy_list))
|
56 |
+
precision_list = np.mean(precision_list)
|
57 |
+
recall_list = np.mean(recall_list)
|
58 |
+
accuracy_list = np.mean(accuracy_list)
|
59 |
+
|
60 |
+
print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
|
evaluation/peak_velocity.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
import json
|
9 |
+
|
10 |
+
from evaluation.util import *
|
11 |
+
from evaluation.metrics import *
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
speaker = args.speaker
|
20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
21 |
+
|
22 |
+
gt_consistency_list=[]
|
23 |
+
pred_consistency_list=[]
|
24 |
+
|
25 |
+
for aud in tqdm(test_audios):
|
26 |
+
base_name = os.path.splitext(aud)[0]
|
27 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
28 |
+
_, gt_poses, _ = get_gts(gt_path)
|
29 |
+
gt_poses = gt_poses[np.newaxis,...]
|
30 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
31 |
+
for post_fix in args.post_fix:
|
32 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
33 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
34 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
35 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
36 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
37 |
+
|
38 |
+
gt_valid_points = hand_points(gt_poses)
|
39 |
+
pred_valid_points = hand_points(pred_poses)
|
40 |
+
|
41 |
+
gt_velocity = peak_velocity(gt_valid_points, order=2)
|
42 |
+
pred_velocity = peak_velocity(pred_valid_points, order=2)
|
43 |
+
|
44 |
+
gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
45 |
+
pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
46 |
+
|
47 |
+
gt_consistency_list.append(gt_consistency)
|
48 |
+
pred_consistency_list.append(pred_consistency)
|
49 |
+
|
50 |
+
gt_consistency_list = np.concatenate(gt_consistency_list)
|
51 |
+
pred_consistency_list = np.concatenate(pred_consistency_list)
|
52 |
+
|
53 |
+
print(gt_consistency_list.max(), gt_consistency_list.min())
|
54 |
+
print(pred_consistency_list.max(), pred_consistency_list.min())
|
55 |
+
print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
|
56 |
+
print(np.std(gt_consistency_list), np.std(pred_consistency_list))
|
57 |
+
|
58 |
+
draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
|
59 |
+
draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
|
60 |
+
|
61 |
+
to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
|
62 |
+
to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
|
63 |
+
|
64 |
+
np.save('%s_gt.npy'%(speaker), gt_consistency_list)
|
65 |
+
np.save('%s_pred.npy'%(speaker), pred_consistency_list)
|
evaluation/util.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
import pandas as pd
|
7 |
+
def get_gts(clip):
|
8 |
+
'''
|
9 |
+
clip: abs path to the clip dir
|
10 |
+
'''
|
11 |
+
keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
|
12 |
+
|
13 |
+
upper_body_points = list(np.arange(0, 25))
|
14 |
+
poses = []
|
15 |
+
confs = []
|
16 |
+
neck_to_nose_len = []
|
17 |
+
mean_position = []
|
18 |
+
for kp_file in keypoints_files:
|
19 |
+
kp_load = json.load(open(kp_file, 'r'))['people'][0]
|
20 |
+
posepts = kp_load['pose_keypoints_2d']
|
21 |
+
lhandpts = kp_load['hand_left_keypoints_2d']
|
22 |
+
rhandpts = kp_load['hand_right_keypoints_2d']
|
23 |
+
facepts = kp_load['face_keypoints_2d']
|
24 |
+
|
25 |
+
neck = np.array(posepts).reshape(-1,3)[1]
|
26 |
+
nose = np.array(posepts).reshape(-1,3)[0]
|
27 |
+
x_offset = abs(neck[0]-nose[0])
|
28 |
+
y_offset = abs(neck[1]-nose[1])
|
29 |
+
neck_to_nose_len.append(y_offset)
|
30 |
+
mean_position.append([neck[0],neck[1]])
|
31 |
+
|
32 |
+
keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
|
33 |
+
|
34 |
+
upper_body = keypoints[upper_body_points, :]
|
35 |
+
hand_points = keypoints[25:, :]
|
36 |
+
keypoints = np.vstack([upper_body, hand_points])
|
37 |
+
|
38 |
+
poses.append(keypoints)
|
39 |
+
|
40 |
+
if len(neck_to_nose_len) > 0:
|
41 |
+
scale_factor = np.mean(neck_to_nose_len)
|
42 |
+
else:
|
43 |
+
raise ValueError(clip)
|
44 |
+
mean_position = np.mean(np.array(mean_position), axis=0)
|
45 |
+
|
46 |
+
unlocalized_poses = np.array(poses).copy()
|
47 |
+
localized_poses = []
|
48 |
+
for i in range(len(poses)):
|
49 |
+
keypoints = poses[i]
|
50 |
+
neck = keypoints[1].copy()
|
51 |
+
|
52 |
+
keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
|
53 |
+
keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
|
54 |
+
localized_poses.append(keypoints.reshape(-1))
|
55 |
+
|
56 |
+
localized_poses=np.array(localized_poses)
|
57 |
+
return unlocalized_poses, localized_poses, (scale_factor, mean_position)
|
58 |
+
|
59 |
+
def get_full_path(wav_name, speaker, split):
|
60 |
+
'''
|
61 |
+
get clip path from aud file
|
62 |
+
'''
|
63 |
+
wav_name = os.path.basename(wav_name)
|
64 |
+
wav_name = os.path.splitext(wav_name)[0]
|
65 |
+
clip_name, vid_name = wav_name[:10], wav_name[11:]
|
66 |
+
|
67 |
+
full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
|
68 |
+
|
69 |
+
assert os.path.isdir(full_path), full_path
|
70 |
+
|
71 |
+
return full_path
|
72 |
+
|
73 |
+
def smooth(res):
|
74 |
+
'''
|
75 |
+
res: (B, seq_len, pose_dim)
|
76 |
+
'''
|
77 |
+
window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
|
78 |
+
w_size=7
|
79 |
+
for i in range(10, res.shape[1]-3):
|
80 |
+
window.append(res[:, i+3, :])
|
81 |
+
if len(window) > w_size:
|
82 |
+
window = window[1:]
|
83 |
+
|
84 |
+
if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
|
85 |
+
res[:, i, :] = np.mean(window, axis=1)
|
86 |
+
|
87 |
+
return res
|
88 |
+
|
89 |
+
def cvt25(pred_poses, gt_poses=None):
|
90 |
+
'''
|
91 |
+
gt_poses: (1, seq_len, 270), 135 *2
|
92 |
+
pred_poses: (B, seq_len, 108), 54 * 2
|
93 |
+
'''
|
94 |
+
if gt_poses is None:
|
95 |
+
gt_poses = np.zeros_like(pred_poses)
|
96 |
+
else:
|
97 |
+
gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
|
98 |
+
|
99 |
+
length = min(pred_poses.shape[1], gt_poses.shape[1])
|
100 |
+
pred_poses = pred_poses[:, :length, :]
|
101 |
+
gt_poses = gt_poses[:, :length, :]
|
102 |
+
gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
|
103 |
+
pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
|
104 |
+
|
105 |
+
gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
|
106 |
+
gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
|
107 |
+
|
108 |
+
return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
|
109 |
+
|
110 |
+
def hand_points(seq):
|
111 |
+
'''
|
112 |
+
seq: (B, seq_len, 135*2)
|
113 |
+
hands only
|
114 |
+
'''
|
115 |
+
hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
|
116 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
117 |
+
return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
118 |
+
|
119 |
+
def valid_points(seq):
|
120 |
+
'''
|
121 |
+
hands with some head points
|
122 |
+
'''
|
123 |
+
valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
|
124 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
125 |
+
|
126 |
+
seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
127 |
+
assert seq.shape[-1] == 108, seq.shape
|
128 |
+
return seq
|
129 |
+
|
130 |
+
def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
|
131 |
+
plt.figure()
|
132 |
+
plt.hist(seq, bins=100, range=(0, 100), color=color)
|
133 |
+
plt.savefig(save_name)
|
134 |
+
|
135 |
+
def to_excel(seq, save_name='res.xlsx'):
|
136 |
+
'''
|
137 |
+
seq: (T)
|
138 |
+
'''
|
139 |
+
df = pd.DataFrame(seq)
|
140 |
+
writer = pd.ExcelWriter(save_name)
|
141 |
+
df.to_excel(writer, 'sheet1')
|
142 |
+
writer.save()
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
random_data = np.random.randint(0, 10, 100)
|
148 |
+
draw_cdf(random_data)
|
losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .losses import *
|
losses/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (140 Bytes). View file
|
|
losses/__pycache__/losses.cpython-37.pyc
ADDED
Binary file (3.5 kB). View file
|
|
losses/losses.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
class KeypointLoss(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super(KeypointLoss, self).__init__()
|
14 |
+
|
15 |
+
def forward(self, pred_seq, gt_seq, gt_conf=None):
|
16 |
+
#pred_seq: (B, C, T)
|
17 |
+
if gt_conf is not None:
|
18 |
+
gt_conf = gt_conf >= 0.01
|
19 |
+
return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
|
20 |
+
else:
|
21 |
+
return F.mse_loss(pred_seq, gt_seq)
|
22 |
+
|
23 |
+
|
24 |
+
class KLLoss(nn.Module):
|
25 |
+
def __init__(self, kl_tolerance):
|
26 |
+
super(KLLoss, self).__init__()
|
27 |
+
self.kl_tolerance = kl_tolerance
|
28 |
+
|
29 |
+
def forward(self, mu, var, mul=1):
|
30 |
+
kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
|
31 |
+
kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
|
32 |
+
# kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
|
33 |
+
if self.kl_tolerance is not None:
|
34 |
+
# above_line = kld_loss[kld_loss > self.kl_tolerance]
|
35 |
+
# if len(above_line) > 0:
|
36 |
+
# kld_loss = torch.mean(kld_loss)
|
37 |
+
# else:
|
38 |
+
# kld_loss = 0
|
39 |
+
kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
|
40 |
+
# else:
|
41 |
+
kld_loss = torch.mean(kld_loss)
|
42 |
+
return kld_loss
|
43 |
+
|
44 |
+
|
45 |
+
class L2KLLoss(nn.Module):
|
46 |
+
def __init__(self, kl_tolerance):
|
47 |
+
super(L2KLLoss, self).__init__()
|
48 |
+
self.kl_tolerance = kl_tolerance
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
# TODO: check
|
52 |
+
kld_loss = torch.sum(x ** 2, dim=1)
|
53 |
+
if self.kl_tolerance is not None:
|
54 |
+
above_line = kld_loss[kld_loss > self.kl_tolerance]
|
55 |
+
if len(above_line) > 0:
|
56 |
+
kld_loss = torch.mean(kld_loss)
|
57 |
+
else:
|
58 |
+
kld_loss = 0
|
59 |
+
else:
|
60 |
+
kld_loss = torch.mean(kld_loss)
|
61 |
+
return kld_loss
|
62 |
+
|
63 |
+
class L2RegLoss(nn.Module):
|
64 |
+
def __init__(self):
|
65 |
+
super(L2RegLoss, self).__init__()
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
#TODO: check
|
69 |
+
return torch.sum(x**2)
|
70 |
+
|
71 |
+
|
72 |
+
class L2Loss(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super(L2Loss, self).__init__()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
# TODO: check
|
78 |
+
return torch.sum(x ** 2)
|
79 |
+
|
80 |
+
|
81 |
+
class AudioLoss(nn.Module):
|
82 |
+
def __init__(self):
|
83 |
+
super(AudioLoss, self).__init__()
|
84 |
+
|
85 |
+
def forward(self, dynamics, gt_poses):
|
86 |
+
#pay attention, normalized
|
87 |
+
mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
|
88 |
+
gt = gt_poses - mean
|
89 |
+
return F.mse_loss(dynamics, gt)
|
90 |
+
|
91 |
+
L1Loss = nn.L1Loss
|
nets/LS3DCG.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
not exactly the same as the official repo but the results are good
|
3 |
+
'''
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
+
from data_utils.lower_body import c_index_3d, c_index_6d
|
8 |
+
|
9 |
+
sys.path.append(os.getcwd())
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import math
|
17 |
+
|
18 |
+
from nets.base import TrainWrapperBaseClass
|
19 |
+
from nets.layers import SeqEncoder1D
|
20 |
+
from losses import KeypointLoss, L1Loss, KLLoss
|
21 |
+
from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
|
22 |
+
from nets.utils import denormalize
|
23 |
+
|
24 |
+
class Conv1d_tf(nn.Conv1d):
|
25 |
+
"""
|
26 |
+
Conv1d with the padding behavior from TF
|
27 |
+
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, *args, **kwargs):
|
31 |
+
super(Conv1d_tf, self).__init__(*args, **kwargs)
|
32 |
+
self.padding = kwargs.get("padding", "same")
|
33 |
+
|
34 |
+
def _compute_padding(self, input, dim):
|
35 |
+
input_size = input.size(dim + 2)
|
36 |
+
filter_size = self.weight.size(dim + 2)
|
37 |
+
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
|
38 |
+
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
|
39 |
+
total_padding = max(
|
40 |
+
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
|
41 |
+
)
|
42 |
+
additional_padding = int(total_padding % 2 != 0)
|
43 |
+
|
44 |
+
return additional_padding, total_padding
|
45 |
+
|
46 |
+
def forward(self, input):
|
47 |
+
if self.padding == "VALID":
|
48 |
+
return F.conv1d(
|
49 |
+
input,
|
50 |
+
self.weight,
|
51 |
+
self.bias,
|
52 |
+
self.stride,
|
53 |
+
padding=0,
|
54 |
+
dilation=self.dilation,
|
55 |
+
groups=self.groups,
|
56 |
+
)
|
57 |
+
rows_odd, padding_rows = self._compute_padding(input, dim=0)
|
58 |
+
if rows_odd:
|
59 |
+
input = F.pad(input, [0, rows_odd])
|
60 |
+
|
61 |
+
return F.conv1d(
|
62 |
+
input,
|
63 |
+
self.weight,
|
64 |
+
self.bias,
|
65 |
+
self.stride,
|
66 |
+
padding=(padding_rows // 2),
|
67 |
+
dilation=self.dilation,
|
68 |
+
groups=self.groups,
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
|
73 |
+
if k is None and s is None:
|
74 |
+
if not downsample:
|
75 |
+
k = 3
|
76 |
+
s = 1
|
77 |
+
else:
|
78 |
+
k = 4
|
79 |
+
s = 2
|
80 |
+
|
81 |
+
if type == '1d':
|
82 |
+
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
83 |
+
if norm == 'bn':
|
84 |
+
norm_block = nn.BatchNorm1d(out_channels)
|
85 |
+
elif norm == 'ln':
|
86 |
+
norm_block = nn.LayerNorm(out_channels)
|
87 |
+
elif type == '2d':
|
88 |
+
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
89 |
+
norm_block = nn.BatchNorm2d(out_channels)
|
90 |
+
else:
|
91 |
+
assert False
|
92 |
+
|
93 |
+
return nn.Sequential(
|
94 |
+
conv_block,
|
95 |
+
norm_block,
|
96 |
+
nn.LeakyReLU(0.2, True)
|
97 |
+
)
|
98 |
+
|
99 |
+
class Decoder(nn.Module):
|
100 |
+
def __init__(self, in_ch, out_ch):
|
101 |
+
super(Decoder, self).__init__()
|
102 |
+
self.up1 = nn.Sequential(
|
103 |
+
ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
|
104 |
+
ConvNormRelu(in_ch // 2, in_ch // 2),
|
105 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
106 |
+
)
|
107 |
+
self.up2 = nn.Sequential(
|
108 |
+
ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
|
109 |
+
ConvNormRelu(in_ch // 4, in_ch // 4),
|
110 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
111 |
+
)
|
112 |
+
self.up3 = nn.Sequential(
|
113 |
+
ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
|
114 |
+
ConvNormRelu(in_ch // 8, in_ch // 8),
|
115 |
+
nn.Conv1d(in_ch // 8, out_ch, 1, 1)
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, x, x1, x2, x3):
|
119 |
+
x = F.interpolate(x, x3.shape[2])
|
120 |
+
x = torch.cat([x, x3], dim=1)
|
121 |
+
x = self.up1(x)
|
122 |
+
x = F.interpolate(x, x2.shape[2])
|
123 |
+
x = torch.cat([x, x2], dim=1)
|
124 |
+
x = self.up2(x)
|
125 |
+
x = F.interpolate(x, x1.shape[2])
|
126 |
+
x = torch.cat([x, x1], dim=1)
|
127 |
+
x = self.up3(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class EncoderDecoder(nn.Module):
|
132 |
+
def __init__(self, n_frames, each_dim):
|
133 |
+
super().__init__()
|
134 |
+
self.n_frames = n_frames
|
135 |
+
|
136 |
+
self.down1 = nn.Sequential(
|
137 |
+
ConvNormRelu(64, 64, '1d', False),
|
138 |
+
ConvNormRelu(64, 128, '1d', False),
|
139 |
+
)
|
140 |
+
self.down2 = nn.Sequential(
|
141 |
+
ConvNormRelu(128, 128, '1d', False),
|
142 |
+
ConvNormRelu(128, 256, '1d', False),
|
143 |
+
)
|
144 |
+
self.down3 = nn.Sequential(
|
145 |
+
ConvNormRelu(256, 256, '1d', False),
|
146 |
+
ConvNormRelu(256, 512, '1d', False),
|
147 |
+
)
|
148 |
+
self.down4 = nn.Sequential(
|
149 |
+
ConvNormRelu(512, 512, '1d', False),
|
150 |
+
ConvNormRelu(512, 1024, '1d', False),
|
151 |
+
)
|
152 |
+
|
153 |
+
self.down = nn.MaxPool1d(kernel_size=2)
|
154 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
155 |
+
|
156 |
+
self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
|
157 |
+
self.body_decoder = Decoder(1024, each_dim[1])
|
158 |
+
self.hand_decoder = Decoder(1024, each_dim[2])
|
159 |
+
|
160 |
+
def forward(self, spectrogram, time_steps=None):
|
161 |
+
if time_steps is None:
|
162 |
+
time_steps = self.n_frames
|
163 |
+
|
164 |
+
x1 = self.down1(spectrogram)
|
165 |
+
x = self.down(x1)
|
166 |
+
x2 = self.down2(x)
|
167 |
+
x = self.down(x2)
|
168 |
+
x3 = self.down3(x)
|
169 |
+
x = self.down(x3)
|
170 |
+
x = self.down4(x)
|
171 |
+
x = self.up(x)
|
172 |
+
|
173 |
+
face = self.face_decoder(x, x1, x2, x3)
|
174 |
+
body = self.body_decoder(x, x1, x2, x3)
|
175 |
+
hand = self.hand_decoder(x, x1, x2, x3)
|
176 |
+
|
177 |
+
return face, body, hand
|
178 |
+
|
179 |
+
|
180 |
+
class Generator(nn.Module):
|
181 |
+
def __init__(self,
|
182 |
+
each_dim,
|
183 |
+
training=False,
|
184 |
+
device=None
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.training = training
|
189 |
+
self.device = device
|
190 |
+
|
191 |
+
self.encoderdecoder = EncoderDecoder(15, each_dim)
|
192 |
+
|
193 |
+
def forward(self, in_spec, time_steps=None):
|
194 |
+
if time_steps is not None:
|
195 |
+
self.gen_length = time_steps
|
196 |
+
|
197 |
+
face, body, hand = self.encoderdecoder(in_spec)
|
198 |
+
out = torch.cat([face, body, hand], dim=1)
|
199 |
+
out = out.transpose(1, 2)
|
200 |
+
|
201 |
+
return out
|
202 |
+
|
203 |
+
|
204 |
+
class Discriminator(nn.Module):
|
205 |
+
def __init__(self, input_dim):
|
206 |
+
super().__init__()
|
207 |
+
self.net = nn.Sequential(
|
208 |
+
ConvNormRelu(input_dim, 128, '1d'),
|
209 |
+
ConvNormRelu(128, 256, '1d'),
|
210 |
+
nn.MaxPool1d(kernel_size=2),
|
211 |
+
ConvNormRelu(256, 256, '1d'),
|
212 |
+
ConvNormRelu(256, 512, '1d'),
|
213 |
+
nn.MaxPool1d(kernel_size=2),
|
214 |
+
ConvNormRelu(512, 512, '1d'),
|
215 |
+
ConvNormRelu(512, 1024, '1d'),
|
216 |
+
nn.MaxPool1d(kernel_size=2),
|
217 |
+
nn.Conv1d(1024, 1, 1, 1),
|
218 |
+
nn.Sigmoid()
|
219 |
+
)
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
x = x.transpose(1, 2)
|
223 |
+
|
224 |
+
out = self.net(x)
|
225 |
+
return out
|
226 |
+
|
227 |
+
|
228 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
229 |
+
def __init__(self, args, config) -> None:
|
230 |
+
self.args = args
|
231 |
+
self.config = config
|
232 |
+
self.device = torch.device(self.args.gpu)
|
233 |
+
self.global_step = 0
|
234 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
235 |
+
self.init_params()
|
236 |
+
|
237 |
+
self.generator = Generator(
|
238 |
+
each_dim=self.each_dim,
|
239 |
+
training=not self.args.infer,
|
240 |
+
device=self.device,
|
241 |
+
).to(self.device)
|
242 |
+
self.discriminator = Discriminator(
|
243 |
+
input_dim=self.each_dim[1] + self.each_dim[2] + 64
|
244 |
+
).to(self.device)
|
245 |
+
if self.convert_to_6d:
|
246 |
+
self.c_index = c_index_6d
|
247 |
+
else:
|
248 |
+
self.c_index = c_index_3d
|
249 |
+
self.MSELoss = KeypointLoss().to(self.device)
|
250 |
+
self.L1Loss = L1Loss().to(self.device)
|
251 |
+
super().__init__(args, config)
|
252 |
+
|
253 |
+
def init_params(self):
|
254 |
+
scale = 1
|
255 |
+
|
256 |
+
global_orient = round(0 * scale)
|
257 |
+
leye_pose = reye_pose = round(0 * scale)
|
258 |
+
jaw_pose = round(3 * scale)
|
259 |
+
body_pose = round((63 - 24) * scale)
|
260 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
261 |
+
|
262 |
+
expression = 100
|
263 |
+
|
264 |
+
b_j = 0
|
265 |
+
jaw_dim = jaw_pose
|
266 |
+
b_e = b_j + jaw_dim
|
267 |
+
eye_dim = leye_pose + reye_pose
|
268 |
+
b_b = b_e + eye_dim
|
269 |
+
body_dim = global_orient + body_pose
|
270 |
+
b_h = b_b + body_dim
|
271 |
+
hand_dim = left_hand_pose + right_hand_pose
|
272 |
+
b_f = b_h + hand_dim
|
273 |
+
face_dim = expression
|
274 |
+
|
275 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
276 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
277 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
278 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
279 |
+
|
280 |
+
def __call__(self, bat):
|
281 |
+
assert (not self.args.infer), "infer mode"
|
282 |
+
self.global_step += 1
|
283 |
+
|
284 |
+
loss_dict = {}
|
285 |
+
|
286 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
287 |
+
expression = bat['expression'].to(self.device).to(torch.float32)
|
288 |
+
jaw = poses[:, :3, :]
|
289 |
+
poses = poses[:, self.c_index, :]
|
290 |
+
|
291 |
+
pred = self.generator(in_spec=aud)
|
292 |
+
|
293 |
+
D_loss, D_loss_dict = self.get_loss(
|
294 |
+
pred_poses=pred.detach(),
|
295 |
+
gt_poses=poses,
|
296 |
+
aud=aud,
|
297 |
+
mode='training_D',
|
298 |
+
)
|
299 |
+
|
300 |
+
self.discriminator_optimizer.zero_grad()
|
301 |
+
D_loss.backward()
|
302 |
+
self.discriminator_optimizer.step()
|
303 |
+
|
304 |
+
G_loss, G_loss_dict = self.get_loss(
|
305 |
+
pred_poses=pred,
|
306 |
+
gt_poses=poses,
|
307 |
+
aud=aud,
|
308 |
+
expression=expression,
|
309 |
+
jaw=jaw,
|
310 |
+
mode='training_G',
|
311 |
+
)
|
312 |
+
self.generator_optimizer.zero_grad()
|
313 |
+
G_loss.backward()
|
314 |
+
self.generator_optimizer.step()
|
315 |
+
|
316 |
+
total_loss = None
|
317 |
+
loss_dict = {}
|
318 |
+
for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
|
319 |
+
loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
|
320 |
+
|
321 |
+
return total_loss, loss_dict
|
322 |
+
|
323 |
+
def get_loss(self,
|
324 |
+
pred_poses,
|
325 |
+
gt_poses,
|
326 |
+
aud=None,
|
327 |
+
jaw=None,
|
328 |
+
expression=None,
|
329 |
+
mode='training_G',
|
330 |
+
):
|
331 |
+
loss_dict = {}
|
332 |
+
aud = aud.transpose(1, 2)
|
333 |
+
gt_poses = gt_poses.transpose(1, 2)
|
334 |
+
gt_aud = torch.cat([gt_poses, aud], dim=2)
|
335 |
+
pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
|
336 |
+
|
337 |
+
if mode == 'training_D':
|
338 |
+
dis_real = self.discriminator(gt_aud)
|
339 |
+
dis_fake = self.discriminator(pred_aud)
|
340 |
+
dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
|
341 |
+
torch.zeros_like(dis_fake).to(self.device), dis_fake)
|
342 |
+
loss_dict['dis'] = dis_error
|
343 |
+
|
344 |
+
return dis_error, loss_dict
|
345 |
+
elif mode == 'training_G':
|
346 |
+
jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
|
347 |
+
face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
|
348 |
+
body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
|
349 |
+
hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
|
350 |
+
l1_loss = jaw_loss + face_loss + body_loss + hand_loss
|
351 |
+
|
352 |
+
dis_output = self.discriminator(pred_aud)
|
353 |
+
gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
|
354 |
+
gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
|
355 |
+
|
356 |
+
loss_dict['gen'] = gen_error
|
357 |
+
loss_dict['jaw_loss'] = jaw_loss
|
358 |
+
loss_dict['face_loss'] = face_loss
|
359 |
+
loss_dict['body_loss'] = body_loss
|
360 |
+
loss_dict['hand_loss'] = hand_loss
|
361 |
+
return gen_loss, loss_dict
|
362 |
+
else:
|
363 |
+
raise ValueError(mode)
|
364 |
+
|
365 |
+
def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
|
366 |
+
output = []
|
367 |
+
assert self.args.infer, "train mode"
|
368 |
+
self.generator.eval()
|
369 |
+
|
370 |
+
if self.config.Data.pose.normalization:
|
371 |
+
assert norm_stats is not None
|
372 |
+
data_mean = norm_stats[0]
|
373 |
+
data_std = norm_stats[1]
|
374 |
+
|
375 |
+
pre_length = self.config.Data.pose.pre_pose_length
|
376 |
+
generate_length = self.config.Data.pose.generate_length
|
377 |
+
# assert pre_length == initial_pose.shape[-1]
|
378 |
+
# pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
|
379 |
+
# B = pre_poses.shape[0]
|
380 |
+
|
381 |
+
aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
|
382 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
383 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
384 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
pred_poses = self.generator(aud_feat)
|
388 |
+
pred_poses = pred_poses.cpu().numpy()
|
389 |
+
output = pred_poses.squeeze()
|
390 |
+
|
391 |
+
return output
|
392 |
+
|
393 |
+
def generate(self, aud, id):
|
394 |
+
self.generator.eval()
|
395 |
+
pred_poses = self.generator(aud)
|
396 |
+
return pred_poses
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == '__main__':
|
400 |
+
from trainer.options import parse_args
|
401 |
+
|
402 |
+
parser = parse_args()
|
403 |
+
args = parser.parse_args(
|
404 |
+
['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
|
405 |
+
'--infer'])
|
406 |
+
|
407 |
+
generator = TrainWrapper(args)
|
408 |
+
|
409 |
+
aud_fn = '../sample_audio/jon.wav'
|
410 |
+
initial_pose = torch.randn(64, 108, 4)
|
411 |
+
norm_stats = (np.random.randn(108), np.random.randn(108))
|
412 |
+
output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
|
413 |
+
|
414 |
+
print(output.shape)
|
nets/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .smplx_face import TrainWrapper as s2g_face
|
2 |
+
from .smplx_body_vq import TrainWrapper as s2g_body_vq
|
3 |
+
from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
|
4 |
+
from .body_ae import TrainWrapper as s2g_body_ae
|
5 |
+
from .LS3DCG import TrainWrapper as LS3DCG
|
6 |
+
from .base import TrainWrapperBaseClass
|
7 |
+
|
8 |
+
from .utils import normalize, denormalize
|
nets/__pycache__/LS3DCG.cpython-37.pyc
ADDED
Binary file (11.2 kB). View file
|
|
nets/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (431 Bytes). View file
|
|
nets/__pycache__/base.cpython-37.pyc
ADDED
Binary file (2.98 kB). View file
|
|
nets/__pycache__/body_ae.cpython-37.pyc
ADDED
Binary file (4.48 kB). View file
|
|
nets/__pycache__/init_model.cpython-37.pyc
ADDED
Binary file (520 Bytes). View file
|
|
nets/__pycache__/layers.cpython-37.pyc
ADDED
Binary file (22.7 kB). View file
|
|
nets/__pycache__/smplx_body_pixel.cpython-37.pyc
ADDED
Binary file (9.51 kB). View file
|
|
nets/__pycache__/smplx_body_vq.cpython-37.pyc
ADDED
Binary file (7.86 kB). View file
|
|
nets/__pycache__/smplx_face.cpython-37.pyc
ADDED
Binary file (5.96 kB). View file
|
|
nets/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (4.81 kB). View file
|
|
nets/base.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
|
5 |
+
class TrainWrapperBaseClass():
|
6 |
+
def __init__(self, args, config) -> None:
|
7 |
+
self.init_optimizer()
|
8 |
+
|
9 |
+
def init_optimizer(self) -> None:
|
10 |
+
print('using Adam')
|
11 |
+
self.generator_optimizer = optim.Adam(
|
12 |
+
self.generator.parameters(),
|
13 |
+
lr = self.config.Train.learning_rate.generator_learning_rate,
|
14 |
+
betas=[0.9, 0.999]
|
15 |
+
)
|
16 |
+
if self.discriminator is not None:
|
17 |
+
self.discriminator_optimizer = optim.Adam(
|
18 |
+
self.discriminator.parameters(),
|
19 |
+
lr = self.config.Train.learning_rate.discriminator_learning_rate,
|
20 |
+
betas=[0.9, 0.999]
|
21 |
+
)
|
22 |
+
|
23 |
+
def __call__(self, bat):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def get_loss(self, **kwargs):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def state_dict(self):
|
30 |
+
model_state = {
|
31 |
+
'generator': self.generator.state_dict(),
|
32 |
+
'generator_optim': self.generator_optimizer.state_dict(),
|
33 |
+
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
|
34 |
+
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
|
35 |
+
}
|
36 |
+
return model_state
|
37 |
+
|
38 |
+
def parameters(self):
|
39 |
+
return self.generator.parameters()
|
40 |
+
|
41 |
+
def load_state_dict(self, state_dict):
|
42 |
+
if 'generator' in state_dict:
|
43 |
+
self.generator.load_state_dict(state_dict['generator'])
|
44 |
+
else:
|
45 |
+
self.generator.load_state_dict(state_dict)
|
46 |
+
|
47 |
+
if 'generator_optim' in state_dict and self.generator_optimizer is not None:
|
48 |
+
self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
|
49 |
+
|
50 |
+
if self.discriminator is not None:
|
51 |
+
self.discriminator.load_state_dict(state_dict['discriminator'])
|
52 |
+
|
53 |
+
if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
|
54 |
+
self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
|
55 |
+
|
56 |
+
def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def init_params(self):
|
60 |
+
if self.config.Data.pose.convert_to_6d:
|
61 |
+
scale = 2
|
62 |
+
else:
|
63 |
+
scale = 1
|
64 |
+
|
65 |
+
global_orient = round(0 * scale)
|
66 |
+
leye_pose = reye_pose = round(0 * scale)
|
67 |
+
jaw_pose = round(0 * scale)
|
68 |
+
body_pose = round((63 - 24) * scale)
|
69 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
70 |
+
if self.expression:
|
71 |
+
expression = 100
|
72 |
+
else:
|
73 |
+
expression = 0
|
74 |
+
|
75 |
+
b_j = 0
|
76 |
+
jaw_dim = jaw_pose
|
77 |
+
b_e = b_j + jaw_dim
|
78 |
+
eye_dim = leye_pose + reye_pose
|
79 |
+
b_b = b_e + eye_dim
|
80 |
+
body_dim = global_orient + body_pose
|
81 |
+
b_h = b_b + body_dim
|
82 |
+
hand_dim = left_hand_pose + right_hand_pose
|
83 |
+
b_f = b_h + hand_dim
|
84 |
+
face_dim = expression
|
85 |
+
|
86 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
87 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
88 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
89 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
nets/body_ae.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
from nets.base import TrainWrapperBaseClass
|
7 |
+
from nets.spg.s2glayers import Discriminator as D_S2G
|
8 |
+
from nets.spg.vqvae_1d import AE as s2g_body
|
9 |
+
import torch
|
10 |
+
import torch.optim as optim
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from data_utils.lower_body import c_index, c_index_3d, c_index_6d
|
14 |
+
|
15 |
+
|
16 |
+
def separate_aa(aa):
|
17 |
+
aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
|
18 |
+
axis = F.normalize(aa[:, :, :, :3], dim=-1)
|
19 |
+
angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
|
20 |
+
return axis, angle
|
21 |
+
|
22 |
+
|
23 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
24 |
+
'''
|
25 |
+
a wrapper receving a batch from data_utils and calculate loss
|
26 |
+
'''
|
27 |
+
|
28 |
+
def __init__(self, args, config):
|
29 |
+
self.args = args
|
30 |
+
self.config = config
|
31 |
+
self.device = torch.device(self.args.gpu)
|
32 |
+
self.global_step = 0
|
33 |
+
|
34 |
+
self.gan = False
|
35 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
36 |
+
self.preleng = self.config.Data.pose.pre_pose_length
|
37 |
+
self.expression = self.config.Data.pose.expression
|
38 |
+
self.epoch = 0
|
39 |
+
self.init_params()
|
40 |
+
self.num_classes = 4
|
41 |
+
self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
|
42 |
+
num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
43 |
+
if self.gan:
|
44 |
+
self.discriminator = D_S2G(
|
45 |
+
pose_dim=110 + 64, pose=self.pose
|
46 |
+
).to(self.device)
|
47 |
+
else:
|
48 |
+
self.discriminator = None
|
49 |
+
|
50 |
+
if self.convert_to_6d:
|
51 |
+
self.c_index = c_index_6d
|
52 |
+
else:
|
53 |
+
self.c_index = c_index_3d
|
54 |
+
|
55 |
+
super().__init__(args, config)
|
56 |
+
|
57 |
+
def init_optimizer(self):
|
58 |
+
|
59 |
+
self.g_optimizer = optim.Adam(
|
60 |
+
self.g.parameters(),
|
61 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
62 |
+
betas=[0.9, 0.999]
|
63 |
+
)
|
64 |
+
|
65 |
+
def state_dict(self):
|
66 |
+
model_state = {
|
67 |
+
'g': self.g.state_dict(),
|
68 |
+
'g_optim': self.g_optimizer.state_dict(),
|
69 |
+
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
|
70 |
+
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
|
71 |
+
}
|
72 |
+
return model_state
|
73 |
+
|
74 |
+
|
75 |
+
def __call__(self, bat):
|
76 |
+
# assert (not self.args.infer), "infer mode"
|
77 |
+
self.global_step += 1
|
78 |
+
|
79 |
+
total_loss = None
|
80 |
+
loss_dict = {}
|
81 |
+
|
82 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
83 |
+
|
84 |
+
# id = bat['speaker'].to(self.device) - 20
|
85 |
+
# id = F.one_hot(id, self.num_classes)
|
86 |
+
|
87 |
+
poses = poses[:, self.c_index, :]
|
88 |
+
gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
|
89 |
+
|
90 |
+
loss = 0
|
91 |
+
loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
|
92 |
+
|
93 |
+
return total_loss, loss_dict
|
94 |
+
|
95 |
+
def vq_train(self, gt, name, model, dict, total_loss, pre=None):
|
96 |
+
x_recon = model(gt_poses=gt, pre_state=pre)
|
97 |
+
loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
|
98 |
+
# total_loss = total_loss + loss
|
99 |
+
|
100 |
+
if name == 'g':
|
101 |
+
optimizer_name = 'g_optimizer'
|
102 |
+
|
103 |
+
optimizer = getattr(self, optimizer_name)
|
104 |
+
optimizer.zero_grad()
|
105 |
+
loss.backward()
|
106 |
+
optimizer.step()
|
107 |
+
|
108 |
+
for key in list(loss_dict.keys()):
|
109 |
+
dict[name + key] = loss_dict.get(key, 0).item()
|
110 |
+
return dict, total_loss
|
111 |
+
|
112 |
+
def get_loss(self,
|
113 |
+
pred_poses,
|
114 |
+
gt_poses,
|
115 |
+
pre=None
|
116 |
+
):
|
117 |
+
loss_dict = {}
|
118 |
+
|
119 |
+
|
120 |
+
rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
|
121 |
+
v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
|
122 |
+
v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
|
123 |
+
velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
|
124 |
+
|
125 |
+
if pre is None:
|
126 |
+
f0_vel = 0
|
127 |
+
else:
|
128 |
+
v0_pr = pred_poses[:, 0] - pre[:, -1]
|
129 |
+
v0_gt = gt_poses[:, 0] - pre[:, -1]
|
130 |
+
f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
|
131 |
+
|
132 |
+
gen_loss = rec_loss + velocity_loss + f0_vel
|
133 |
+
|
134 |
+
loss_dict['rec_loss'] = rec_loss
|
135 |
+
loss_dict['velocity_loss'] = velocity_loss
|
136 |
+
# loss_dict['e_q_loss'] = e_q_loss
|
137 |
+
if pre is not None:
|
138 |
+
loss_dict['f0_vel'] = f0_vel
|
139 |
+
|
140 |
+
return gen_loss, loss_dict
|
141 |
+
|
142 |
+
def load_state_dict(self, state_dict):
|
143 |
+
self.g.load_state_dict(state_dict['g'])
|
144 |
+
|
145 |
+
def extract(self, x):
|
146 |
+
self.g.eval()
|
147 |
+
if x.shape[2] > self.full_dim:
|
148 |
+
if x.shape[2] == 239:
|
149 |
+
x = x[:, :, 102:]
|
150 |
+
x = x[:, :, self.c_index]
|
151 |
+
feat = self.g.encode(x)
|
152 |
+
return feat.transpose(1, 2), x
|
nets/init_model.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nets import *
|
2 |
+
|
3 |
+
|
4 |
+
def init_model(model_name, args, config):
|
5 |
+
|
6 |
+
if model_name == 's2g_face':
|
7 |
+
generator = s2g_face(
|
8 |
+
args,
|
9 |
+
config,
|
10 |
+
)
|
11 |
+
elif model_name == 's2g_body_vq':
|
12 |
+
generator = s2g_body_vq(
|
13 |
+
args,
|
14 |
+
config,
|
15 |
+
)
|
16 |
+
elif model_name == 's2g_body_pixel':
|
17 |
+
generator = s2g_body_pixel(
|
18 |
+
args,
|
19 |
+
config,
|
20 |
+
)
|
21 |
+
elif model_name == 's2g_body_ae':
|
22 |
+
generator = s2g_body_ae(
|
23 |
+
args,
|
24 |
+
config,
|
25 |
+
)
|
26 |
+
elif model_name == 's2g_LS3DCG':
|
27 |
+
generator = LS3DCG(
|
28 |
+
args,
|
29 |
+
config,
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
raise ValueError
|
33 |
+
return generator
|
34 |
+
|
35 |
+
|
nets/layers.py
ADDED
@@ -0,0 +1,1052 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
# TODO: be aware of the actual netork structures
|
12 |
+
|
13 |
+
def get_log(x):
|
14 |
+
log = 0
|
15 |
+
while x > 1:
|
16 |
+
if x % 2 == 0:
|
17 |
+
x = x // 2
|
18 |
+
log += 1
|
19 |
+
else:
|
20 |
+
raise ValueError('x is not a power of 2')
|
21 |
+
|
22 |
+
return log
|
23 |
+
|
24 |
+
|
25 |
+
class ConvNormRelu(nn.Module):
|
26 |
+
'''
|
27 |
+
(B,C_in,H,W) -> (B, C_out, H, W)
|
28 |
+
there exist some kernel size that makes the result is not H/s
|
29 |
+
#TODO: there might some problems with residual
|
30 |
+
'''
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
in_channels,
|
34 |
+
out_channels,
|
35 |
+
type='1d',
|
36 |
+
leaky=False,
|
37 |
+
downsample=False,
|
38 |
+
kernel_size=None,
|
39 |
+
stride=None,
|
40 |
+
padding=None,
|
41 |
+
p=0,
|
42 |
+
groups=1,
|
43 |
+
residual=False,
|
44 |
+
norm='bn'):
|
45 |
+
'''
|
46 |
+
conv-bn-relu
|
47 |
+
'''
|
48 |
+
super(ConvNormRelu, self).__init__()
|
49 |
+
self.residual = residual
|
50 |
+
self.norm_type = norm
|
51 |
+
# kernel_size = k
|
52 |
+
# stride = s
|
53 |
+
|
54 |
+
if kernel_size is None and stride is None:
|
55 |
+
if not downsample:
|
56 |
+
kernel_size = 3
|
57 |
+
stride = 1
|
58 |
+
else:
|
59 |
+
kernel_size = 4
|
60 |
+
stride = 2
|
61 |
+
|
62 |
+
if padding is None:
|
63 |
+
if isinstance(kernel_size, int) and isinstance(stride, tuple):
|
64 |
+
padding = tuple(int((kernel_size - st) / 2) for st in stride)
|
65 |
+
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
|
66 |
+
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
|
67 |
+
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
|
68 |
+
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
|
69 |
+
else:
|
70 |
+
padding = int((kernel_size - stride) / 2)
|
71 |
+
|
72 |
+
if self.residual:
|
73 |
+
if downsample:
|
74 |
+
if type == '1d':
|
75 |
+
self.residual_layer = nn.Sequential(
|
76 |
+
nn.Conv1d(
|
77 |
+
in_channels=in_channels,
|
78 |
+
out_channels=out_channels,
|
79 |
+
kernel_size=kernel_size,
|
80 |
+
stride=stride,
|
81 |
+
padding=padding
|
82 |
+
)
|
83 |
+
)
|
84 |
+
elif type == '2d':
|
85 |
+
self.residual_layer = nn.Sequential(
|
86 |
+
nn.Conv2d(
|
87 |
+
in_channels=in_channels,
|
88 |
+
out_channels=out_channels,
|
89 |
+
kernel_size=kernel_size,
|
90 |
+
stride=stride,
|
91 |
+
padding=padding
|
92 |
+
)
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
if in_channels == out_channels:
|
96 |
+
self.residual_layer = nn.Identity()
|
97 |
+
else:
|
98 |
+
if type == '1d':
|
99 |
+
self.residual_layer = nn.Sequential(
|
100 |
+
nn.Conv1d(
|
101 |
+
in_channels=in_channels,
|
102 |
+
out_channels=out_channels,
|
103 |
+
kernel_size=kernel_size,
|
104 |
+
stride=stride,
|
105 |
+
padding=padding
|
106 |
+
)
|
107 |
+
)
|
108 |
+
elif type == '2d':
|
109 |
+
self.residual_layer = nn.Sequential(
|
110 |
+
nn.Conv2d(
|
111 |
+
in_channels=in_channels,
|
112 |
+
out_channels=out_channels,
|
113 |
+
kernel_size=kernel_size,
|
114 |
+
stride=stride,
|
115 |
+
padding=padding
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
in_channels = in_channels * groups
|
120 |
+
out_channels = out_channels * groups
|
121 |
+
if type == '1d':
|
122 |
+
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
123 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
124 |
+
groups=groups)
|
125 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
126 |
+
self.dropout = nn.Dropout(p=p)
|
127 |
+
elif type == '2d':
|
128 |
+
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
129 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
130 |
+
groups=groups)
|
131 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
132 |
+
self.dropout = nn.Dropout2d(p=p)
|
133 |
+
if norm == 'gn':
|
134 |
+
self.norm = nn.GroupNorm(2, out_channels)
|
135 |
+
elif norm == 'ln':
|
136 |
+
self.norm = nn.LayerNorm(out_channels)
|
137 |
+
if leaky:
|
138 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
139 |
+
else:
|
140 |
+
self.relu = nn.ReLU()
|
141 |
+
|
142 |
+
def forward(self, x, **kwargs):
|
143 |
+
if self.norm_type == 'ln':
|
144 |
+
out = self.dropout(self.conv(x))
|
145 |
+
out = self.norm(out.transpose(1,2)).transpose(1,2)
|
146 |
+
else:
|
147 |
+
out = self.norm(self.dropout(self.conv(x)))
|
148 |
+
if self.residual:
|
149 |
+
residual = self.residual_layer(x)
|
150 |
+
out += residual
|
151 |
+
return self.relu(out)
|
152 |
+
|
153 |
+
|
154 |
+
class UNet1D(nn.Module):
|
155 |
+
def __init__(self,
|
156 |
+
input_channels,
|
157 |
+
output_channels,
|
158 |
+
max_depth=5,
|
159 |
+
kernel_size=None,
|
160 |
+
stride=None,
|
161 |
+
p=0,
|
162 |
+
groups=1):
|
163 |
+
super(UNet1D, self).__init__()
|
164 |
+
self.pre_downsampling_conv = nn.ModuleList([])
|
165 |
+
self.conv1 = nn.ModuleList([])
|
166 |
+
self.conv2 = nn.ModuleList([])
|
167 |
+
self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
|
168 |
+
self.max_depth = max_depth
|
169 |
+
self.groups = groups
|
170 |
+
|
171 |
+
self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels,
|
172 |
+
type='1d', leaky=True, downsample=False,
|
173 |
+
kernel_size=kernel_size, stride=stride, p=p, groups=groups))
|
174 |
+
self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels,
|
175 |
+
type='1d', leaky=True, downsample=False,
|
176 |
+
kernel_size=kernel_size, stride=stride, p=p, groups=groups))
|
177 |
+
|
178 |
+
for i in range(self.max_depth):
|
179 |
+
self.conv1.append(ConvNormRelu(output_channels, output_channels,
|
180 |
+
type='1d', leaky=True, downsample=True,
|
181 |
+
kernel_size=kernel_size, stride=stride, p=p, groups=groups))
|
182 |
+
|
183 |
+
for i in range(self.max_depth):
|
184 |
+
self.conv2.append(ConvNormRelu(output_channels, output_channels,
|
185 |
+
type='1d', leaky=True, downsample=False,
|
186 |
+
kernel_size=kernel_size, stride=stride, p=p, groups=groups))
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
|
190 |
+
input_size = x.shape[-1]
|
191 |
+
|
192 |
+
assert get_log(
|
193 |
+
input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth'
|
194 |
+
|
195 |
+
x = nn.Sequential(*self.pre_downsampling_conv)(x)
|
196 |
+
|
197 |
+
residuals = []
|
198 |
+
residuals.append(x)
|
199 |
+
for i, conv1 in enumerate(self.conv1):
|
200 |
+
x = conv1(x)
|
201 |
+
if i < self.max_depth - 1:
|
202 |
+
residuals.append(x)
|
203 |
+
|
204 |
+
for i, conv2 in enumerate(self.conv2):
|
205 |
+
x = self.upconv(x) + residuals[self.max_depth - i - 1]
|
206 |
+
x = conv2(x)
|
207 |
+
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
class UNet2D(nn.Module):
|
212 |
+
def __init__(self):
|
213 |
+
super(UNet2D, self).__init__()
|
214 |
+
raise NotImplementedError('2D Unet is wierd')
|
215 |
+
|
216 |
+
|
217 |
+
class AudioPoseEncoder1D(nn.Module):
|
218 |
+
'''
|
219 |
+
(B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T)
|
220 |
+
'''
|
221 |
+
|
222 |
+
def __init__(self,
|
223 |
+
C_in,
|
224 |
+
C_out,
|
225 |
+
kernel_size=None,
|
226 |
+
stride=None,
|
227 |
+
min_layer_nums=None
|
228 |
+
):
|
229 |
+
super(AudioPoseEncoder1D, self).__init__()
|
230 |
+
self.C_in = C_in
|
231 |
+
self.C_out = C_out
|
232 |
+
|
233 |
+
conv_layers = nn.ModuleList([])
|
234 |
+
cur_C = C_in
|
235 |
+
num_layers = 0
|
236 |
+
while cur_C < self.C_out:
|
237 |
+
conv_layers.append(ConvNormRelu(
|
238 |
+
in_channels=cur_C,
|
239 |
+
out_channels=cur_C * 2,
|
240 |
+
kernel_size=kernel_size,
|
241 |
+
stride=stride
|
242 |
+
))
|
243 |
+
cur_C *= 2
|
244 |
+
num_layers += 1
|
245 |
+
|
246 |
+
if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums):
|
247 |
+
while (cur_C != C_out) or num_layers < min_layer_nums:
|
248 |
+
conv_layers.append(ConvNormRelu(
|
249 |
+
in_channels=cur_C,
|
250 |
+
out_channels=C_out,
|
251 |
+
kernel_size=kernel_size,
|
252 |
+
stride=stride
|
253 |
+
))
|
254 |
+
num_layers += 1
|
255 |
+
cur_C = C_out
|
256 |
+
|
257 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
'''
|
261 |
+
x: (B, C, T)
|
262 |
+
'''
|
263 |
+
x = self.conv_layers(x)
|
264 |
+
return x
|
265 |
+
|
266 |
+
|
267 |
+
class AudioPoseEncoder2D(nn.Module):
|
268 |
+
'''
|
269 |
+
(B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T)
|
270 |
+
'''
|
271 |
+
|
272 |
+
def __init__(self):
|
273 |
+
raise NotImplementedError
|
274 |
+
|
275 |
+
|
276 |
+
class AudioPoseEncoderRNN(nn.Module):
|
277 |
+
'''
|
278 |
+
(B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T)
|
279 |
+
'''
|
280 |
+
|
281 |
+
def __init__(self,
|
282 |
+
C_in,
|
283 |
+
hidden_size,
|
284 |
+
num_layers,
|
285 |
+
rnn_cell='gru',
|
286 |
+
bidirectional=False
|
287 |
+
):
|
288 |
+
super(AudioPoseEncoderRNN, self).__init__()
|
289 |
+
if rnn_cell == 'gru':
|
290 |
+
self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
291 |
+
bidirectional=bidirectional)
|
292 |
+
elif rnn_cell == 'lstm':
|
293 |
+
self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
294 |
+
bidirectional=bidirectional)
|
295 |
+
else:
|
296 |
+
raise ValueError('invalid rnn cell:%s' % (rnn_cell))
|
297 |
+
|
298 |
+
def forward(self, x, state=None):
|
299 |
+
|
300 |
+
x = x.permute(0, 2, 1)
|
301 |
+
x, state = self.cell(x, state)
|
302 |
+
x = x.permute(0, 2, 1)
|
303 |
+
|
304 |
+
return x
|
305 |
+
|
306 |
+
|
307 |
+
class AudioPoseEncoderGraph(nn.Module):
|
308 |
+
'''
|
309 |
+
(B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V)
|
310 |
+
'''
|
311 |
+
|
312 |
+
def __init__(self,
|
313 |
+
layers_config, # 理应是(C_in, C_out, kernel_size)的list
|
314 |
+
A, # adjacent matrix (num_parts, V, V)
|
315 |
+
residual,
|
316 |
+
local_bn=False,
|
317 |
+
share_weights=False
|
318 |
+
) -> None:
|
319 |
+
super().__init__()
|
320 |
+
self.A = A
|
321 |
+
self.num_joints = A.shape[1]
|
322 |
+
self.num_parts = A.shape[0]
|
323 |
+
self.C_in = layers_config[0][0]
|
324 |
+
self.C_out = layers_config[-1][1]
|
325 |
+
|
326 |
+
self.conv_layers = nn.ModuleList([
|
327 |
+
GraphConvNormRelu(
|
328 |
+
C_in=c_in,
|
329 |
+
C_out=c_out,
|
330 |
+
A=self.A,
|
331 |
+
residual=residual,
|
332 |
+
local_bn=local_bn,
|
333 |
+
kernel_size=k,
|
334 |
+
share_weights=share_weights
|
335 |
+
) for (c_in, c_out, k) in layers_config
|
336 |
+
])
|
337 |
+
|
338 |
+
self.conv_layers = nn.Sequential(*self.conv_layers)
|
339 |
+
|
340 |
+
def forward(self, x):
|
341 |
+
'''
|
342 |
+
x: (B, C, T), C should be num_joints*D
|
343 |
+
output: (B, D, T, V)
|
344 |
+
'''
|
345 |
+
B, C, T = x.shape
|
346 |
+
x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面
|
347 |
+
x = x.permute(0, 2, 3, 1) # (B, D, T, V)
|
348 |
+
assert x.shape[1] == self.C_in
|
349 |
+
|
350 |
+
x_conved = self.conv_layers(x)
|
351 |
+
|
352 |
+
# x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T)
|
353 |
+
|
354 |
+
return x_conved
|
355 |
+
|
356 |
+
|
357 |
+
class SeqEncoder2D(nn.Module):
|
358 |
+
'''
|
359 |
+
seq_encoder, encoding a seq to a vector
|
360 |
+
(B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out)
|
361 |
+
'''
|
362 |
+
|
363 |
+
def __init__(self,
|
364 |
+
C_in, # should be 2
|
365 |
+
T_in,
|
366 |
+
C_out,
|
367 |
+
num_joints,
|
368 |
+
min_layer_num=None,
|
369 |
+
residual=False
|
370 |
+
):
|
371 |
+
super(SeqEncoder2D, self).__init__()
|
372 |
+
self.C_in = C_in
|
373 |
+
self.C_out = C_out
|
374 |
+
self.T_in = T_in
|
375 |
+
self.num_joints = num_joints
|
376 |
+
|
377 |
+
conv_layers = nn.ModuleList([])
|
378 |
+
conv_layers.append(ConvNormRelu(
|
379 |
+
in_channels=C_in,
|
380 |
+
out_channels=32,
|
381 |
+
type='2d',
|
382 |
+
residual=residual
|
383 |
+
))
|
384 |
+
|
385 |
+
cur_C = 32
|
386 |
+
cur_H = T_in
|
387 |
+
cur_W = num_joints
|
388 |
+
num_layers = 1
|
389 |
+
while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1):
|
390 |
+
ks = [3, 3]
|
391 |
+
st = [1, 1]
|
392 |
+
|
393 |
+
if cur_H > 1:
|
394 |
+
if cur_H > 4:
|
395 |
+
ks[0] = 4
|
396 |
+
st[0] = 2
|
397 |
+
else:
|
398 |
+
ks[0] = cur_H
|
399 |
+
st[0] = cur_H
|
400 |
+
if cur_W > 1:
|
401 |
+
if cur_W > 4:
|
402 |
+
ks[1] = 4
|
403 |
+
st[1] = 2
|
404 |
+
else:
|
405 |
+
ks[1] = cur_W
|
406 |
+
st[1] = cur_W
|
407 |
+
|
408 |
+
conv_layers.append(ConvNormRelu(
|
409 |
+
in_channels=cur_C,
|
410 |
+
out_channels=min(C_out, cur_C * 2),
|
411 |
+
type='2d',
|
412 |
+
kernel_size=tuple(ks),
|
413 |
+
stride=tuple(st),
|
414 |
+
residual=residual
|
415 |
+
))
|
416 |
+
cur_C = min(cur_C * 2, C_out)
|
417 |
+
if cur_H > 1:
|
418 |
+
if cur_H > 4:
|
419 |
+
cur_H //= 2
|
420 |
+
else:
|
421 |
+
cur_H = 1
|
422 |
+
if cur_W > 1:
|
423 |
+
if cur_W > 4:
|
424 |
+
cur_W //= 2
|
425 |
+
else:
|
426 |
+
cur_W = 1
|
427 |
+
num_layers += 1
|
428 |
+
|
429 |
+
if min_layer_num is not None and (num_layers < min_layer_num):
|
430 |
+
while num_layers < min_layer_num:
|
431 |
+
conv_layers.append(ConvNormRelu(
|
432 |
+
in_channels=C_out,
|
433 |
+
out_channels=C_out,
|
434 |
+
type='2d',
|
435 |
+
kernel_size=1,
|
436 |
+
stride=1,
|
437 |
+
residual=residual
|
438 |
+
))
|
439 |
+
num_layers += 1
|
440 |
+
|
441 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
442 |
+
self.num_layers = num_layers
|
443 |
+
|
444 |
+
def forward(self, x):
|
445 |
+
B, C, T = x.shape
|
446 |
+
x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front
|
447 |
+
x = x.permute(0, 2, 3, 1) # (B, D, T, V)
|
448 |
+
assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints
|
449 |
+
|
450 |
+
x = self.conv_layers(x)
|
451 |
+
return x.squeeze()
|
452 |
+
|
453 |
+
|
454 |
+
class SeqEncoder1D(nn.Module):
|
455 |
+
'''
|
456 |
+
(B, C, T)->(B, D)
|
457 |
+
'''
|
458 |
+
|
459 |
+
def __init__(self,
|
460 |
+
C_in,
|
461 |
+
C_out,
|
462 |
+
T_in,
|
463 |
+
min_layer_nums=None
|
464 |
+
):
|
465 |
+
super(SeqEncoder1D, self).__init__()
|
466 |
+
conv_layers = nn.ModuleList([])
|
467 |
+
cur_C = C_in
|
468 |
+
cur_T = T_in
|
469 |
+
self.num_layers = 0
|
470 |
+
while (cur_C < C_out) or (cur_T > 1):
|
471 |
+
ks = 3
|
472 |
+
st = 1
|
473 |
+
if cur_T > 1:
|
474 |
+
if cur_T > 4:
|
475 |
+
ks = 4
|
476 |
+
st = 2
|
477 |
+
else:
|
478 |
+
ks = cur_T
|
479 |
+
st = cur_T
|
480 |
+
|
481 |
+
conv_layers.append(ConvNormRelu(
|
482 |
+
in_channels=cur_C,
|
483 |
+
out_channels=min(C_out, cur_C * 2),
|
484 |
+
type='1d',
|
485 |
+
kernel_size=ks,
|
486 |
+
stride=st
|
487 |
+
))
|
488 |
+
cur_C = min(cur_C * 2, C_out)
|
489 |
+
if cur_T > 1:
|
490 |
+
if cur_T > 4:
|
491 |
+
cur_T = cur_T // 2
|
492 |
+
else:
|
493 |
+
cur_T = 1
|
494 |
+
self.num_layers += 1
|
495 |
+
|
496 |
+
if min_layer_nums is not None and (self.num_layers < min_layer_nums):
|
497 |
+
while self.num_layers < min_layer_nums:
|
498 |
+
conv_layers.append(ConvNormRelu(
|
499 |
+
in_channels=C_out,
|
500 |
+
out_channels=C_out,
|
501 |
+
type='1d',
|
502 |
+
kernel_size=1,
|
503 |
+
stride=1
|
504 |
+
))
|
505 |
+
self.num_layers += 1
|
506 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
507 |
+
|
508 |
+
def forward(self, x):
|
509 |
+
x = self.conv_layers(x)
|
510 |
+
return x.squeeze()
|
511 |
+
|
512 |
+
|
513 |
+
class SeqEncoderRNN(nn.Module):
|
514 |
+
'''
|
515 |
+
(B, C, T) -> (B, T, C) -> (B, D)
|
516 |
+
LSTM/GRU-FC
|
517 |
+
'''
|
518 |
+
|
519 |
+
def __init__(self,
|
520 |
+
hidden_size,
|
521 |
+
in_size,
|
522 |
+
num_rnn_layers,
|
523 |
+
rnn_cell='gru',
|
524 |
+
bidirectional=False
|
525 |
+
):
|
526 |
+
super(SeqEncoderRNN, self).__init__()
|
527 |
+
self.hidden_size = hidden_size
|
528 |
+
self.in_size = in_size
|
529 |
+
self.num_rnn_layers = num_rnn_layers
|
530 |
+
self.bidirectional = bidirectional
|
531 |
+
|
532 |
+
if rnn_cell == 'gru':
|
533 |
+
self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
|
534 |
+
batch_first=True, bidirectional=bidirectional)
|
535 |
+
elif rnn_cell == 'lstm':
|
536 |
+
self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
|
537 |
+
batch_first=True, bidirectional=bidirectional)
|
538 |
+
|
539 |
+
def forward(self, x, state=None):
|
540 |
+
|
541 |
+
x = x.permute(0, 2, 1)
|
542 |
+
B, T, C = x.shape
|
543 |
+
x, _ = self.cell(x, state)
|
544 |
+
if self.bidirectional:
|
545 |
+
out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1)
|
546 |
+
else:
|
547 |
+
out = x[:, -1, :]
|
548 |
+
assert out.shape[0] == B
|
549 |
+
return out
|
550 |
+
|
551 |
+
|
552 |
+
class SeqEncoderGraph(nn.Module):
|
553 |
+
'''
|
554 |
+
'''
|
555 |
+
|
556 |
+
def __init__(self,
|
557 |
+
embedding_size,
|
558 |
+
layer_configs,
|
559 |
+
residual,
|
560 |
+
local_bn,
|
561 |
+
A,
|
562 |
+
T,
|
563 |
+
share_weights=False
|
564 |
+
) -> None:
|
565 |
+
super().__init__()
|
566 |
+
|
567 |
+
self.C_in = layer_configs[0][0]
|
568 |
+
self.C_out = embedding_size
|
569 |
+
|
570 |
+
self.num_joints = A.shape[1]
|
571 |
+
|
572 |
+
self.graph_encoder = AudioPoseEncoderGraph(
|
573 |
+
layers_config=layer_configs,
|
574 |
+
A=A,
|
575 |
+
residual=residual,
|
576 |
+
local_bn=local_bn,
|
577 |
+
share_weights=share_weights
|
578 |
+
)
|
579 |
+
|
580 |
+
cur_C = layer_configs[-1][1]
|
581 |
+
self.spatial_pool = ConvNormRelu(
|
582 |
+
in_channels=cur_C,
|
583 |
+
out_channels=cur_C,
|
584 |
+
type='2d',
|
585 |
+
kernel_size=(1, self.num_joints),
|
586 |
+
stride=(1, 1),
|
587 |
+
padding=(0, 0)
|
588 |
+
)
|
589 |
+
|
590 |
+
temporal_pool = nn.ModuleList([])
|
591 |
+
cur_H = T
|
592 |
+
num_layers = 0
|
593 |
+
self.temporal_conv_info = []
|
594 |
+
while cur_C < self.C_out or cur_H > 1:
|
595 |
+
self.temporal_conv_info.append(cur_C)
|
596 |
+
ks = [3, 1]
|
597 |
+
st = [1, 1]
|
598 |
+
|
599 |
+
if cur_H > 1:
|
600 |
+
if cur_H > 4:
|
601 |
+
ks[0] = 4
|
602 |
+
st[0] = 2
|
603 |
+
else:
|
604 |
+
ks[0] = cur_H
|
605 |
+
st[0] = cur_H
|
606 |
+
|
607 |
+
temporal_pool.append(ConvNormRelu(
|
608 |
+
in_channels=cur_C,
|
609 |
+
out_channels=min(self.C_out, cur_C * 2),
|
610 |
+
type='2d',
|
611 |
+
kernel_size=tuple(ks),
|
612 |
+
stride=tuple(st)
|
613 |
+
))
|
614 |
+
cur_C = min(cur_C * 2, self.C_out)
|
615 |
+
|
616 |
+
if cur_H > 1:
|
617 |
+
if cur_H > 4:
|
618 |
+
cur_H //= 2
|
619 |
+
else:
|
620 |
+
cur_H = 1
|
621 |
+
|
622 |
+
num_layers += 1
|
623 |
+
|
624 |
+
self.temporal_pool = nn.Sequential(*temporal_pool)
|
625 |
+
print("graph seq encoder info: temporal pool:", self.temporal_conv_info)
|
626 |
+
self.num_layers = num_layers
|
627 |
+
# need fc?
|
628 |
+
|
629 |
+
def forward(self, x):
|
630 |
+
'''
|
631 |
+
x: (B, C, T)
|
632 |
+
'''
|
633 |
+
B, C, T = x.shape
|
634 |
+
x = self.graph_encoder(x)
|
635 |
+
x = self.spatial_pool(x)
|
636 |
+
x = self.temporal_pool(x)
|
637 |
+
x = x.view(B, self.C_out)
|
638 |
+
|
639 |
+
return x
|
640 |
+
|
641 |
+
|
642 |
+
class SeqDecoder2D(nn.Module):
|
643 |
+
'''
|
644 |
+
(B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T)
|
645 |
+
'''
|
646 |
+
|
647 |
+
def __init__(self):
|
648 |
+
super(SeqDecoder2D, self).__init__()
|
649 |
+
raise NotImplementedError
|
650 |
+
|
651 |
+
|
652 |
+
class SeqDecoder1D(nn.Module):
|
653 |
+
'''
|
654 |
+
(B, D)->(B, D, 1)->...->(B, C_out, T)
|
655 |
+
'''
|
656 |
+
|
657 |
+
def __init__(self,
|
658 |
+
D_in,
|
659 |
+
C_out,
|
660 |
+
T_out,
|
661 |
+
min_layer_num=None
|
662 |
+
):
|
663 |
+
super(SeqDecoder1D, self).__init__()
|
664 |
+
self.T_out = T_out
|
665 |
+
self.min_layer_num = min_layer_num
|
666 |
+
|
667 |
+
cur_t = 1
|
668 |
+
|
669 |
+
self.pre_conv = ConvNormRelu(
|
670 |
+
in_channels=D_in,
|
671 |
+
out_channels=C_out,
|
672 |
+
type='1d'
|
673 |
+
)
|
674 |
+
self.num_layers = 1
|
675 |
+
self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
|
676 |
+
self.conv_layers = nn.ModuleList([])
|
677 |
+
cur_t *= 2
|
678 |
+
while cur_t <= T_out:
|
679 |
+
self.conv_layers.append(ConvNormRelu(
|
680 |
+
in_channels=C_out,
|
681 |
+
out_channels=C_out,
|
682 |
+
type='1d'
|
683 |
+
))
|
684 |
+
cur_t *= 2
|
685 |
+
self.num_layers += 1
|
686 |
+
|
687 |
+
post_conv = nn.ModuleList([ConvNormRelu(
|
688 |
+
in_channels=C_out,
|
689 |
+
out_channels=C_out,
|
690 |
+
type='1d'
|
691 |
+
)])
|
692 |
+
self.num_layers += 1
|
693 |
+
if min_layer_num is not None and self.num_layers < min_layer_num:
|
694 |
+
while self.num_layers < min_layer_num:
|
695 |
+
post_conv.append(ConvNormRelu(
|
696 |
+
in_channels=C_out,
|
697 |
+
out_channels=C_out,
|
698 |
+
type='1d'
|
699 |
+
))
|
700 |
+
self.num_layers += 1
|
701 |
+
self.post_conv = nn.Sequential(*post_conv)
|
702 |
+
|
703 |
+
def forward(self, x):
|
704 |
+
|
705 |
+
x = x.unsqueeze(-1)
|
706 |
+
x = self.pre_conv(x)
|
707 |
+
for conv in self.conv_layers:
|
708 |
+
x = self.upconv(x)
|
709 |
+
x = conv(x)
|
710 |
+
|
711 |
+
x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest')
|
712 |
+
x = self.post_conv(x)
|
713 |
+
return x
|
714 |
+
|
715 |
+
|
716 |
+
class SeqDecoderRNN(nn.Module):
|
717 |
+
'''
|
718 |
+
(B, D)->(B, C_out, T)
|
719 |
+
'''
|
720 |
+
|
721 |
+
def __init__(self,
|
722 |
+
hidden_size,
|
723 |
+
C_out,
|
724 |
+
T_out,
|
725 |
+
num_layers,
|
726 |
+
rnn_cell='gru'
|
727 |
+
):
|
728 |
+
super(SeqDecoderRNN, self).__init__()
|
729 |
+
self.num_steps = T_out
|
730 |
+
if rnn_cell == 'gru':
|
731 |
+
self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
732 |
+
bidirectional=False)
|
733 |
+
elif rnn_cell == 'lstm':
|
734 |
+
self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
735 |
+
bidirectional=False)
|
736 |
+
else:
|
737 |
+
raise ValueError('invalid rnn cell:%s' % (rnn_cell))
|
738 |
+
|
739 |
+
self.fc = nn.Linear(hidden_size, C_out)
|
740 |
+
|
741 |
+
def forward(self, hidden, frame_0):
|
742 |
+
frame_0 = frame_0.permute(0, 2, 1)
|
743 |
+
dec_input = frame_0
|
744 |
+
outputs = []
|
745 |
+
for i in range(self.num_steps):
|
746 |
+
frame_out, hidden = self.cell(dec_input, hidden)
|
747 |
+
frame_out = self.fc(frame_out)
|
748 |
+
dec_input = frame_out
|
749 |
+
outputs.append(frame_out)
|
750 |
+
output = torch.cat(outputs, dim=1)
|
751 |
+
return output.permute(0, 2, 1)
|
752 |
+
|
753 |
+
|
754 |
+
class SeqTranslator2D(nn.Module):
|
755 |
+
'''
|
756 |
+
(B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out)
|
757 |
+
'''
|
758 |
+
|
759 |
+
def __init__(self,
|
760 |
+
C_in=64,
|
761 |
+
C_out=108,
|
762 |
+
T_in=75,
|
763 |
+
T_out=25,
|
764 |
+
residual=True
|
765 |
+
):
|
766 |
+
super(SeqTranslator2D, self).__init__()
|
767 |
+
print("Warning: hard coded")
|
768 |
+
self.C_in = C_in
|
769 |
+
self.C_out = C_out
|
770 |
+
self.T_in = T_in
|
771 |
+
self.T_out = T_out
|
772 |
+
self.residual = residual
|
773 |
+
|
774 |
+
self.conv_layers = nn.Sequential(
|
775 |
+
ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1),
|
776 |
+
ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
|
777 |
+
ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
|
778 |
+
|
779 |
+
ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)),
|
780 |
+
ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
|
781 |
+
ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
|
782 |
+
|
783 |
+
ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)),
|
784 |
+
ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)),
|
785 |
+
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
|
786 |
+
|
787 |
+
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
|
788 |
+
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1),
|
789 |
+
)
|
790 |
+
|
791 |
+
def forward(self, x):
|
792 |
+
assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in
|
793 |
+
x = x.view(x.shape[0], 1, x.shape[1], x.shape[2])
|
794 |
+
x = self.conv_layers(x)
|
795 |
+
x = x.squeeze(2)
|
796 |
+
return x
|
797 |
+
|
798 |
+
|
799 |
+
class SeqTranslator1D(nn.Module):
|
800 |
+
'''
|
801 |
+
(B, C, T)->(B, C_out, T)
|
802 |
+
'''
|
803 |
+
|
804 |
+
def __init__(self,
|
805 |
+
C_in,
|
806 |
+
C_out,
|
807 |
+
kernel_size=None,
|
808 |
+
stride=None,
|
809 |
+
min_layers_num=None,
|
810 |
+
residual=True,
|
811 |
+
norm='bn'
|
812 |
+
):
|
813 |
+
super(SeqTranslator1D, self).__init__()
|
814 |
+
|
815 |
+
conv_layers = nn.ModuleList([])
|
816 |
+
conv_layers.append(ConvNormRelu(
|
817 |
+
in_channels=C_in,
|
818 |
+
out_channels=C_out,
|
819 |
+
type='1d',
|
820 |
+
kernel_size=kernel_size,
|
821 |
+
stride=stride,
|
822 |
+
residual=residual,
|
823 |
+
norm=norm
|
824 |
+
))
|
825 |
+
self.num_layers = 1
|
826 |
+
if min_layers_num is not None and self.num_layers < min_layers_num:
|
827 |
+
while self.num_layers < min_layers_num:
|
828 |
+
conv_layers.append(ConvNormRelu(
|
829 |
+
in_channels=C_out,
|
830 |
+
out_channels=C_out,
|
831 |
+
type='1d',
|
832 |
+
kernel_size=kernel_size,
|
833 |
+
stride=stride,
|
834 |
+
residual=residual,
|
835 |
+
norm=norm
|
836 |
+
))
|
837 |
+
self.num_layers += 1
|
838 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
839 |
+
|
840 |
+
def forward(self, x):
|
841 |
+
return self.conv_layers(x)
|
842 |
+
|
843 |
+
|
844 |
+
class SeqTranslatorRNN(nn.Module):
|
845 |
+
'''
|
846 |
+
(B, C, T)->(B, C_out, T)
|
847 |
+
LSTM-FC
|
848 |
+
'''
|
849 |
+
|
850 |
+
def __init__(self,
|
851 |
+
C_in,
|
852 |
+
C_out,
|
853 |
+
hidden_size,
|
854 |
+
num_layers,
|
855 |
+
rnn_cell='gru'
|
856 |
+
):
|
857 |
+
super(SeqTranslatorRNN, self).__init__()
|
858 |
+
|
859 |
+
if rnn_cell == 'gru':
|
860 |
+
self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
861 |
+
bidirectional=False)
|
862 |
+
self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
863 |
+
bidirectional=False)
|
864 |
+
elif rnn_cell == 'lstm':
|
865 |
+
self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
866 |
+
bidirectional=False)
|
867 |
+
self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
|
868 |
+
bidirectional=False)
|
869 |
+
else:
|
870 |
+
raise ValueError('invalid rnn cell:%s' % (rnn_cell))
|
871 |
+
|
872 |
+
self.fc = nn.Linear(hidden_size, C_out)
|
873 |
+
|
874 |
+
def forward(self, x, frame_0):
|
875 |
+
|
876 |
+
num_steps = x.shape[-1]
|
877 |
+
x = x.permute(0, 2, 1)
|
878 |
+
frame_0 = frame_0.permute(0, 2, 1)
|
879 |
+
_, hidden = self.enc_cell(x, None)
|
880 |
+
|
881 |
+
outputs = []
|
882 |
+
for i in range(num_steps):
|
883 |
+
inputs = frame_0
|
884 |
+
output_frame, hidden = self.dec_cell(inputs, hidden)
|
885 |
+
output_frame = self.fc(output_frame)
|
886 |
+
frame_0 = output_frame
|
887 |
+
outputs.append(output_frame)
|
888 |
+
outputs = torch.cat(outputs, dim=1)
|
889 |
+
return outputs.permute(0, 2, 1)
|
890 |
+
|
891 |
+
|
892 |
+
class ResBlock(nn.Module):
|
893 |
+
def __init__(self,
|
894 |
+
input_dim,
|
895 |
+
fc_dim,
|
896 |
+
afn,
|
897 |
+
nfn
|
898 |
+
):
|
899 |
+
'''
|
900 |
+
afn: activation fn
|
901 |
+
nfn: normalization fn
|
902 |
+
'''
|
903 |
+
super(ResBlock, self).__init__()
|
904 |
+
|
905 |
+
self.input_dim = input_dim
|
906 |
+
self.fc_dim = fc_dim
|
907 |
+
self.afn = afn
|
908 |
+
self.nfn = nfn
|
909 |
+
|
910 |
+
if self.afn != 'relu':
|
911 |
+
raise ValueError('Wrong')
|
912 |
+
|
913 |
+
if self.nfn == 'layer_norm':
|
914 |
+
raise ValueError('wrong')
|
915 |
+
|
916 |
+
self.layers = nn.Sequential(
|
917 |
+
nn.Linear(self.input_dim, self.fc_dim // 2),
|
918 |
+
nn.ReLU(),
|
919 |
+
nn.Linear(self.fc_dim // 2, self.fc_dim // 2),
|
920 |
+
nn.ReLU(),
|
921 |
+
nn.Linear(self.fc_dim // 2, self.fc_dim),
|
922 |
+
nn.ReLU()
|
923 |
+
)
|
924 |
+
|
925 |
+
self.shortcut_layer = nn.Sequential(
|
926 |
+
nn.Linear(self.input_dim, self.fc_dim),
|
927 |
+
nn.ReLU(),
|
928 |
+
)
|
929 |
+
|
930 |
+
def forward(self, inputs):
|
931 |
+
return self.layers(inputs) + self.shortcut_layer(inputs)
|
932 |
+
|
933 |
+
|
934 |
+
class AudioEncoder(nn.Module):
|
935 |
+
def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False):
|
936 |
+
super(AudioEncoder, self).__init__()
|
937 |
+
self.in_channels = channels[0]
|
938 |
+
self.augmentation = augmentation
|
939 |
+
|
940 |
+
model = []
|
941 |
+
acti = nn.LeakyReLU(0.2)
|
942 |
+
|
943 |
+
nr_layer = len(channels) - 1
|
944 |
+
|
945 |
+
for i in range(nr_layer):
|
946 |
+
if conv_pool is None:
|
947 |
+
model.append(nn.ReflectionPad1d(padding))
|
948 |
+
model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
|
949 |
+
model.append(acti)
|
950 |
+
else:
|
951 |
+
model.append(nn.ReflectionPad1d(padding))
|
952 |
+
model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
|
953 |
+
model.append(acti)
|
954 |
+
model.append(conv_pool(kernel_size=2, stride=2))
|
955 |
+
|
956 |
+
if self.augmentation:
|
957 |
+
model.append(
|
958 |
+
nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride)
|
959 |
+
)
|
960 |
+
model.append(acti)
|
961 |
+
|
962 |
+
self.model = nn.Sequential(*model)
|
963 |
+
|
964 |
+
def forward(self, x):
|
965 |
+
|
966 |
+
x = x[:, :self.in_channels, :]
|
967 |
+
x = self.model(x)
|
968 |
+
return x
|
969 |
+
|
970 |
+
|
971 |
+
class AudioDecoder(nn.Module):
|
972 |
+
def __init__(self, channels, kernel_size=7, ups=25):
|
973 |
+
super(AudioDecoder, self).__init__()
|
974 |
+
|
975 |
+
model = []
|
976 |
+
pad = (kernel_size - 1) // 2
|
977 |
+
acti = nn.LeakyReLU(0.2)
|
978 |
+
|
979 |
+
for i in range(len(channels) - 2):
|
980 |
+
model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
981 |
+
model.append(nn.ReflectionPad1d(pad))
|
982 |
+
model.append(nn.Conv1d(channels[i], channels[i + 1],
|
983 |
+
kernel_size=kernel_size, stride=1))
|
984 |
+
if i == 0 or i == 1:
|
985 |
+
model.append(nn.Dropout(p=0.2))
|
986 |
+
if not i == len(channels) - 2:
|
987 |
+
model.append(acti)
|
988 |
+
|
989 |
+
model.append(nn.Upsample(size=ups, mode='nearest'))
|
990 |
+
model.append(nn.ReflectionPad1d(pad))
|
991 |
+
model.append(nn.Conv1d(channels[-2], channels[-1],
|
992 |
+
kernel_size=kernel_size, stride=1))
|
993 |
+
|
994 |
+
self.model = nn.Sequential(*model)
|
995 |
+
|
996 |
+
def forward(self, x):
|
997 |
+
return self.model(x)
|
998 |
+
|
999 |
+
|
1000 |
+
class Audio2Pose(nn.Module):
|
1001 |
+
def __init__(self, pose_dim, embed_size, augmentation, ups=25):
|
1002 |
+
super(Audio2Pose, self).__init__()
|
1003 |
+
self.pose_dim = pose_dim
|
1004 |
+
self.embed_size = embed_size
|
1005 |
+
self.augmentation = augmentation
|
1006 |
+
|
1007 |
+
self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1,
|
1008 |
+
conv_pool=nn.AvgPool1d, augmentation=self.augmentation)
|
1009 |
+
if self.augmentation:
|
1010 |
+
self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim])
|
1011 |
+
else:
|
1012 |
+
self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups)
|
1013 |
+
|
1014 |
+
if self.augmentation:
|
1015 |
+
self.pose_enc = nn.Sequential(
|
1016 |
+
nn.Linear(self.embed_size // 2, 256),
|
1017 |
+
nn.LayerNorm(256)
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
def forward(self, audio_feat, dec_input=None):
|
1021 |
+
|
1022 |
+
B = audio_feat.shape[0]
|
1023 |
+
|
1024 |
+
aud_embed = self.aud_enc.forward(audio_feat)
|
1025 |
+
|
1026 |
+
if self.augmentation:
|
1027 |
+
dec_input = dec_input.squeeze(0)
|
1028 |
+
dec_embed = self.pose_enc(dec_input)
|
1029 |
+
dec_embed = dec_embed.unsqueeze(2)
|
1030 |
+
dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1])
|
1031 |
+
aud_embed = torch.cat([aud_embed, dec_embed], dim=1)
|
1032 |
+
|
1033 |
+
out = self.aud_dec.forward(aud_embed)
|
1034 |
+
return out
|
1035 |
+
|
1036 |
+
|
1037 |
+
if __name__ == '__main__':
|
1038 |
+
import numpy as np
|
1039 |
+
import os
|
1040 |
+
import sys
|
1041 |
+
|
1042 |
+
test_model = SeqEncoder2D(
|
1043 |
+
C_in=2,
|
1044 |
+
T_in=25,
|
1045 |
+
C_out=512,
|
1046 |
+
num_joints=54,
|
1047 |
+
)
|
1048 |
+
print(test_model.num_layers)
|
1049 |
+
|
1050 |
+
input = torch.randn((64, 108, 25))
|
1051 |
+
output = test_model(input)
|
1052 |
+
print(output.shape)
|
nets/smplx_body_pixel.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.optim.lr_scheduler import StepLR
|
6 |
+
|
7 |
+
sys.path.append(os.getcwd())
|
8 |
+
|
9 |
+
from nets.layers import *
|
10 |
+
from nets.base import TrainWrapperBaseClass
|
11 |
+
from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
|
12 |
+
from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder
|
13 |
+
from nets.spg.vqvae_1d import AudioEncoder
|
14 |
+
from nets.utils import parse_audio, denormalize
|
15 |
+
from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
|
16 |
+
import numpy as np
|
17 |
+
import torch.optim as optim
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from sklearn.preprocessing import normalize
|
20 |
+
|
21 |
+
from data_utils.lower_body import c_index, c_index_3d, c_index_6d
|
22 |
+
from data_utils.utils import smooth_geom, get_mfcc_sepa
|
23 |
+
|
24 |
+
|
25 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
26 |
+
'''
|
27 |
+
a wrapper receving a batch from data_utils and calculate loss
|
28 |
+
'''
|
29 |
+
|
30 |
+
def __init__(self, args, config):
|
31 |
+
self.args = args
|
32 |
+
self.config = config
|
33 |
+
self.device = torch.device(self.args.gpu)
|
34 |
+
self.global_step = 0
|
35 |
+
|
36 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
37 |
+
self.expression = self.config.Data.pose.expression
|
38 |
+
self.epoch = 0
|
39 |
+
self.init_params()
|
40 |
+
self.num_classes = 4
|
41 |
+
self.audio = True
|
42 |
+
self.composition = self.config.Model.composition
|
43 |
+
self.bh_model = self.config.Model.bh_model
|
44 |
+
|
45 |
+
if self.audio:
|
46 |
+
self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
|
47 |
+
else:
|
48 |
+
self.audioencoder = None
|
49 |
+
if self.convert_to_6d:
|
50 |
+
dim, layer = 512, 10
|
51 |
+
else:
|
52 |
+
dim, layer = 256, 15
|
53 |
+
self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
|
54 |
+
self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
|
55 |
+
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
56 |
+
self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
|
57 |
+
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
58 |
+
|
59 |
+
model_path = self.config.Model.vq_path
|
60 |
+
model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
61 |
+
self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
|
62 |
+
self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
|
63 |
+
|
64 |
+
if torch.cuda.device_count() > 1:
|
65 |
+
self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1])
|
66 |
+
self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1])
|
67 |
+
self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1])
|
68 |
+
if self.audioencoder is not None:
|
69 |
+
self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1])
|
70 |
+
|
71 |
+
self.discriminator = None
|
72 |
+
if self.convert_to_6d:
|
73 |
+
self.c_index = c_index_6d
|
74 |
+
else:
|
75 |
+
self.c_index = c_index_3d
|
76 |
+
|
77 |
+
super().__init__(args, config)
|
78 |
+
|
79 |
+
def init_optimizer(self):
|
80 |
+
|
81 |
+
print('using Adam')
|
82 |
+
self.generator_optimizer = optim.Adam(
|
83 |
+
self.generator.parameters(),
|
84 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
85 |
+
betas=[0.9, 0.999]
|
86 |
+
)
|
87 |
+
if self.audioencoder is not None:
|
88 |
+
opt = self.config.Model.AudioOpt
|
89 |
+
if opt == 'Adam':
|
90 |
+
self.audioencoder_optimizer = optim.Adam(
|
91 |
+
self.audioencoder.parameters(),
|
92 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
93 |
+
betas=[0.9, 0.999]
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
print('using SGD')
|
97 |
+
self.audioencoder_optimizer = optim.SGD(
|
98 |
+
filter(lambda p: p.requires_grad,self.audioencoder.parameters()),
|
99 |
+
lr=self.config.Train.learning_rate.generator_learning_rate*10,
|
100 |
+
momentum=0.9,
|
101 |
+
nesterov=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
def state_dict(self):
|
105 |
+
model_state = {
|
106 |
+
'generator': self.generator.state_dict(),
|
107 |
+
'generator_optim': self.generator_optimizer.state_dict(),
|
108 |
+
'audioencoder': self.audioencoder.state_dict() if self.audio else None,
|
109 |
+
'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
|
110 |
+
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
|
111 |
+
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
|
112 |
+
}
|
113 |
+
return model_state
|
114 |
+
|
115 |
+
def load_state_dict(self, state_dict):
|
116 |
+
|
117 |
+
from collections import OrderedDict
|
118 |
+
new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.`
|
119 |
+
for k, v in state_dict.items():
|
120 |
+
sub_dict = OrderedDict()
|
121 |
+
if v is not None:
|
122 |
+
for k1, v1 in v.items():
|
123 |
+
name = k1.replace('module.', '')
|
124 |
+
sub_dict[name] = v1
|
125 |
+
new_state_dict[k] = sub_dict
|
126 |
+
state_dict = new_state_dict
|
127 |
+
if 'generator' in state_dict:
|
128 |
+
self.generator.load_state_dict(state_dict['generator'])
|
129 |
+
else:
|
130 |
+
self.generator.load_state_dict(state_dict)
|
131 |
+
|
132 |
+
if 'generator_optim' in state_dict and self.generator_optimizer is not None:
|
133 |
+
self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
|
134 |
+
|
135 |
+
if self.discriminator is not None:
|
136 |
+
self.discriminator.load_state_dict(state_dict['discriminator'])
|
137 |
+
|
138 |
+
if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
|
139 |
+
self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
|
140 |
+
|
141 |
+
if 'audioencoder' in state_dict and self.audioencoder is not None:
|
142 |
+
self.audioencoder.load_state_dict(state_dict['audioencoder'])
|
143 |
+
|
144 |
+
def init_params(self):
|
145 |
+
if self.config.Data.pose.convert_to_6d:
|
146 |
+
scale = 2
|
147 |
+
else:
|
148 |
+
scale = 1
|
149 |
+
|
150 |
+
global_orient = round(0 * scale)
|
151 |
+
leye_pose = reye_pose = round(0 * scale)
|
152 |
+
jaw_pose = round(0 * scale)
|
153 |
+
body_pose = round((63 - 24) * scale)
|
154 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
155 |
+
if self.expression:
|
156 |
+
expression = 100
|
157 |
+
else:
|
158 |
+
expression = 0
|
159 |
+
|
160 |
+
b_j = 0
|
161 |
+
jaw_dim = jaw_pose
|
162 |
+
b_e = b_j + jaw_dim
|
163 |
+
eye_dim = leye_pose + reye_pose
|
164 |
+
b_b = b_e + eye_dim
|
165 |
+
body_dim = global_orient + body_pose
|
166 |
+
b_h = b_b + body_dim
|
167 |
+
hand_dim = left_hand_pose + right_hand_pose
|
168 |
+
b_f = b_h + hand_dim
|
169 |
+
face_dim = expression
|
170 |
+
|
171 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
172 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
173 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
174 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
175 |
+
|
176 |
+
def __call__(self, bat):
|
177 |
+
# assert (not self.args.infer), "infer mode"
|
178 |
+
self.global_step += 1
|
179 |
+
|
180 |
+
total_loss = None
|
181 |
+
loss_dict = {}
|
182 |
+
|
183 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
184 |
+
|
185 |
+
id = bat['speaker'].to(self.device) - 20
|
186 |
+
# id = F.one_hot(id, self.num_classes)
|
187 |
+
|
188 |
+
poses = poses[:, self.c_index, :]
|
189 |
+
|
190 |
+
aud = aud.permute(0, 2, 1)
|
191 |
+
gt_poses = poses.permute(0, 2, 1)
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
self.g_body.eval()
|
195 |
+
self.g_hand.eval()
|
196 |
+
if torch.cuda.device_count() > 1:
|
197 |
+
_, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
|
198 |
+
_, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
|
199 |
+
else:
|
200 |
+
_, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
|
201 |
+
_, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
|
202 |
+
latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1)
|
203 |
+
latents = latents.detach()
|
204 |
+
|
205 |
+
if self.audio:
|
206 |
+
audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
|
207 |
+
logits = self.generator(latents[:, :], id, audio)
|
208 |
+
else:
|
209 |
+
logits = self.generator(latents, id)
|
210 |
+
logits = logits.permute(0, 2, 3, 1).contiguous()
|
211 |
+
|
212 |
+
self.generator_optimizer.zero_grad()
|
213 |
+
if self.audio:
|
214 |
+
self.audioencoder_optimizer.zero_grad()
|
215 |
+
|
216 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
|
217 |
+
loss.backward()
|
218 |
+
|
219 |
+
grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
|
220 |
+
|
221 |
+
if torch.isnan(grad).sum() > 0:
|
222 |
+
print('fuck')
|
223 |
+
|
224 |
+
loss_dict['grad'] = grad.item()
|
225 |
+
loss_dict['ce_loss'] = loss.item()
|
226 |
+
self.generator_optimizer.step()
|
227 |
+
if self.audio:
|
228 |
+
self.audioencoder_optimizer.step()
|
229 |
+
|
230 |
+
return total_loss, loss_dict
|
231 |
+
|
232 |
+
def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None,
|
233 |
+
continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs):
|
234 |
+
'''
|
235 |
+
initial_pose: (B, C, T), normalized
|
236 |
+
(aud_fn, txgfile) -> generated motion (B, T, C)
|
237 |
+
'''
|
238 |
+
output = []
|
239 |
+
|
240 |
+
assert self.args.infer, "train mode"
|
241 |
+
self.generator.eval()
|
242 |
+
self.g_body.eval()
|
243 |
+
self.g_hand.eval()
|
244 |
+
|
245 |
+
if continuity:
|
246 |
+
aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps)
|
247 |
+
else:
|
248 |
+
aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am)
|
249 |
+
aud_feat = aud_feat.transpose(1, 0)
|
250 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
251 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
|
252 |
+
|
253 |
+
if id is None:
|
254 |
+
id = torch.tensor([0]).to(self.device)
|
255 |
+
else:
|
256 |
+
id = id.repeat(B)
|
257 |
+
|
258 |
+
with torch.no_grad():
|
259 |
+
aud_feat = aud_feat.permute(0, 2, 1)
|
260 |
+
if continuity:
|
261 |
+
self.audioencoder.eval()
|
262 |
+
pre_pose = {}
|
263 |
+
pre_pose['b'] = pre_pose['h'] = None
|
264 |
+
pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose)
|
265 |
+
pre_pose['b'] = body_0[:, :, -4:].transpose(1,2)
|
266 |
+
pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2)
|
267 |
+
_, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose)
|
268 |
+
body = torch.cat([body_0, body_1], dim=2)
|
269 |
+
hand = torch.cat([hand_0, hand_1], dim=2)
|
270 |
+
|
271 |
+
else:
|
272 |
+
if self.audio:
|
273 |
+
self.audioencoder.eval()
|
274 |
+
audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
|
275 |
+
latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio)
|
276 |
+
else:
|
277 |
+
latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B)
|
278 |
+
|
279 |
+
body_latents = latents[..., 0]
|
280 |
+
hand_latents = latents[..., 1]
|
281 |
+
|
282 |
+
body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
|
283 |
+
hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
|
284 |
+
|
285 |
+
pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy()
|
286 |
+
|
287 |
+
output = pred_poses
|
288 |
+
|
289 |
+
return output
|
290 |
+
|
291 |
+
def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None):
|
292 |
+
audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
|
293 |
+
latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio,
|
294 |
+
pre_latents=pre_latents, pre_audio=pre_audio)
|
295 |
+
|
296 |
+
body_latents = latents[..., 0]
|
297 |
+
hand_latents = latents[..., 1]
|
298 |
+
|
299 |
+
body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1],
|
300 |
+
latents=body_latents, pre_state=pre_pose['b'])
|
301 |
+
hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1],
|
302 |
+
latents=hand_latents, pre_state=pre_pose['h'])
|
303 |
+
|
304 |
+
return latents, audio, body, hand
|
305 |
+
|
306 |
+
def generate(self, aud, id, frame_num=0):
|
307 |
+
|
308 |
+
self.generator.eval()
|
309 |
+
self.g_body.eval()
|
310 |
+
self.g_hand.eval()
|
311 |
+
aud_feat = aud.permute(0, 2, 1)
|
312 |
+
if self.audio:
|
313 |
+
self.audioencoder.eval()
|
314 |
+
audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
|
315 |
+
latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio)
|
316 |
+
else:
|
317 |
+
latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0])
|
318 |
+
|
319 |
+
body_latents = latents[..., 0]
|
320 |
+
hand_latents = latents[..., 1]
|
321 |
+
|
322 |
+
body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
|
323 |
+
hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
|
324 |
+
|
325 |
+
pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2)
|
326 |
+
return pred_poses
|
nets/smplx_body_vq.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from torch.optim.lr_scheduler import StepLR
|
5 |
+
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
|
8 |
+
from nets.layers import *
|
9 |
+
from nets.base import TrainWrapperBaseClass
|
10 |
+
from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G
|
11 |
+
from nets.spg.vqvae_1d import VQVAE as s2g_body
|
12 |
+
from nets.utils import parse_audio, denormalize
|
13 |
+
from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
|
14 |
+
import numpy as np
|
15 |
+
import torch.optim as optim
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from sklearn.preprocessing import normalize
|
18 |
+
|
19 |
+
from data_utils.lower_body import c_index, c_index_3d, c_index_6d
|
20 |
+
|
21 |
+
|
22 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
23 |
+
'''
|
24 |
+
a wrapper receving a batch from data_utils and calculate loss
|
25 |
+
'''
|
26 |
+
|
27 |
+
def __init__(self, args, config):
|
28 |
+
self.args = args
|
29 |
+
self.config = config
|
30 |
+
self.device = torch.device(self.args.gpu)
|
31 |
+
self.global_step = 0
|
32 |
+
|
33 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
34 |
+
self.expression = self.config.Data.pose.expression
|
35 |
+
self.epoch = 0
|
36 |
+
self.init_params()
|
37 |
+
self.num_classes = 4
|
38 |
+
self.composition = self.config.Model.composition
|
39 |
+
if self.composition:
|
40 |
+
self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
|
41 |
+
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
42 |
+
self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
|
43 |
+
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
44 |
+
else:
|
45 |
+
self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num,
|
46 |
+
num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
|
47 |
+
|
48 |
+
self.discriminator = None
|
49 |
+
|
50 |
+
if self.convert_to_6d:
|
51 |
+
self.c_index = c_index_6d
|
52 |
+
else:
|
53 |
+
self.c_index = c_index_3d
|
54 |
+
|
55 |
+
super().__init__(args, config)
|
56 |
+
|
57 |
+
def init_optimizer(self):
|
58 |
+
print('using Adam')
|
59 |
+
if self.composition:
|
60 |
+
self.g_body_optimizer = optim.Adam(
|
61 |
+
self.g_body.parameters(),
|
62 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
63 |
+
betas=[0.9, 0.999]
|
64 |
+
)
|
65 |
+
self.g_hand_optimizer = optim.Adam(
|
66 |
+
self.g_hand.parameters(),
|
67 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
68 |
+
betas=[0.9, 0.999]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.g_optimizer = optim.Adam(
|
72 |
+
self.g.parameters(),
|
73 |
+
lr=self.config.Train.learning_rate.generator_learning_rate,
|
74 |
+
betas=[0.9, 0.999]
|
75 |
+
)
|
76 |
+
|
77 |
+
def state_dict(self):
|
78 |
+
if self.composition:
|
79 |
+
model_state = {
|
80 |
+
'g_body': self.g_body.state_dict(),
|
81 |
+
'g_body_optim': self.g_body_optimizer.state_dict(),
|
82 |
+
'g_hand': self.g_hand.state_dict(),
|
83 |
+
'g_hand_optim': self.g_hand_optimizer.state_dict(),
|
84 |
+
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
|
85 |
+
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
|
86 |
+
}
|
87 |
+
else:
|
88 |
+
model_state = {
|
89 |
+
'g': self.g.state_dict(),
|
90 |
+
'g_optim': self.g_optimizer.state_dict(),
|
91 |
+
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
|
92 |
+
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
|
93 |
+
}
|
94 |
+
return model_state
|
95 |
+
|
96 |
+
def init_params(self):
|
97 |
+
if self.config.Data.pose.convert_to_6d:
|
98 |
+
scale = 2
|
99 |
+
else:
|
100 |
+
scale = 1
|
101 |
+
|
102 |
+
global_orient = round(0 * scale)
|
103 |
+
leye_pose = reye_pose = round(0 * scale)
|
104 |
+
jaw_pose = round(0 * scale)
|
105 |
+
body_pose = round((63 - 24) * scale)
|
106 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
107 |
+
if self.expression:
|
108 |
+
expression = 100
|
109 |
+
else:
|
110 |
+
expression = 0
|
111 |
+
|
112 |
+
b_j = 0
|
113 |
+
jaw_dim = jaw_pose
|
114 |
+
b_e = b_j + jaw_dim
|
115 |
+
eye_dim = leye_pose + reye_pose
|
116 |
+
b_b = b_e + eye_dim
|
117 |
+
body_dim = global_orient + body_pose
|
118 |
+
b_h = b_b + body_dim
|
119 |
+
hand_dim = left_hand_pose + right_hand_pose
|
120 |
+
b_f = b_h + hand_dim
|
121 |
+
face_dim = expression
|
122 |
+
|
123 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
124 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
125 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
126 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
127 |
+
|
128 |
+
def __call__(self, bat):
|
129 |
+
# assert (not self.args.infer), "infer mode"
|
130 |
+
self.global_step += 1
|
131 |
+
|
132 |
+
total_loss = None
|
133 |
+
loss_dict = {}
|
134 |
+
|
135 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
136 |
+
|
137 |
+
# id = bat['speaker'].to(self.device) - 20
|
138 |
+
# id = F.one_hot(id, self.num_classes)
|
139 |
+
|
140 |
+
poses = poses[:, self.c_index, :]
|
141 |
+
gt_poses = poses.permute(0, 2, 1)
|
142 |
+
b_poses = gt_poses[..., :self.each_dim[1]]
|
143 |
+
h_poses = gt_poses[..., self.each_dim[1]:]
|
144 |
+
|
145 |
+
if self.composition:
|
146 |
+
loss = 0
|
147 |
+
loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss)
|
148 |
+
loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss)
|
149 |
+
else:
|
150 |
+
loss = 0
|
151 |
+
loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
|
152 |
+
|
153 |
+
return total_loss, loss_dict
|
154 |
+
|
155 |
+
def vq_train(self, gt, name, model, dict, total_loss, pre=None):
|
156 |
+
e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre)
|
157 |
+
loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre)
|
158 |
+
# total_loss = total_loss + loss
|
159 |
+
|
160 |
+
if name == 'b':
|
161 |
+
optimizer_name = 'g_body_optimizer'
|
162 |
+
elif name == 'h':
|
163 |
+
optimizer_name = 'g_hand_optimizer'
|
164 |
+
elif name == 'g':
|
165 |
+
optimizer_name = 'g_optimizer'
|
166 |
+
else:
|
167 |
+
raise ValueError("model's name must be b or h")
|
168 |
+
optimizer = getattr(self, optimizer_name)
|
169 |
+
optimizer.zero_grad()
|
170 |
+
loss.backward()
|
171 |
+
optimizer.step()
|
172 |
+
|
173 |
+
for key in list(loss_dict.keys()):
|
174 |
+
dict[name + key] = loss_dict.get(key, 0).item()
|
175 |
+
return dict, total_loss
|
176 |
+
|
177 |
+
def get_loss(self,
|
178 |
+
pred_poses,
|
179 |
+
gt_poses,
|
180 |
+
e_q_loss,
|
181 |
+
pre=None
|
182 |
+
):
|
183 |
+
loss_dict = {}
|
184 |
+
|
185 |
+
|
186 |
+
rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
|
187 |
+
v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
|
188 |
+
v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
|
189 |
+
velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
|
190 |
+
|
191 |
+
if pre is None:
|
192 |
+
f0_vel = 0
|
193 |
+
else:
|
194 |
+
v0_pr = pred_poses[:, 0] - pre[:, -1]
|
195 |
+
v0_gt = gt_poses[:, 0] - pre[:, -1]
|
196 |
+
f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
|
197 |
+
|
198 |
+
gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel
|
199 |
+
|
200 |
+
loss_dict['rec_loss'] = rec_loss
|
201 |
+
loss_dict['velocity_loss'] = velocity_loss
|
202 |
+
# loss_dict['e_q_loss'] = e_q_loss
|
203 |
+
if pre is not None:
|
204 |
+
loss_dict['f0_vel'] = f0_vel
|
205 |
+
|
206 |
+
return gen_loss, loss_dict
|
207 |
+
|
208 |
+
def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False,
|
209 |
+
id=None, fps=15, sr=22000, smooth=False, **kwargs):
|
210 |
+
'''
|
211 |
+
initial_pose: (B, C, T), normalized
|
212 |
+
(aud_fn, txgfile) -> generated motion (B, T, C)
|
213 |
+
'''
|
214 |
+
output = []
|
215 |
+
|
216 |
+
assert self.args.infer, "train mode"
|
217 |
+
if self.composition:
|
218 |
+
self.g_body.eval()
|
219 |
+
self.g_hand.eval()
|
220 |
+
else:
|
221 |
+
self.g.eval()
|
222 |
+
|
223 |
+
if self.config.Data.pose.normalization:
|
224 |
+
assert norm_stats is not None
|
225 |
+
data_mean = norm_stats[0]
|
226 |
+
data_std = norm_stats[1]
|
227 |
+
|
228 |
+
# assert initial_pose.shape[-1] == pre_length
|
229 |
+
if initial_pose is not None:
|
230 |
+
gt = initial_pose[:, :, :].to(self.device).to(torch.float32)
|
231 |
+
pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32)
|
232 |
+
poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
|
233 |
+
B = pre_poses.shape[0]
|
234 |
+
else:
|
235 |
+
gt = None
|
236 |
+
pre_poses = None
|
237 |
+
B = 1
|
238 |
+
|
239 |
+
if type(aud_fn) == torch.Tensor:
|
240 |
+
aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device)
|
241 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
242 |
+
else:
|
243 |
+
aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
|
244 |
+
aud_feat = aud_feat[:, :]
|
245 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
246 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
247 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
|
248 |
+
|
249 |
+
# pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32)
|
250 |
+
if id is None:
|
251 |
+
id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device)
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
aud_feat = aud_feat.permute(0, 2, 1)
|
255 |
+
gt_poses = gt[:, self.c_index].permute(0, 2, 1)
|
256 |
+
if self.composition:
|
257 |
+
if continuity:
|
258 |
+
pred_poses_body = []
|
259 |
+
pred_poses_hand = []
|
260 |
+
pre_b = None
|
261 |
+
pre_h = None
|
262 |
+
for i in range(5):
|
263 |
+
_, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b)
|
264 |
+
pre_b = pred_body[..., -1:].transpose(1,2)
|
265 |
+
pred_poses_body.append(pred_body)
|
266 |
+
_, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h)
|
267 |
+
pre_h = pred_hand[..., -1:].transpose(1,2)
|
268 |
+
pred_poses_hand.append(pred_hand)
|
269 |
+
|
270 |
+
pred_poses_body = torch.cat(pred_poses_body, dim=2)
|
271 |
+
pred_poses_hand = torch.cat(pred_poses_hand, dim=2)
|
272 |
+
else:
|
273 |
+
_, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
|
274 |
+
_, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
|
275 |
+
pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1)
|
276 |
+
else:
|
277 |
+
_, pred_poses = self.g(gt_poses=gt_poses, id=id)
|
278 |
+
pred_poses = pred_poses.transpose(1, 2).cpu().numpy()
|
279 |
+
output = pred_poses
|
280 |
+
|
281 |
+
if self.config.Data.pose.normalization:
|
282 |
+
output = denormalize(output, data_mean, data_std)
|
283 |
+
|
284 |
+
if smooth:
|
285 |
+
lamda = 0.8
|
286 |
+
smooth_f = 10
|
287 |
+
frame = 149
|
288 |
+
for i in range(smooth_f):
|
289 |
+
f = frame + i
|
290 |
+
l = lamda * (i + 1) / smooth_f
|
291 |
+
output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f]
|
292 |
+
|
293 |
+
output = np.concatenate(output, axis=1)
|
294 |
+
|
295 |
+
return output
|
296 |
+
|
297 |
+
def load_state_dict(self, state_dict):
|
298 |
+
if self.composition:
|
299 |
+
self.g_body.load_state_dict(state_dict['g_body'])
|
300 |
+
self.g_hand.load_state_dict(state_dict['g_hand'])
|
301 |
+
else:
|
302 |
+
self.g.load_state_dict(state_dict['g'])
|
nets/smplx_face.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
from nets.layers import *
|
7 |
+
from nets.base import TrainWrapperBaseClass
|
8 |
+
# from nets.spg.faceformer import Faceformer
|
9 |
+
from nets.spg.s2g_face import Generator as s2g_face
|
10 |
+
from losses import KeypointLoss
|
11 |
+
from nets.utils import denormalize
|
12 |
+
from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
|
13 |
+
import numpy as np
|
14 |
+
import torch.optim as optim
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from sklearn.preprocessing import normalize
|
17 |
+
import smplx
|
18 |
+
|
19 |
+
|
20 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
21 |
+
'''
|
22 |
+
a wrapper receving a batch from data_utils and calculate loss
|
23 |
+
'''
|
24 |
+
|
25 |
+
def __init__(self, args, config):
|
26 |
+
self.args = args
|
27 |
+
self.config = config
|
28 |
+
self.device = torch.device(self.args.gpu)
|
29 |
+
self.global_step = 0
|
30 |
+
|
31 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
32 |
+
self.expression = self.config.Data.pose.expression
|
33 |
+
self.epoch = 0
|
34 |
+
self.init_params()
|
35 |
+
self.num_classes = 4
|
36 |
+
|
37 |
+
self.generator = s2g_face(
|
38 |
+
n_poses=self.config.Data.pose.generate_length,
|
39 |
+
each_dim=self.each_dim,
|
40 |
+
dim_list=self.dim_list,
|
41 |
+
training=not self.args.infer,
|
42 |
+
device=self.device,
|
43 |
+
identity=False if self.convert_to_6d else True,
|
44 |
+
num_classes=self.num_classes,
|
45 |
+
).to(self.device)
|
46 |
+
|
47 |
+
# self.generator = Faceformer().to(self.device)
|
48 |
+
|
49 |
+
self.discriminator = None
|
50 |
+
self.am = None
|
51 |
+
|
52 |
+
self.MSELoss = KeypointLoss().to(self.device)
|
53 |
+
super().__init__(args, config)
|
54 |
+
|
55 |
+
def init_optimizer(self):
|
56 |
+
self.generator_optimizer = optim.SGD(
|
57 |
+
filter(lambda p: p.requires_grad,self.generator.parameters()),
|
58 |
+
lr=0.001,
|
59 |
+
momentum=0.9,
|
60 |
+
nesterov=False,
|
61 |
+
)
|
62 |
+
|
63 |
+
def init_params(self):
|
64 |
+
if self.convert_to_6d:
|
65 |
+
scale = 2
|
66 |
+
else:
|
67 |
+
scale = 1
|
68 |
+
|
69 |
+
global_orient = round(3 * scale)
|
70 |
+
leye_pose = reye_pose = round(3 * scale)
|
71 |
+
jaw_pose = round(3 * scale)
|
72 |
+
body_pose = round(63 * scale)
|
73 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
74 |
+
if self.expression:
|
75 |
+
expression = 100
|
76 |
+
else:
|
77 |
+
expression = 0
|
78 |
+
|
79 |
+
b_j = 0
|
80 |
+
jaw_dim = jaw_pose
|
81 |
+
b_e = b_j + jaw_dim
|
82 |
+
eye_dim = leye_pose + reye_pose
|
83 |
+
b_b = b_e + eye_dim
|
84 |
+
body_dim = global_orient + body_pose
|
85 |
+
b_h = b_b + body_dim
|
86 |
+
hand_dim = left_hand_pose + right_hand_pose
|
87 |
+
b_f = b_h + hand_dim
|
88 |
+
face_dim = expression
|
89 |
+
|
90 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
91 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
|
92 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
93 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
94 |
+
|
95 |
+
def __call__(self, bat):
|
96 |
+
# assert (not self.args.infer), "infer mode"
|
97 |
+
self.global_step += 1
|
98 |
+
|
99 |
+
total_loss = None
|
100 |
+
loss_dict = {}
|
101 |
+
|
102 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
103 |
+
id = bat['speaker'].to(self.device) - 20
|
104 |
+
id = F.one_hot(id, self.num_classes)
|
105 |
+
|
106 |
+
aud = aud.permute(0, 2, 1)
|
107 |
+
gt_poses = poses.permute(0, 2, 1)
|
108 |
+
|
109 |
+
if self.expression:
|
110 |
+
expression = bat['expression'].to(self.device).to(torch.float32)
|
111 |
+
gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
|
112 |
+
|
113 |
+
pred_poses, _ = self.generator(
|
114 |
+
aud,
|
115 |
+
gt_poses,
|
116 |
+
id,
|
117 |
+
)
|
118 |
+
|
119 |
+
G_loss, G_loss_dict = self.get_loss(
|
120 |
+
pred_poses=pred_poses,
|
121 |
+
gt_poses=gt_poses,
|
122 |
+
pre_poses=None,
|
123 |
+
mode='training_G',
|
124 |
+
gt_conf=None,
|
125 |
+
aud=aud,
|
126 |
+
)
|
127 |
+
|
128 |
+
self.generator_optimizer.zero_grad()
|
129 |
+
G_loss.backward()
|
130 |
+
grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
|
131 |
+
loss_dict['grad'] = grad.item()
|
132 |
+
self.generator_optimizer.step()
|
133 |
+
|
134 |
+
for key in list(G_loss_dict.keys()):
|
135 |
+
loss_dict[key] = G_loss_dict.get(key, 0).item()
|
136 |
+
|
137 |
+
return total_loss, loss_dict
|
138 |
+
|
139 |
+
def get_loss(self,
|
140 |
+
pred_poses,
|
141 |
+
gt_poses,
|
142 |
+
pre_poses,
|
143 |
+
aud,
|
144 |
+
mode='training_G',
|
145 |
+
gt_conf=None,
|
146 |
+
exp=1,
|
147 |
+
gt_nzero=None,
|
148 |
+
pre_nzero=None,
|
149 |
+
):
|
150 |
+
loss_dict = {}
|
151 |
+
|
152 |
+
|
153 |
+
[b_j, b_e, b_b, b_h, b_f] = self.dim_list
|
154 |
+
|
155 |
+
MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
|
156 |
+
if self.expression:
|
157 |
+
expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
|
158 |
+
else:
|
159 |
+
expl = 0
|
160 |
+
|
161 |
+
gen_loss = expl + MSELoss
|
162 |
+
|
163 |
+
loss_dict['MSELoss'] = MSELoss
|
164 |
+
if self.expression:
|
165 |
+
loss_dict['exp_loss'] = expl
|
166 |
+
|
167 |
+
return gen_loss, loss_dict
|
168 |
+
|
169 |
+
def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
|
170 |
+
'''
|
171 |
+
initial_pose: (B, C, T), normalized
|
172 |
+
(aud_fn, txgfile) -> generated motion (B, T, C)
|
173 |
+
'''
|
174 |
+
output = []
|
175 |
+
|
176 |
+
# assert self.args.infer, "train mode"
|
177 |
+
self.generator.eval()
|
178 |
+
|
179 |
+
if self.config.Data.pose.normalization:
|
180 |
+
assert norm_stats is not None
|
181 |
+
data_mean = norm_stats[0]
|
182 |
+
data_std = norm_stats[1]
|
183 |
+
|
184 |
+
# assert initial_pose.shape[-1] == pre_length
|
185 |
+
if initial_pose is not None:
|
186 |
+
gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
|
187 |
+
pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
|
188 |
+
poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
|
189 |
+
B = pre_poses.shape[0]
|
190 |
+
else:
|
191 |
+
gt = None
|
192 |
+
pre_poses=None
|
193 |
+
B = 1
|
194 |
+
|
195 |
+
if type(aud_fn) == torch.Tensor:
|
196 |
+
aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
|
197 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
198 |
+
else:
|
199 |
+
aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
|
200 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
201 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
|
202 |
+
if frame is None:
|
203 |
+
frame = aud_feat.shape[2]*30//16000
|
204 |
+
#
|
205 |
+
if id is None:
|
206 |
+
id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
|
207 |
+
else:
|
208 |
+
id = F.one_hot(id, self.num_classes).to(self.generator.device)
|
209 |
+
|
210 |
+
with torch.no_grad():
|
211 |
+
pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
|
212 |
+
pred_poses = pred_poses.cpu().numpy()
|
213 |
+
output = pred_poses
|
214 |
+
|
215 |
+
if self.config.Data.pose.normalization:
|
216 |
+
output = denormalize(output, data_mean, data_std)
|
217 |
+
|
218 |
+
return output
|
219 |
+
|
220 |
+
|
221 |
+
def generate(self, wv2_feat, frame):
|
222 |
+
'''
|
223 |
+
initial_pose: (B, C, T), normalized
|
224 |
+
(aud_fn, txgfile) -> generated motion (B, T, C)
|
225 |
+
'''
|
226 |
+
output = []
|
227 |
+
|
228 |
+
# assert self.args.infer, "train mode"
|
229 |
+
self.generator.eval()
|
230 |
+
|
231 |
+
B = 1
|
232 |
+
|
233 |
+
id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
|
234 |
+
id = id.repeat(wv2_feat.shape[0], 1)
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
|
238 |
+
return pred_poses
|
nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc
ADDED
Binary file (5.16 kB). View file
|
|
nets/spg/__pycache__/s2g_face.cpython-37.pyc
ADDED
Binary file (6.96 kB). View file
|
|
nets/spg/__pycache__/s2glayers.cpython-37.pyc
ADDED
Binary file (11.5 kB). View file
|
|
nets/spg/__pycache__/vqvae_1d.cpython-37.pyc
ADDED
Binary file (8.07 kB). View file
|
|
nets/spg/__pycache__/vqvae_modules.cpython-37.pyc
ADDED
Binary file (10.6 kB). View file
|
|
nets/spg/__pycache__/wav2vec.cpython-37.pyc
ADDED
Binary file (3.89 kB). View file
|
|
nets/spg/gated_pixelcnn_v2.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def weights_init(m):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find('Conv') != -1:
|
9 |
+
try:
|
10 |
+
nn.init.xavier_uniform_(m.weight.data)
|
11 |
+
m.bias.data.fill_(0)
|
12 |
+
except AttributeError:
|
13 |
+
print("Skipping initialization of ", classname)
|
14 |
+
|
15 |
+
|
16 |
+
class GatedActivation(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x, y = x.chunk(2, dim=1)
|
22 |
+
return F.tanh(x) * F.sigmoid(y)
|
23 |
+
|
24 |
+
|
25 |
+
class GatedMaskedConv2d(nn.Module):
|
26 |
+
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
|
27 |
+
super().__init__()
|
28 |
+
assert kernel % 2 == 1, print("Kernel size must be odd")
|
29 |
+
self.mask_type = mask_type
|
30 |
+
self.residual = residual
|
31 |
+
self.bh_model = bh_model
|
32 |
+
|
33 |
+
self.class_cond_embedding = nn.Embedding(n_classes, 2 * dim)
|
34 |
+
self.class_cond_embedding = self.class_cond_embedding.to("cpu")
|
35 |
+
|
36 |
+
kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
|
37 |
+
padding_shp = (kernel // 2, 1 if self.bh_model else 0)
|
38 |
+
self.vert_stack = nn.Conv2d(
|
39 |
+
dim, dim * 2,
|
40 |
+
kernel_shp, 1, padding_shp
|
41 |
+
)
|
42 |
+
|
43 |
+
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
|
44 |
+
|
45 |
+
kernel_shp = (1, 2)
|
46 |
+
padding_shp = (0, 1)
|
47 |
+
self.horiz_stack = nn.Conv2d(
|
48 |
+
dim, dim * 2,
|
49 |
+
kernel_shp, 1, padding_shp
|
50 |
+
)
|
51 |
+
|
52 |
+
self.horiz_resid = nn.Conv2d(dim, dim, 1)
|
53 |
+
|
54 |
+
self.gate = GatedActivation()
|
55 |
+
|
56 |
+
def make_causal(self):
|
57 |
+
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
|
58 |
+
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
|
59 |
+
|
60 |
+
def forward(self, x_v, x_h, h):
|
61 |
+
if self.mask_type == 'A':
|
62 |
+
self.make_causal()
|
63 |
+
|
64 |
+
h = h.to(self.class_cond_embedding.weight.device)
|
65 |
+
h = self.class_cond_embedding(h)
|
66 |
+
|
67 |
+
h_vert = self.vert_stack(x_v)
|
68 |
+
h_vert = h_vert[:, :, :x_v.size(-2), :]
|
69 |
+
out_v = self.gate(h_vert + h[:, :, None, None])
|
70 |
+
|
71 |
+
if self.bh_model:
|
72 |
+
h_horiz = self.horiz_stack(x_h)
|
73 |
+
h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
|
74 |
+
v2h = self.vert_to_horiz(h_vert)
|
75 |
+
|
76 |
+
out = self.gate(v2h + h_horiz + h[:, :, None, None])
|
77 |
+
if self.residual:
|
78 |
+
out_h = self.horiz_resid(out) + x_h
|
79 |
+
else:
|
80 |
+
out_h = self.horiz_resid(out)
|
81 |
+
else:
|
82 |
+
if self.residual:
|
83 |
+
out_v = self.horiz_resid(out_v) + x_v
|
84 |
+
else:
|
85 |
+
out_v = self.horiz_resid(out_v)
|
86 |
+
out_h = out_v
|
87 |
+
|
88 |
+
return out_v, out_h
|
89 |
+
|
90 |
+
|
91 |
+
class GatedPixelCNN(nn.Module):
|
92 |
+
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
|
93 |
+
super().__init__()
|
94 |
+
self.dim = dim
|
95 |
+
self.audio = audio
|
96 |
+
self.bh_model = bh_model
|
97 |
+
|
98 |
+
if self.audio:
|
99 |
+
self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
|
100 |
+
self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
|
101 |
+
self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
|
102 |
+
|
103 |
+
# Create embedding layer to embed input
|
104 |
+
self.embedding = nn.Embedding(input_dim, dim)
|
105 |
+
|
106 |
+
# Building the PixelCNN layer by layer
|
107 |
+
self.layers = nn.ModuleList()
|
108 |
+
|
109 |
+
# Initial block with Mask-A convolution
|
110 |
+
# Rest with Mask-B convolutions
|
111 |
+
for i in range(n_layers):
|
112 |
+
mask_type = 'A' if i == 0 else 'B'
|
113 |
+
kernel = 7 if i == 0 else 3
|
114 |
+
residual = False if i == 0 else True
|
115 |
+
|
116 |
+
self.layers.append(
|
117 |
+
GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
|
118 |
+
)
|
119 |
+
|
120 |
+
# Add the output layer
|
121 |
+
self.output_conv = nn.Sequential(
|
122 |
+
nn.Conv2d(dim, 512, 1),
|
123 |
+
nn.ReLU(True),
|
124 |
+
nn.Conv2d(512, input_dim, 1)
|
125 |
+
)
|
126 |
+
|
127 |
+
self.apply(weights_init)
|
128 |
+
|
129 |
+
self.dp = nn.Dropout(0.1)
|
130 |
+
self.to("cpu")
|
131 |
+
|
132 |
+
def forward(self, x, label, aud=None):
|
133 |
+
shp = x.size() + (-1,)
|
134 |
+
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
|
135 |
+
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
|
136 |
+
|
137 |
+
x_v, x_h = (x, x)
|
138 |
+
for i, layer in enumerate(self.layers):
|
139 |
+
if i == 1 and self.audio is True:
|
140 |
+
aud = self.embedding_aud(aud)
|
141 |
+
a = torch.ones(aud.shape[-2]).to(aud.device)
|
142 |
+
a = self.dp(a)
|
143 |
+
aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
|
144 |
+
x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
|
145 |
+
if self.bh_model:
|
146 |
+
x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
|
147 |
+
x_v, x_h = layer(x_v, x_h, label)
|
148 |
+
|
149 |
+
if self.bh_model:
|
150 |
+
return self.output_conv(x_h)
|
151 |
+
else:
|
152 |
+
return self.output_conv(x_v)
|
153 |
+
|
154 |
+
def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
|
155 |
+
param = next(self.parameters())
|
156 |
+
x = torch.zeros(
|
157 |
+
(batch_size, *shape),
|
158 |
+
dtype=torch.int64, device=param.device
|
159 |
+
)
|
160 |
+
if pre_latents is not None:
|
161 |
+
x = torch.cat([pre_latents, x], dim=1)
|
162 |
+
aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
|
163 |
+
h0 = pre_latents.shape[1]
|
164 |
+
h = h0 + shape[0]
|
165 |
+
else:
|
166 |
+
h0 = 0
|
167 |
+
h = shape[0]
|
168 |
+
|
169 |
+
for i in range(h0, h):
|
170 |
+
for j in range(shape[1]):
|
171 |
+
if self.audio:
|
172 |
+
logits = self.forward(x, label, aud_feat)
|
173 |
+
else:
|
174 |
+
logits = self.forward(x, label)
|
175 |
+
probs = F.softmax(logits[:, :, i, j], -1)
|
176 |
+
x.data[:, i, j].copy_(
|
177 |
+
probs.multinomial(1).squeeze().data
|
178 |
+
)
|
179 |
+
return x[:, h0:h]
|
nets/spg/s2g_face.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
not exactly the same as the official repo but the results are good
|
3 |
+
'''
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
+
from transformers import Wav2Vec2Processor
|
8 |
+
|
9 |
+
from .wav2vec import Wav2Vec2Model
|
10 |
+
from torchaudio.sox_effects import apply_effects_tensor
|
11 |
+
|
12 |
+
sys.path.append(os.getcwd())
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torchaudio as ta
|
19 |
+
import math
|
20 |
+
from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
|
21 |
+
|
22 |
+
|
23 |
+
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
|
24 |
+
|
25 |
+
|
26 |
+
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
|
27 |
+
"""
|
28 |
+
:param audio: 1 x T tensor containing a 16kHz audio signal
|
29 |
+
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
|
30 |
+
:param chunk_size: number of audio samples per chunk
|
31 |
+
:return: num_chunks x chunk_size tensor containing sliced audio
|
32 |
+
"""
|
33 |
+
samples_per_frame = 16000 // frame_rate
|
34 |
+
padding = (chunk_size - samples_per_frame) // 2
|
35 |
+
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
|
36 |
+
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
|
37 |
+
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
|
38 |
+
return audio
|
39 |
+
|
40 |
+
|
41 |
+
class MeshtalkEncoder(nn.Module):
|
42 |
+
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
|
43 |
+
"""
|
44 |
+
:param latent_dim: size of the latent audio embedding
|
45 |
+
:param model_name: name of the model, used to load and save the model
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.melspec = ta.transforms.MelSpectrogram(
|
50 |
+
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
|
51 |
+
)
|
52 |
+
|
53 |
+
conv_len = 5
|
54 |
+
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
|
55 |
+
self.weights_init(self.convert_dimensions)
|
56 |
+
self.receptive_field = conv_len
|
57 |
+
|
58 |
+
convs = []
|
59 |
+
for i in range(6):
|
60 |
+
dilation = 2 * (i % 3 + 1)
|
61 |
+
self.receptive_field += (conv_len - 1) * dilation
|
62 |
+
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
|
63 |
+
self.weights_init(convs[-1])
|
64 |
+
self.convs = torch.nn.ModuleList(convs)
|
65 |
+
self.code = torch.nn.Linear(128, latent_dim)
|
66 |
+
|
67 |
+
self.apply(lambda x: self.weights_init(x))
|
68 |
+
|
69 |
+
def weights_init(self, m):
|
70 |
+
if isinstance(m, torch.nn.Conv1d):
|
71 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
72 |
+
try:
|
73 |
+
torch.nn.init.constant_(m.bias, .01)
|
74 |
+
except:
|
75 |
+
pass
|
76 |
+
|
77 |
+
def forward(self, audio: torch.Tensor):
|
78 |
+
"""
|
79 |
+
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
|
80 |
+
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
|
81 |
+
"""
|
82 |
+
B, T = audio.shape[0], audio.shape[1]
|
83 |
+
x = self.melspec(audio).squeeze(1)
|
84 |
+
x = torch.log(x.clamp(min=1e-10, max=None))
|
85 |
+
if T == 1:
|
86 |
+
x = x.unsqueeze(1)
|
87 |
+
|
88 |
+
# Convert to the right dimensionality
|
89 |
+
x = x.view(-1, x.shape[2], x.shape[3])
|
90 |
+
x = F.leaky_relu(self.convert_dimensions(x), .2)
|
91 |
+
|
92 |
+
# Process stacks
|
93 |
+
for conv in self.convs:
|
94 |
+
x_ = F.leaky_relu(conv(x), .2)
|
95 |
+
if self.training:
|
96 |
+
x_ = F.dropout(x_, .2)
|
97 |
+
l = (x.shape[2] - x_.shape[2]) // 2
|
98 |
+
x = (x[:, :, l:-l] + x_) / 2
|
99 |
+
|
100 |
+
x = torch.mean(x, dim=-1)
|
101 |
+
x = x.view(B, T, x.shape[-1])
|
102 |
+
x = self.code(x)
|
103 |
+
|
104 |
+
return {"code": x}
|
105 |
+
|
106 |
+
|
107 |
+
class AudioEncoder(nn.Module):
|
108 |
+
def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
|
109 |
+
super().__init__()
|
110 |
+
self.identity = identity
|
111 |
+
if self.identity:
|
112 |
+
in_dim = in_dim + 64
|
113 |
+
self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
|
114 |
+
self.first_net = SeqTranslator1D(in_dim, out_dim,
|
115 |
+
min_layers_num=3,
|
116 |
+
residual=True,
|
117 |
+
norm='ln'
|
118 |
+
)
|
119 |
+
self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
|
120 |
+
self.dropout = nn.Dropout(0.1)
|
121 |
+
# self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
|
122 |
+
|
123 |
+
def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
|
124 |
+
|
125 |
+
spectrogram = spectrogram
|
126 |
+
spectrogram = self.dropout(spectrogram)
|
127 |
+
if self.identity:
|
128 |
+
id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
|
129 |
+
id = self.id_mlp(id)
|
130 |
+
spectrogram = torch.cat([spectrogram, id], dim=1)
|
131 |
+
x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
|
132 |
+
if time_steps is not None:
|
133 |
+
x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
|
134 |
+
# x1, _ = self.att(x1, x1, x1)
|
135 |
+
# x1, hidden_state = self.grus(x1)
|
136 |
+
# x1 = x1.permute(0, 2, 1)
|
137 |
+
hidden_state=None
|
138 |
+
|
139 |
+
return x1, hidden_state
|
140 |
+
|
141 |
+
|
142 |
+
class Generator(nn.Module):
|
143 |
+
def __init__(self,
|
144 |
+
n_poses,
|
145 |
+
each_dim: list,
|
146 |
+
dim_list: list,
|
147 |
+
training=False,
|
148 |
+
device=None,
|
149 |
+
identity=True,
|
150 |
+
num_classes=0,
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.training = training
|
155 |
+
self.device = device
|
156 |
+
self.gen_length = n_poses
|
157 |
+
self.identity = identity
|
158 |
+
|
159 |
+
norm = 'ln'
|
160 |
+
in_dim = 256
|
161 |
+
out_dim = 256
|
162 |
+
|
163 |
+
self.encoder_choice = 'faceformer'
|
164 |
+
|
165 |
+
if self.encoder_choice == 'meshtalk':
|
166 |
+
self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
|
167 |
+
elif self.encoder_choice == 'faceformer':
|
168 |
+
# wav2vec 2.0 weights initialization
|
169 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
|
170 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
171 |
+
self.audio_feature_map = nn.Linear(768, in_dim)
|
172 |
+
else:
|
173 |
+
self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
|
174 |
+
|
175 |
+
self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
|
176 |
+
|
177 |
+
self.dim_list = dim_list
|
178 |
+
|
179 |
+
self.decoder = nn.ModuleList()
|
180 |
+
self.final_out = nn.ModuleList()
|
181 |
+
|
182 |
+
self.decoder.append(nn.Sequential(
|
183 |
+
ConvNormRelu(out_dim, 64, norm=norm),
|
184 |
+
ConvNormRelu(64, 64, norm=norm),
|
185 |
+
ConvNormRelu(64, 64, norm=norm),
|
186 |
+
))
|
187 |
+
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
|
188 |
+
|
189 |
+
self.decoder.append(nn.Sequential(
|
190 |
+
ConvNormRelu(out_dim, out_dim, norm=norm),
|
191 |
+
ConvNormRelu(out_dim, out_dim, norm=norm),
|
192 |
+
ConvNormRelu(out_dim, out_dim, norm=norm),
|
193 |
+
))
|
194 |
+
self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
|
195 |
+
|
196 |
+
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
|
197 |
+
if self.training:
|
198 |
+
time_steps = gt_poses.shape[1]
|
199 |
+
|
200 |
+
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
201 |
+
if self.encoder_choice == 'meshtalk':
|
202 |
+
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
|
203 |
+
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
|
204 |
+
elif self.encoder_choice == 'faceformer':
|
205 |
+
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
|
206 |
+
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
|
207 |
+
else:
|
208 |
+
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
209 |
+
|
210 |
+
# hidden_states = in_spec
|
211 |
+
|
212 |
+
feature, _ = self.audio_middle(feature, id=id)
|
213 |
+
|
214 |
+
out = []
|
215 |
+
|
216 |
+
for i in range(self.decoder.__len__()):
|
217 |
+
mid = self.decoder[i](feature)
|
218 |
+
mid = self.final_out[i](mid)
|
219 |
+
out.append(mid)
|
220 |
+
|
221 |
+
out = torch.cat(out, dim=1)
|
222 |
+
out = out.transpose(1, 2)
|
223 |
+
|
224 |
+
return out, None
|
225 |
+
|
226 |
+
|
nets/spg/s2glayers.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
not exactly the same as the official repo but the results are good
|
3 |
+
'''
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
+
sys.path.append(os.getcwd())
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import math
|
14 |
+
from nets.layers import SeqEncoder1D, SeqTranslator1D
|
15 |
+
|
16 |
+
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
|
17 |
+
|
18 |
+
|
19 |
+
class Conv2d_tf(nn.Conv2d):
|
20 |
+
"""
|
21 |
+
Conv2d with the padding behavior from TF
|
22 |
+
from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, *args, **kwargs):
|
26 |
+
super(Conv2d_tf, self).__init__(*args, **kwargs)
|
27 |
+
self.padding = kwargs.get("padding", "SAME")
|
28 |
+
|
29 |
+
def _compute_padding(self, input, dim):
|
30 |
+
input_size = input.size(dim + 2)
|
31 |
+
filter_size = self.weight.size(dim + 2)
|
32 |
+
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
|
33 |
+
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
|
34 |
+
total_padding = max(
|
35 |
+
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
|
36 |
+
)
|
37 |
+
additional_padding = int(total_padding % 2 != 0)
|
38 |
+
|
39 |
+
return additional_padding, total_padding
|
40 |
+
|
41 |
+
def forward(self, input):
|
42 |
+
if self.padding == "VALID":
|
43 |
+
return F.conv2d(
|
44 |
+
input,
|
45 |
+
self.weight,
|
46 |
+
self.bias,
|
47 |
+
self.stride,
|
48 |
+
padding=0,
|
49 |
+
dilation=self.dilation,
|
50 |
+
groups=self.groups,
|
51 |
+
)
|
52 |
+
rows_odd, padding_rows = self._compute_padding(input, dim=0)
|
53 |
+
cols_odd, padding_cols = self._compute_padding(input, dim=1)
|
54 |
+
if rows_odd or cols_odd:
|
55 |
+
input = F.pad(input, [0, cols_odd, 0, rows_odd])
|
56 |
+
|
57 |
+
return F.conv2d(
|
58 |
+
input,
|
59 |
+
self.weight,
|
60 |
+
self.bias,
|
61 |
+
self.stride,
|
62 |
+
padding=(padding_rows // 2, padding_cols // 2),
|
63 |
+
dilation=self.dilation,
|
64 |
+
groups=self.groups,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
class Conv1d_tf(nn.Conv1d):
|
69 |
+
"""
|
70 |
+
Conv1d with the padding behavior from TF
|
71 |
+
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, *args, **kwargs):
|
75 |
+
super(Conv1d_tf, self).__init__(*args, **kwargs)
|
76 |
+
self.padding = kwargs.get("padding")
|
77 |
+
|
78 |
+
def _compute_padding(self, input, dim):
|
79 |
+
input_size = input.size(dim + 2)
|
80 |
+
filter_size = self.weight.size(dim + 2)
|
81 |
+
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
|
82 |
+
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
|
83 |
+
total_padding = max(
|
84 |
+
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
|
85 |
+
)
|
86 |
+
additional_padding = int(total_padding % 2 != 0)
|
87 |
+
|
88 |
+
return additional_padding, total_padding
|
89 |
+
|
90 |
+
def forward(self, input):
|
91 |
+
# if self.padding == "valid":
|
92 |
+
# return F.conv1d(
|
93 |
+
# input,
|
94 |
+
# self.weight,
|
95 |
+
# self.bias,
|
96 |
+
# self.stride,
|
97 |
+
# padding=0,
|
98 |
+
# dilation=self.dilation,
|
99 |
+
# groups=self.groups,
|
100 |
+
# )
|
101 |
+
rows_odd, padding_rows = self._compute_padding(input, dim=0)
|
102 |
+
if rows_odd:
|
103 |
+
input = F.pad(input, [0, rows_odd])
|
104 |
+
|
105 |
+
return F.conv1d(
|
106 |
+
input,
|
107 |
+
self.weight,
|
108 |
+
self.bias,
|
109 |
+
self.stride,
|
110 |
+
padding=(padding_rows // 2),
|
111 |
+
dilation=self.dilation,
|
112 |
+
groups=self.groups,
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1,
|
117 |
+
nonlinear='lrelu', bn='bn'):
|
118 |
+
if k is None and s is None:
|
119 |
+
if not downsample:
|
120 |
+
k = 3
|
121 |
+
s = 1
|
122 |
+
padding = 'same'
|
123 |
+
else:
|
124 |
+
k = 4
|
125 |
+
s = 2
|
126 |
+
padding = 'valid'
|
127 |
+
|
128 |
+
if type == '1d':
|
129 |
+
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
|
130 |
+
norm_block = nn.BatchNorm1d(out_channels)
|
131 |
+
elif type == '2d':
|
132 |
+
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
|
133 |
+
norm_block = nn.BatchNorm2d(out_channels)
|
134 |
+
else:
|
135 |
+
assert False
|
136 |
+
if bn != 'bn':
|
137 |
+
if bn == 'gn':
|
138 |
+
norm_block = nn.GroupNorm(1, out_channels)
|
139 |
+
elif bn == 'ln':
|
140 |
+
norm_block = nn.LayerNorm(out_channels)
|
141 |
+
else:
|
142 |
+
norm_block = nn.Identity()
|
143 |
+
if nonlinear == 'lrelu':
|
144 |
+
nlinear = nn.LeakyReLU(0.2, True)
|
145 |
+
elif nonlinear == 'tanh':
|
146 |
+
nlinear = nn.Tanh()
|
147 |
+
elif nonlinear == 'none':
|
148 |
+
nlinear = nn.Identity()
|
149 |
+
|
150 |
+
return nn.Sequential(
|
151 |
+
conv_block,
|
152 |
+
norm_block,
|
153 |
+
nlinear
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
class UnetUp(nn.Module):
|
158 |
+
def __init__(self, in_ch, out_ch):
|
159 |
+
super(UnetUp, self).__init__()
|
160 |
+
self.conv = ConvNormRelu(in_ch, out_ch)
|
161 |
+
|
162 |
+
def forward(self, x1, x2):
|
163 |
+
# x1 = torch.repeat_interleave(x1, 2, dim=2)
|
164 |
+
# x1 = x1[:, :, :x2.shape[2]]
|
165 |
+
x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear')
|
166 |
+
x = x1 + x2
|
167 |
+
x = self.conv(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class UNet(nn.Module):
|
172 |
+
def __init__(self, input_dim, dim):
|
173 |
+
super(UNet, self).__init__()
|
174 |
+
# dim = 512
|
175 |
+
self.down1 = nn.Sequential(
|
176 |
+
ConvNormRelu(input_dim, input_dim, '1d', False),
|
177 |
+
ConvNormRelu(input_dim, dim, '1d', False),
|
178 |
+
ConvNormRelu(dim, dim, '1d', False)
|
179 |
+
)
|
180 |
+
self.gru = nn.GRU(dim, dim, 1, batch_first=True)
|
181 |
+
self.down2 = ConvNormRelu(dim, dim, '1d', True)
|
182 |
+
self.down3 = ConvNormRelu(dim, dim, '1d', True)
|
183 |
+
self.down4 = ConvNormRelu(dim, dim, '1d', True)
|
184 |
+
self.down5 = ConvNormRelu(dim, dim, '1d', True)
|
185 |
+
self.down6 = ConvNormRelu(dim, dim, '1d', True)
|
186 |
+
self.up1 = UnetUp(dim, dim)
|
187 |
+
self.up2 = UnetUp(dim, dim)
|
188 |
+
self.up3 = UnetUp(dim, dim)
|
189 |
+
self.up4 = UnetUp(dim, dim)
|
190 |
+
self.up5 = UnetUp(dim, dim)
|
191 |
+
|
192 |
+
def forward(self, x1, pre_pose=None, w_pre=False):
|
193 |
+
x2_0 = self.down1(x1)
|
194 |
+
if w_pre:
|
195 |
+
i = 1
|
196 |
+
x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1)
|
197 |
+
x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1)
|
198 |
+
# x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15]
|
199 |
+
else:
|
200 |
+
# x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2)
|
201 |
+
x2 = x2_0
|
202 |
+
x3 = self.down2(x2)
|
203 |
+
x4 = self.down3(x3)
|
204 |
+
x5 = self.down4(x4)
|
205 |
+
x6 = self.down5(x5)
|
206 |
+
x7 = self.down6(x6)
|
207 |
+
x = self.up1(x7, x6)
|
208 |
+
x = self.up2(x, x5)
|
209 |
+
x = self.up3(x, x4)
|
210 |
+
x = self.up4(x, x3)
|
211 |
+
x = self.up5(x, x2) # [B, 512, 15]
|
212 |
+
return x, x2_0
|
213 |
+
|
214 |
+
|
215 |
+
class AudioEncoder(nn.Module):
|
216 |
+
def __init__(self, n_frames, template_length, pose=False, common_dim=512):
|
217 |
+
super().__init__()
|
218 |
+
self.n_frames = n_frames
|
219 |
+
self.pose = pose
|
220 |
+
self.step = 0
|
221 |
+
self.weight = 0
|
222 |
+
if self.pose:
|
223 |
+
# self.first_net = nn.Sequential(
|
224 |
+
# ConvNormRelu(1, 64, '2d', False),
|
225 |
+
# ConvNormRelu(64, 64, '2d', True),
|
226 |
+
# ConvNormRelu(64, 128, '2d', False),
|
227 |
+
# ConvNormRelu(128, 128, '2d', True),
|
228 |
+
# ConvNormRelu(128, 256, '2d', False),
|
229 |
+
# ConvNormRelu(256, 256, '2d', True),
|
230 |
+
# ConvNormRelu(256, 256, '2d', False),
|
231 |
+
# ConvNormRelu(256, 256, '2d', False, padding='VALID')
|
232 |
+
# )
|
233 |
+
# decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4,
|
234 |
+
# dim_feedforward=2 * args.feature_dim, batch_first=True)
|
235 |
+
# a = nn.TransformerDecoder
|
236 |
+
self.first_net = SeqTranslator1D(256, 256,
|
237 |
+
min_layers_num=4,
|
238 |
+
residual=True
|
239 |
+
)
|
240 |
+
self.dropout_0 = nn.Dropout(0.1)
|
241 |
+
self.mu_fc = nn.Conv1d(256, 128, 1, 1)
|
242 |
+
self.var_fc = nn.Conv1d(256, 128, 1, 1)
|
243 |
+
self.trans_motion = SeqTranslator1D(common_dim, common_dim,
|
244 |
+
kernel_size=1,
|
245 |
+
stride=1,
|
246 |
+
min_layers_num=3,
|
247 |
+
residual=True
|
248 |
+
)
|
249 |
+
# self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1)
|
250 |
+
self.unet = UNet(128 + template_length, common_dim)
|
251 |
+
|
252 |
+
else:
|
253 |
+
self.first_net = SeqTranslator1D(256, 256,
|
254 |
+
min_layers_num=4,
|
255 |
+
residual=True
|
256 |
+
)
|
257 |
+
self.dropout_0 = nn.Dropout(0.1)
|
258 |
+
# self.att = nn.MultiheadAttention(256, 4, dropout=0.1)
|
259 |
+
self.unet = UNet(256, 256)
|
260 |
+
self.dropout_1 = nn.Dropout(0.0)
|
261 |
+
|
262 |
+
def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False):
|
263 |
+
self.step = self.step + 1
|
264 |
+
if self.pose:
|
265 |
+
spect = spectrogram.transpose(1, 2)
|
266 |
+
if w_pre:
|
267 |
+
spect = spect[:, :, :]
|
268 |
+
|
269 |
+
out = self.first_net(spect)
|
270 |
+
out = self.dropout_0(out)
|
271 |
+
|
272 |
+
mu = self.mu_fc(out)
|
273 |
+
var = self.var_fc(out)
|
274 |
+
audio = self.__reparam(mu, var)
|
275 |
+
# audio = out
|
276 |
+
|
277 |
+
# template = self.trans_motion(template)
|
278 |
+
x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1)
|
279 |
+
# x1 = out
|
280 |
+
#x1, _ = self.att(x1, x1, x1)
|
281 |
+
#x1 = x1.permute(1,2,0)
|
282 |
+
x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre)
|
283 |
+
else:
|
284 |
+
spectrogram = spectrogram.transpose(1, 2)
|
285 |
+
x1 = self.first_net(spectrogram)#.permute(2,0,1)
|
286 |
+
#out, _ = self.att(out, out, out)
|
287 |
+
#out = out.permute(1, 2, 0)
|
288 |
+
x1 = self.dropout_0(x1)
|
289 |
+
x1, x2_0 = self.unet(x1)
|
290 |
+
x1 = self.dropout_1(x1)
|
291 |
+
mu = None
|
292 |
+
var = None
|
293 |
+
|
294 |
+
return x1, (mu, var), x2_0
|
295 |
+
|
296 |
+
def __reparam(self, mu, log_var):
|
297 |
+
std = torch.exp(0.5 * log_var)
|
298 |
+
eps = torch.randn_like(std, device='cuda')
|
299 |
+
z = eps * std + mu
|
300 |
+
return z
|
301 |
+
|
302 |
+
|
303 |
+
class Generator(nn.Module):
|
304 |
+
def __init__(self,
|
305 |
+
n_poses,
|
306 |
+
pose_dim,
|
307 |
+
pose,
|
308 |
+
n_pre_poses,
|
309 |
+
each_dim: list,
|
310 |
+
dim_list: list,
|
311 |
+
use_template=False,
|
312 |
+
template_length=0,
|
313 |
+
training=False,
|
314 |
+
device=None,
|
315 |
+
separate=False,
|
316 |
+
expression=False
|
317 |
+
):
|
318 |
+
super().__init__()
|
319 |
+
|
320 |
+
self.use_template = use_template
|
321 |
+
self.template_length = template_length
|
322 |
+
self.training = training
|
323 |
+
self.device = device
|
324 |
+
self.separate = separate
|
325 |
+
self.pose = pose
|
326 |
+
self.decoderf = True
|
327 |
+
self.expression = expression
|
328 |
+
|
329 |
+
common_dim = 256
|
330 |
+
|
331 |
+
if self.use_template:
|
332 |
+
assert template_length > 0
|
333 |
+
# self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device)
|
334 |
+
# self.pose_encoder = SeqEncoder1D(
|
335 |
+
# C_in=pose_dim,
|
336 |
+
# C_out=512,
|
337 |
+
# T_in=n_poses,
|
338 |
+
# min_layer_nums=6
|
339 |
+
#
|
340 |
+
# )
|
341 |
+
self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim,
|
342 |
+
# kernel_size=1,
|
343 |
+
# stride=1,
|
344 |
+
min_layers_num=3,
|
345 |
+
residual=True
|
346 |
+
)
|
347 |
+
self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
|
348 |
+
self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
|
349 |
+
|
350 |
+
else:
|
351 |
+
self.template_length = 0
|
352 |
+
|
353 |
+
self.gen_length = n_poses
|
354 |
+
|
355 |
+
self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim)
|
356 |
+
self.speech_encoder = AudioEncoder(n_poses, template_length, False)
|
357 |
+
|
358 |
+
# self.pre_pose_encoder = SeqEncoder1D(
|
359 |
+
# C_in=pose_dim,
|
360 |
+
# C_out=128,
|
361 |
+
# T_in=15,
|
362 |
+
# min_layer_nums=3
|
363 |
+
#
|
364 |
+
# )
|
365 |
+
# self.pmu_fc = nn.Linear(128, 64)
|
366 |
+
# self.pvar_fc = nn.Linear(128, 64)
|
367 |
+
|
368 |
+
self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim,
|
369 |
+
min_layers_num=5,
|
370 |
+
residual=True
|
371 |
+
)
|
372 |
+
self.decoder_in = 256 + 64
|
373 |
+
self.dim_list = dim_list
|
374 |
+
|
375 |
+
if self.separate:
|
376 |
+
self.decoder = nn.ModuleList()
|
377 |
+
self.final_out = nn.ModuleList()
|
378 |
+
|
379 |
+
self.decoder.append(nn.Sequential(
|
380 |
+
ConvNormRelu(256, 64),
|
381 |
+
ConvNormRelu(64, 64),
|
382 |
+
ConvNormRelu(64, 64),
|
383 |
+
))
|
384 |
+
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
|
385 |
+
|
386 |
+
self.decoder.append(nn.Sequential(
|
387 |
+
ConvNormRelu(common_dim, common_dim),
|
388 |
+
ConvNormRelu(common_dim, common_dim),
|
389 |
+
ConvNormRelu(common_dim, common_dim),
|
390 |
+
))
|
391 |
+
self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1))
|
392 |
+
|
393 |
+
self.decoder.append(nn.Sequential(
|
394 |
+
ConvNormRelu(common_dim, common_dim),
|
395 |
+
ConvNormRelu(common_dim, common_dim),
|
396 |
+
ConvNormRelu(common_dim, common_dim),
|
397 |
+
))
|
398 |
+
self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1))
|
399 |
+
|
400 |
+
if self.expression:
|
401 |
+
self.decoder.append(nn.Sequential(
|
402 |
+
ConvNormRelu(256, 256),
|
403 |
+
ConvNormRelu(256, 256),
|
404 |
+
ConvNormRelu(256, 256),
|
405 |
+
))
|
406 |
+
self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1))
|
407 |
+
else:
|
408 |
+
self.decoder = nn.Sequential(
|
409 |
+
ConvNormRelu(self.decoder_in, 512),
|
410 |
+
ConvNormRelu(512, 512),
|
411 |
+
ConvNormRelu(512, 512),
|
412 |
+
ConvNormRelu(512, 512),
|
413 |
+
ConvNormRelu(512, 512),
|
414 |
+
ConvNormRelu(512, 512),
|
415 |
+
)
|
416 |
+
self.final_out = nn.Conv1d(512, pose_dim, 1, 1)
|
417 |
+
|
418 |
+
def __reparam(self, mu, log_var):
|
419 |
+
std = torch.exp(0.5 * log_var)
|
420 |
+
eps = torch.randn_like(std, device=self.device)
|
421 |
+
z = eps * std + mu
|
422 |
+
return z
|
423 |
+
|
424 |
+
def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True):
|
425 |
+
if time_steps is not None:
|
426 |
+
self.gen_length = time_steps
|
427 |
+
|
428 |
+
if self.use_template:
|
429 |
+
if self.training:
|
430 |
+
if w_pre:
|
431 |
+
in_spec = in_spec[:, 15:, :]
|
432 |
+
pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1))
|
433 |
+
pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1))
|
434 |
+
mu = self.mu_fc(pose_enc)
|
435 |
+
var = self.var_fc(pose_enc)
|
436 |
+
template = self.__reparam(mu, var)
|
437 |
+
else:
|
438 |
+
pre_pose = None
|
439 |
+
pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
|
440 |
+
mu = self.mu_fc(pose_enc)
|
441 |
+
var = self.var_fc(pose_enc)
|
442 |
+
template = self.__reparam(mu, var)
|
443 |
+
elif pre_poses is not None:
|
444 |
+
if w_pre:
|
445 |
+
pre_pose = pre_poses[:, -1:, :-50]
|
446 |
+
if norm:
|
447 |
+
pre_pose = pre_pose.reshape(1, -1, 55, 5)
|
448 |
+
pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1),
|
449 |
+
F.normalize(pre_pose[..., 3:5], dim=-1)],
|
450 |
+
dim=-1).reshape(1, -1, 275)
|
451 |
+
pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1))
|
452 |
+
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to(
|
453 |
+
in_spec.device)
|
454 |
+
else:
|
455 |
+
pre_pose = None
|
456 |
+
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
|
457 |
+
elif gt_poses is not None:
|
458 |
+
template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
|
459 |
+
elif template is None:
|
460 |
+
pre_pose = None
|
461 |
+
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
|
462 |
+
else:
|
463 |
+
template = None
|
464 |
+
mu = None
|
465 |
+
var = None
|
466 |
+
|
467 |
+
a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre)
|
468 |
+
s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps)
|
469 |
+
|
470 |
+
out = []
|
471 |
+
|
472 |
+
if self.separate:
|
473 |
+
for i in range(self.decoder.__len__()):
|
474 |
+
if i == 0 or i == 3:
|
475 |
+
mid = self.decoder[i](s_f)
|
476 |
+
else:
|
477 |
+
mid = self.decoder[i](a_t_f)
|
478 |
+
mid = self.final_out[i](mid)
|
479 |
+
out.append(mid)
|
480 |
+
out = torch.cat(out, dim=1)
|
481 |
+
|
482 |
+
else:
|
483 |
+
out = self.decoder(a_t_f)
|
484 |
+
out = self.final_out(out)
|
485 |
+
|
486 |
+
out = out.transpose(1, 2)
|
487 |
+
|
488 |
+
if self.training:
|
489 |
+
if w_pre:
|
490 |
+
return out, template, mu, var, (mu2, var2, x2_0, pre_pose)
|
491 |
+
else:
|
492 |
+
return out, template, mu, var, (mu2, var2, None, None)
|
493 |
+
else:
|
494 |
+
return out
|
495 |
+
|
496 |
+
|
497 |
+
class Discriminator(nn.Module):
|
498 |
+
def __init__(self, pose_dim, pose):
|
499 |
+
super().__init__()
|
500 |
+
self.net = nn.Sequential(
|
501 |
+
Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'),
|
502 |
+
nn.LeakyReLU(0.2, True),
|
503 |
+
ConvNormRelu(64, 128, '1d', True),
|
504 |
+
ConvNormRelu(128, 256, '1d', k=4, s=1),
|
505 |
+
Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'),
|
506 |
+
)
|
507 |
+
|
508 |
+
def forward(self, x):
|
509 |
+
x = x.transpose(1, 2)
|
510 |
+
|
511 |
+
out = self.net(x)
|
512 |
+
return out
|
513 |
+
|
514 |
+
|
515 |
+
def main():
|
516 |
+
d = Discriminator(275, 55)
|
517 |
+
x = torch.randn([8, 60, 275])
|
518 |
+
result = d(x)
|
519 |
+
|
520 |
+
|
521 |
+
if __name__ == "__main__":
|
522 |
+
main()
|
nets/spg/vqvae_1d.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from .wav2vec import Wav2Vec2Model
|
7 |
+
from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class AudioEncoder(nn.Module):
|
12 |
+
def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
|
13 |
+
super(AudioEncoder, self).__init__()
|
14 |
+
self._num_hiddens = num_hiddens
|
15 |
+
self._num_residual_layers = num_residual_layers
|
16 |
+
self._num_residual_hiddens = num_residual_hiddens
|
17 |
+
|
18 |
+
self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
|
19 |
+
|
20 |
+
self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
|
21 |
+
self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
|
22 |
+
sample='down')
|
23 |
+
self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
|
24 |
+
self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
|
25 |
+
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
26 |
+
|
27 |
+
def forward(self, x, frame_num=0):
|
28 |
+
h = self.project(x)
|
29 |
+
h = self._enc_1(h)
|
30 |
+
h = self._down_1(h)
|
31 |
+
h = self._enc_2(h)
|
32 |
+
h = self._down_2(h)
|
33 |
+
h = self._enc_3(h)
|
34 |
+
return h
|
35 |
+
|
36 |
+
|
37 |
+
class Wav2VecEncoder(nn.Module):
|
38 |
+
def __init__(self, num_hiddens, num_residual_layers):
|
39 |
+
super(Wav2VecEncoder, self).__init__()
|
40 |
+
self._num_hiddens = num_hiddens
|
41 |
+
self._num_residual_layers = num_residual_layers
|
42 |
+
|
43 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained(
|
44 |
+
"facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
|
45 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
46 |
+
|
47 |
+
self.project = ConvNormRelu(768, self._num_hiddens, leaky=True)
|
48 |
+
|
49 |
+
self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
50 |
+
self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
|
51 |
+
self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
52 |
+
self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
|
53 |
+
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
54 |
+
|
55 |
+
def forward(self, x, frame_num):
|
56 |
+
h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2)
|
57 |
+
h = self.project(h)
|
58 |
+
h = self._enc_1(h)
|
59 |
+
h = self._down_1(h)
|
60 |
+
h = self._enc_2(h)
|
61 |
+
h = self._down_2(h)
|
62 |
+
h = self._enc_3(h)
|
63 |
+
return h
|
64 |
+
|
65 |
+
|
66 |
+
class Encoder(nn.Module):
|
67 |
+
def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
|
68 |
+
super(Encoder, self).__init__()
|
69 |
+
self._num_hiddens = num_hiddens
|
70 |
+
self._num_residual_layers = num_residual_layers
|
71 |
+
self._num_residual_hiddens = num_residual_hiddens
|
72 |
+
|
73 |
+
self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
|
74 |
+
|
75 |
+
self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
|
76 |
+
self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
|
77 |
+
sample='down')
|
78 |
+
self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
|
79 |
+
self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
|
80 |
+
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
81 |
+
|
82 |
+
self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
h = self.project(x)
|
86 |
+
h = self._enc_1(h)
|
87 |
+
h = self._down_1(h)
|
88 |
+
h = self._enc_2(h)
|
89 |
+
h = self._down_2(h)
|
90 |
+
h = self._enc_3(h)
|
91 |
+
h = self.pre_vq_conv(h)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class Frame_Enc(nn.Module):
|
96 |
+
def __init__(self, in_dim, num_hiddens):
|
97 |
+
super(Frame_Enc, self).__init__()
|
98 |
+
self.in_dim = in_dim
|
99 |
+
self.num_hiddens = num_hiddens
|
100 |
+
|
101 |
+
# self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4)
|
102 |
+
self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1)
|
103 |
+
self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True)
|
104 |
+
self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1)
|
105 |
+
self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
# x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1)
|
109 |
+
x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1)
|
110 |
+
second_last = self.proj_2(x)
|
111 |
+
last = self.proj_1(x)
|
112 |
+
return second_last, last
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
class Decoder(nn.Module):
|
117 |
+
def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
|
118 |
+
super(Decoder, self).__init__()
|
119 |
+
self._num_hiddens = num_hiddens
|
120 |
+
self._num_residual_layers = num_residual_layers
|
121 |
+
self._num_residual_hiddens = num_residual_hiddens
|
122 |
+
|
123 |
+
self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
|
124 |
+
|
125 |
+
self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
|
126 |
+
self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
|
127 |
+
self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
|
128 |
+
self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
|
129 |
+
sample='up')
|
130 |
+
self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
|
131 |
+
|
132 |
+
if ae:
|
133 |
+
self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4)
|
134 |
+
self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True)
|
135 |
+
self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True)
|
136 |
+
|
137 |
+
self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
|
138 |
+
|
139 |
+
def forward(self, h, last_frame=None):
|
140 |
+
|
141 |
+
h = self.aft_vq_conv(h)
|
142 |
+
h = self._dec_1(h)
|
143 |
+
h = self._up_2(h)
|
144 |
+
h = self._dec_2(h)
|
145 |
+
h = self._up_3(h)
|
146 |
+
h = self._dec_3(h)
|
147 |
+
|
148 |
+
recon = self.project(h)
|
149 |
+
return recon, None
|
150 |
+
|
151 |
+
|
152 |
+
class Pre_VQ(nn.Module):
|
153 |
+
def __init__(self, num_hiddens, embedding_dim, num_chunks):
|
154 |
+
super(Pre_VQ, self).__init__()
|
155 |
+
self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks)
|
156 |
+
self.bn = nn.GroupNorm(num_chunks, num_hiddens)
|
157 |
+
self.relu = nn.ReLU()
|
158 |
+
self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x = self.conv(x)
|
162 |
+
x = self.bn(x)
|
163 |
+
x = self.relu(x)
|
164 |
+
x = self.proj(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class VQVAE(nn.Module):
|
169 |
+
"""VQ-VAE"""
|
170 |
+
|
171 |
+
def __init__(self, in_dim, embedding_dim, num_embeddings,
|
172 |
+
num_hiddens, num_residual_layers, num_residual_hiddens,
|
173 |
+
commitment_cost=0.25, decay=0.99, share=False):
|
174 |
+
super().__init__()
|
175 |
+
self.in_dim = in_dim
|
176 |
+
self.embedding_dim = embedding_dim
|
177 |
+
self.num_embeddings = num_embeddings
|
178 |
+
self.share_code_vq = share
|
179 |
+
|
180 |
+
self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
|
181 |
+
self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
|
182 |
+
self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
|
183 |
+
|
184 |
+
def forward(self, gt_poses, id=None, pre_state=None):
|
185 |
+
z = self.encoder(gt_poses.transpose(1, 2))
|
186 |
+
if not self.training:
|
187 |
+
e, _ = self.vq_layer(z)
|
188 |
+
x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
|
189 |
+
return e, x_recon
|
190 |
+
|
191 |
+
e, e_q_loss = self.vq_layer(z)
|
192 |
+
gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
|
193 |
+
|
194 |
+
return e_q_loss, gt_recon.transpose(1, 2)
|
195 |
+
|
196 |
+
def encode(self, gt_poses, id=None):
|
197 |
+
z = self.encoder(gt_poses.transpose(1, 2))
|
198 |
+
e, latents = self.vq_layer(z)
|
199 |
+
return e, latents
|
200 |
+
|
201 |
+
def decode(self, b, w, e=None, latents=None, pre_state=None):
|
202 |
+
if e is not None:
|
203 |
+
x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
|
204 |
+
else:
|
205 |
+
e = self.vq_layer.quantize(latents)
|
206 |
+
e = e.view(b, w, -1).permute(0, 2, 1).contiguous()
|
207 |
+
x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
class AE(nn.Module):
|
212 |
+
"""VQ-VAE"""
|
213 |
+
|
214 |
+
def __init__(self, in_dim, embedding_dim, num_embeddings,
|
215 |
+
num_hiddens, num_residual_layers, num_residual_hiddens):
|
216 |
+
super().__init__()
|
217 |
+
self.in_dim = in_dim
|
218 |
+
self.embedding_dim = embedding_dim
|
219 |
+
self.num_embeddings = num_embeddings
|
220 |
+
|
221 |
+
self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
|
222 |
+
self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True)
|
223 |
+
|
224 |
+
def forward(self, gt_poses, id=None, pre_state=None):
|
225 |
+
z = self.encoder(gt_poses.transpose(1, 2))
|
226 |
+
if not self.training:
|
227 |
+
x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
|
228 |
+
return z, x_recon
|
229 |
+
gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
|
230 |
+
|
231 |
+
return gt_recon.transpose(1, 2)
|
232 |
+
|
233 |
+
def encode(self, gt_poses, id=None):
|
234 |
+
z = self.encoder(gt_poses.transpose(1, 2))
|
235 |
+
return z
|
nets/spg/vqvae_modules.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision import datasets, transforms
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class CasualCT(nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
leaky=False,
|
17 |
+
p=0,
|
18 |
+
groups=1, ):
|
19 |
+
'''
|
20 |
+
conv-bn-relu
|
21 |
+
'''
|
22 |
+
super(CasualCT, self).__init__()
|
23 |
+
padding = 0
|
24 |
+
kernel_size = 2
|
25 |
+
stride = 2
|
26 |
+
in_channels = in_channels * groups
|
27 |
+
out_channels = out_channels * groups
|
28 |
+
|
29 |
+
self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
30 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
31 |
+
groups=groups)
|
32 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
33 |
+
self.dropout = nn.Dropout(p=p)
|
34 |
+
if leaky:
|
35 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
36 |
+
else:
|
37 |
+
self.relu = nn.ReLU()
|
38 |
+
|
39 |
+
def forward(self, x, **kwargs):
|
40 |
+
out = self.norm(self.dropout(self.conv(x)))
|
41 |
+
return self.relu(out)
|
42 |
+
|
43 |
+
|
44 |
+
class CasualConv(nn.Module):
|
45 |
+
def __init__(self,
|
46 |
+
in_channels,
|
47 |
+
out_channels,
|
48 |
+
leaky=False,
|
49 |
+
p=0,
|
50 |
+
groups=1,
|
51 |
+
downsample=False):
|
52 |
+
'''
|
53 |
+
conv-bn-relu
|
54 |
+
'''
|
55 |
+
super(CasualConv, self).__init__()
|
56 |
+
padding = 0
|
57 |
+
kernel_size = 2
|
58 |
+
stride = 1
|
59 |
+
self.downsample = downsample
|
60 |
+
if self.downsample:
|
61 |
+
kernel_size = 2
|
62 |
+
stride = 2
|
63 |
+
|
64 |
+
in_channels = in_channels * groups
|
65 |
+
out_channels = out_channels * groups
|
66 |
+
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
67 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
68 |
+
groups=groups)
|
69 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
70 |
+
self.dropout = nn.Dropout(p=p)
|
71 |
+
if leaky:
|
72 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
73 |
+
else:
|
74 |
+
self.relu = nn.ReLU()
|
75 |
+
|
76 |
+
def forward(self, x, pre_state=None):
|
77 |
+
if not self.downsample:
|
78 |
+
if pre_state is not None:
|
79 |
+
x = torch.cat([pre_state, x], dim=-1)
|
80 |
+
else:
|
81 |
+
zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device)
|
82 |
+
x = torch.cat([zeros, x], dim=-1)
|
83 |
+
out = self.norm(self.dropout(self.conv(x)))
|
84 |
+
return self.relu(out)
|
85 |
+
|
86 |
+
|
87 |
+
class ConvNormRelu(nn.Module):
|
88 |
+
'''
|
89 |
+
(B,C_in,H,W) -> (B, C_out, H, W)
|
90 |
+
there exist some kernel size that makes the result is not H/s
|
91 |
+
#TODO: there might some problems with residual
|
92 |
+
'''
|
93 |
+
|
94 |
+
def __init__(self,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
leaky=False,
|
98 |
+
sample='none',
|
99 |
+
p=0,
|
100 |
+
groups=1,
|
101 |
+
residual=False,
|
102 |
+
norm='bn'):
|
103 |
+
'''
|
104 |
+
conv-bn-relu
|
105 |
+
'''
|
106 |
+
super(ConvNormRelu, self).__init__()
|
107 |
+
self.residual = residual
|
108 |
+
self.norm_type = norm
|
109 |
+
padding = 1
|
110 |
+
|
111 |
+
if sample == 'none':
|
112 |
+
kernel_size = 3
|
113 |
+
stride = 1
|
114 |
+
elif sample == 'one':
|
115 |
+
padding = 0
|
116 |
+
kernel_size = stride = 1
|
117 |
+
else:
|
118 |
+
kernel_size = 4
|
119 |
+
stride = 2
|
120 |
+
|
121 |
+
if self.residual:
|
122 |
+
if sample == 'down':
|
123 |
+
self.residual_layer = nn.Conv1d(
|
124 |
+
in_channels=in_channels,
|
125 |
+
out_channels=out_channels,
|
126 |
+
kernel_size=kernel_size,
|
127 |
+
stride=stride,
|
128 |
+
padding=padding)
|
129 |
+
elif sample == 'up':
|
130 |
+
self.residual_layer = nn.ConvTranspose1d(
|
131 |
+
in_channels=in_channels,
|
132 |
+
out_channels=out_channels,
|
133 |
+
kernel_size=kernel_size,
|
134 |
+
stride=stride,
|
135 |
+
padding=padding)
|
136 |
+
else:
|
137 |
+
if in_channels == out_channels:
|
138 |
+
self.residual_layer = nn.Identity()
|
139 |
+
else:
|
140 |
+
self.residual_layer = nn.Sequential(
|
141 |
+
nn.Conv1d(
|
142 |
+
in_channels=in_channels,
|
143 |
+
out_channels=out_channels,
|
144 |
+
kernel_size=kernel_size,
|
145 |
+
stride=stride,
|
146 |
+
padding=padding
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
in_channels = in_channels * groups
|
151 |
+
out_channels = out_channels * groups
|
152 |
+
if sample == 'up':
|
153 |
+
self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
154 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
155 |
+
groups=groups)
|
156 |
+
else:
|
157 |
+
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
158 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
159 |
+
groups=groups)
|
160 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
161 |
+
self.dropout = nn.Dropout(p=p)
|
162 |
+
if leaky:
|
163 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
164 |
+
else:
|
165 |
+
self.relu = nn.ReLU()
|
166 |
+
|
167 |
+
def forward(self, x, **kwargs):
|
168 |
+
out = self.norm(self.dropout(self.conv(x)))
|
169 |
+
if self.residual:
|
170 |
+
residual = self.residual_layer(x)
|
171 |
+
out += residual
|
172 |
+
return self.relu(out)
|
173 |
+
|
174 |
+
|
175 |
+
class Res_CNR_Stack(nn.Module):
|
176 |
+
def __init__(self,
|
177 |
+
channels,
|
178 |
+
layers,
|
179 |
+
sample='none',
|
180 |
+
leaky=False,
|
181 |
+
casual=False,
|
182 |
+
):
|
183 |
+
super(Res_CNR_Stack, self).__init__()
|
184 |
+
|
185 |
+
if casual:
|
186 |
+
kernal_size = 1
|
187 |
+
padding = 0
|
188 |
+
conv = CasualConv
|
189 |
+
else:
|
190 |
+
kernal_size = 3
|
191 |
+
padding = 1
|
192 |
+
conv = ConvNormRelu
|
193 |
+
|
194 |
+
if sample == 'one':
|
195 |
+
kernal_size = 1
|
196 |
+
padding = 0
|
197 |
+
|
198 |
+
self._layers = nn.ModuleList()
|
199 |
+
for i in range(layers):
|
200 |
+
self._layers.append(conv(channels, channels, leaky=leaky, sample=sample))
|
201 |
+
self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding)
|
202 |
+
self.norm = nn.BatchNorm1d(channels)
|
203 |
+
self.relu = nn.ReLU()
|
204 |
+
|
205 |
+
def forward(self, x, pre_state=None):
|
206 |
+
# cur_state = []
|
207 |
+
h = x
|
208 |
+
for i in range(self._layers.__len__()):
|
209 |
+
# cur_state.append(h[..., -1:])
|
210 |
+
h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None)
|
211 |
+
h = self.norm(self.conv(h))
|
212 |
+
return self.relu(h + x)
|
213 |
+
|
214 |
+
|
215 |
+
class ExponentialMovingAverage(nn.Module):
|
216 |
+
"""Maintains an exponential moving average for a value.
|
217 |
+
|
218 |
+
This module keeps track of a hidden exponential moving average that is
|
219 |
+
initialized as a vector of zeros which is then normalized to give the average.
|
220 |
+
This gives us a moving average which isn't biased towards either zero or the
|
221 |
+
initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)
|
222 |
+
|
223 |
+
Initially:
|
224 |
+
hidden_0 = 0
|
225 |
+
Then iteratively:
|
226 |
+
hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay)
|
227 |
+
average_i = hidden_i / (1 - decay^i)
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self, init_value, decay):
|
231 |
+
super().__init__()
|
232 |
+
|
233 |
+
self.decay = decay
|
234 |
+
self.counter = 0
|
235 |
+
self.register_buffer("hidden", torch.zeros_like(init_value))
|
236 |
+
|
237 |
+
def forward(self, value):
|
238 |
+
self.counter += 1
|
239 |
+
self.hidden.sub_((self.hidden - value) * (1 - self.decay))
|
240 |
+
average = self.hidden / (1 - self.decay ** self.counter)
|
241 |
+
return average
|
242 |
+
|
243 |
+
|
244 |
+
class VectorQuantizerEMA(nn.Module):
|
245 |
+
"""
|
246 |
+
VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings.
|
247 |
+
Args:
|
248 |
+
embedding_dim (int): the dimensionality of the tensors in the
|
249 |
+
quantized space. Inputs to the modules must be in this format as well.
|
250 |
+
num_embeddings (int): the number of vectors in the quantized space.
|
251 |
+
commitment_cost (float): scalar which controls the weighting of the loss terms (see
|
252 |
+
equation 4 in the paper - this variable is Beta).
|
253 |
+
decay (float): decay for the moving averages.
|
254 |
+
epsilon (float): small float constant to avoid numerical instability.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay,
|
258 |
+
epsilon=1e-5):
|
259 |
+
super().__init__()
|
260 |
+
self.embedding_dim = embedding_dim
|
261 |
+
self.num_embeddings = num_embeddings
|
262 |
+
self.commitment_cost = commitment_cost
|
263 |
+
self.epsilon = epsilon
|
264 |
+
|
265 |
+
# initialize embeddings as buffers
|
266 |
+
embeddings = torch.empty(self.num_embeddings, self.embedding_dim)
|
267 |
+
nn.init.xavier_uniform_(embeddings)
|
268 |
+
self.register_buffer("embeddings", embeddings)
|
269 |
+
self.ema_dw = ExponentialMovingAverage(self.embeddings, decay)
|
270 |
+
|
271 |
+
# also maintain ema_cluster_size, which record the size of each embedding
|
272 |
+
self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay)
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
# [B, C, H, W] -> [B, H, W, C]
|
276 |
+
x = x.permute(0, 2, 1).contiguous()
|
277 |
+
# [B, H, W, C] -> [BHW, C]
|
278 |
+
flat_x = x.reshape(-1, self.embedding_dim)
|
279 |
+
|
280 |
+
encoding_indices = self.get_code_indices(flat_x)
|
281 |
+
quantized = self.quantize(encoding_indices)
|
282 |
+
quantized = quantized.view_as(x) # [B, W, C]
|
283 |
+
|
284 |
+
if not self.training:
|
285 |
+
quantized = quantized.permute(0, 2, 1).contiguous()
|
286 |
+
return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2])
|
287 |
+
|
288 |
+
# update embeddings with EMA
|
289 |
+
with torch.no_grad():
|
290 |
+
encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
|
291 |
+
updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0))
|
292 |
+
n = torch.sum(updated_ema_cluster_size)
|
293 |
+
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
|
294 |
+
(n + self.num_embeddings * self.epsilon) * n)
|
295 |
+
dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster
|
296 |
+
updated_ema_dw = self.ema_dw(dw)
|
297 |
+
normalised_updated_ema_w = (
|
298 |
+
updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1))
|
299 |
+
self.embeddings.data = normalised_updated_ema_w
|
300 |
+
|
301 |
+
# commitment loss
|
302 |
+
e_latent_loss = F.mse_loss(x, quantized.detach())
|
303 |
+
loss = self.commitment_cost * e_latent_loss
|
304 |
+
|
305 |
+
# Straight Through Estimator
|
306 |
+
quantized = x + (quantized - x).detach()
|
307 |
+
|
308 |
+
quantized = quantized.permute(0, 2, 1).contiguous()
|
309 |
+
return quantized, loss
|
310 |
+
|
311 |
+
def get_code_indices(self, flat_x):
|
312 |
+
# compute L2 distance
|
313 |
+
distances = (
|
314 |
+
torch.sum(flat_x ** 2, dim=1, keepdim=True) +
|
315 |
+
torch.sum(self.embeddings ** 2, dim=1) -
|
316 |
+
2. * torch.matmul(flat_x, self.embeddings.t())
|
317 |
+
) # [N, M]
|
318 |
+
encoding_indices = torch.argmin(distances, dim=1) # [N,]
|
319 |
+
return encoding_indices
|
320 |
+
|
321 |
+
def quantize(self, encoding_indices):
|
322 |
+
"""Returns embedding tensor for a batch of indices."""
|
323 |
+
return F.embedding(encoding_indices, self.embeddings)
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
class Casual_Encoder(nn.Module):
|
328 |
+
def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
|
329 |
+
super(Casual_Encoder, self).__init__()
|
330 |
+
self._num_hiddens = num_hiddens
|
331 |
+
self._num_residual_layers = num_residual_layers
|
332 |
+
self._num_residual_hiddens = num_residual_hiddens
|
333 |
+
|
334 |
+
self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1)
|
335 |
+
self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
|
336 |
+
self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True)
|
337 |
+
self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
|
338 |
+
self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True)
|
339 |
+
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
|
340 |
+
# self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
h = self.project(x)
|
344 |
+
h, _ = self._enc_1(h)
|
345 |
+
h = self._down_1(h)
|
346 |
+
h, _ = self._enc_2(h)
|
347 |
+
h = self._down_2(h)
|
348 |
+
h, _ = self._enc_3(h)
|
349 |
+
# h = self.pre_vq_conv(h)
|
350 |
+
return h
|
351 |
+
|
352 |
+
|
353 |
+
class Casual_Decoder(nn.Module):
|
354 |
+
def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
|
355 |
+
super(Casual_Decoder, self).__init__()
|
356 |
+
self._num_hiddens = num_hiddens
|
357 |
+
self._num_residual_layers = num_residual_layers
|
358 |
+
self._num_residual_hiddens = num_residual_hiddens
|
359 |
+
|
360 |
+
# self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
|
361 |
+
self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
|
362 |
+
self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True)
|
363 |
+
self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
|
364 |
+
self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True)
|
365 |
+
self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
|
366 |
+
self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1)
|
367 |
+
|
368 |
+
def forward(self, h, pre_state=None):
|
369 |
+
cur_state = []
|
370 |
+
# h = self.aft_vq_conv(x)
|
371 |
+
h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None)
|
372 |
+
cur_state.append(s)
|
373 |
+
h = self._up_2(h)
|
374 |
+
h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None)
|
375 |
+
cur_state.append(s)
|
376 |
+
h = self._up_3(h)
|
377 |
+
h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None)
|
378 |
+
cur_state.append(s)
|
379 |
+
recon = self.project(h)
|
380 |
+
return recon, cur_state
|
nets/spg/wav2vec.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
from transformers import Wav2Vec2Model,Wav2Vec2Config
|
8 |
+
from transformers.modeling_outputs import BaseModelOutput
|
9 |
+
from typing import Optional, Tuple
|
10 |
+
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
11 |
+
|
12 |
+
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
13 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
14 |
+
def _compute_mask_indices(
|
15 |
+
shape: Tuple[int, int],
|
16 |
+
mask_prob: float,
|
17 |
+
mask_length: int,
|
18 |
+
attention_mask: Optional[torch.Tensor] = None,
|
19 |
+
min_masks: int = 0,
|
20 |
+
) -> np.ndarray:
|
21 |
+
bsz, all_sz = shape
|
22 |
+
mask = np.full((bsz, all_sz), False)
|
23 |
+
|
24 |
+
all_num_mask = int(
|
25 |
+
mask_prob * all_sz / float(mask_length)
|
26 |
+
+ np.random.rand()
|
27 |
+
)
|
28 |
+
all_num_mask = max(min_masks, all_num_mask)
|
29 |
+
mask_idcs = []
|
30 |
+
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
31 |
+
for i in range(bsz):
|
32 |
+
if padding_mask is not None:
|
33 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
34 |
+
num_mask = int(
|
35 |
+
mask_prob * sz / float(mask_length)
|
36 |
+
+ np.random.rand()
|
37 |
+
)
|
38 |
+
num_mask = max(min_masks, num_mask)
|
39 |
+
else:
|
40 |
+
sz = all_sz
|
41 |
+
num_mask = all_num_mask
|
42 |
+
|
43 |
+
lengths = np.full(num_mask, mask_length)
|
44 |
+
|
45 |
+
if sum(lengths) == 0:
|
46 |
+
lengths[0] = min(mask_length, sz - 1)
|
47 |
+
|
48 |
+
min_len = min(lengths)
|
49 |
+
if sz - min_len <= num_mask:
|
50 |
+
min_len = sz - num_mask - 1
|
51 |
+
|
52 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
53 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
54 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
55 |
+
|
56 |
+
min_len = min([len(m) for m in mask_idcs])
|
57 |
+
for i, mask_idc in enumerate(mask_idcs):
|
58 |
+
if len(mask_idc) > min_len:
|
59 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
60 |
+
mask[i, mask_idc] = True
|
61 |
+
return mask
|
62 |
+
|
63 |
+
# linear interpolation layer
|
64 |
+
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
65 |
+
features = features.transpose(1, 2)
|
66 |
+
seq_len = features.shape[2] / float(input_fps)
|
67 |
+
if output_len is None:
|
68 |
+
output_len = int(seq_len * output_fps)
|
69 |
+
output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear')
|
70 |
+
return output_features.transpose(1, 2)
|
71 |
+
|
72 |
+
|
73 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
74 |
+
def __init__(self, config):
|
75 |
+
super().__init__(config)
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
input_values,
|
79 |
+
attention_mask=None,
|
80 |
+
output_attentions=None,
|
81 |
+
output_hidden_states=None,
|
82 |
+
return_dict=None,
|
83 |
+
frame_num=None
|
84 |
+
):
|
85 |
+
self.config.output_attentions = True
|
86 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
87 |
+
output_hidden_states = (
|
88 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
89 |
+
)
|
90 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
91 |
+
|
92 |
+
hidden_states = self.feature_extractor(input_values)
|
93 |
+
hidden_states = hidden_states.transpose(1, 2)
|
94 |
+
|
95 |
+
hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num)
|
96 |
+
|
97 |
+
if attention_mask is not None:
|
98 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
99 |
+
attention_mask = torch.zeros(
|
100 |
+
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
101 |
+
)
|
102 |
+
attention_mask[
|
103 |
+
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
104 |
+
] = 1
|
105 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
106 |
+
|
107 |
+
hidden_states = self.feature_projection(hidden_states)
|
108 |
+
|
109 |
+
if self.config.apply_spec_augment and self.training:
|
110 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
111 |
+
if self.config.mask_time_prob > 0:
|
112 |
+
mask_time_indices = _compute_mask_indices(
|
113 |
+
(batch_size, sequence_length),
|
114 |
+
self.config.mask_time_prob,
|
115 |
+
self.config.mask_time_length,
|
116 |
+
attention_mask=attention_mask,
|
117 |
+
min_masks=2,
|
118 |
+
)
|
119 |
+
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
120 |
+
if self.config.mask_feature_prob > 0:
|
121 |
+
mask_feature_indices = _compute_mask_indices(
|
122 |
+
(batch_size, hidden_size),
|
123 |
+
self.config.mask_feature_prob,
|
124 |
+
self.config.mask_feature_length,
|
125 |
+
)
|
126 |
+
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
127 |
+
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
128 |
+
encoder_outputs = self.encoder(
|
129 |
+
hidden_states[0],
|
130 |
+
attention_mask=attention_mask,
|
131 |
+
output_attentions=output_attentions,
|
132 |
+
output_hidden_states=output_hidden_states,
|
133 |
+
return_dict=return_dict,
|
134 |
+
)
|
135 |
+
hidden_states = encoder_outputs[0]
|
136 |
+
if not return_dict:
|
137 |
+
return (hidden_states,) + encoder_outputs[1:]
|
138 |
+
|
139 |
+
return BaseModelOutput(
|
140 |
+
last_hidden_state=hidden_states,
|
141 |
+
hidden_states=encoder_outputs.hidden_states,
|
142 |
+
attentions=encoder_outputs.attentions,
|
143 |
+
)
|
nets/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import textgrid as tg
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def get_parameter_size(model):
|
6 |
+
total_num = sum(p.numel() for p in model.parameters())
|
7 |
+
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
8 |
+
return total_num, trainable_num
|
9 |
+
|
10 |
+
def denormalize(kps, data_mean, data_std):
|
11 |
+
'''
|
12 |
+
kps: (B, T, C)
|
13 |
+
'''
|
14 |
+
data_std = data_std.reshape(1, 1, -1)
|
15 |
+
data_mean = data_mean.reshape(1, 1, -1)
|
16 |
+
return (kps * data_std) + data_mean
|
17 |
+
|
18 |
+
def normalize(kps, data_mean, data_std):
|
19 |
+
'''
|
20 |
+
kps: (B, T, C)
|
21 |
+
'''
|
22 |
+
data_std = data_std.squeeze().reshape(1, 1, -1)
|
23 |
+
data_mean = data_mean.squeeze().reshape(1, 1, -1)
|
24 |
+
|
25 |
+
return (kps-data_mean) / data_std
|
26 |
+
|
27 |
+
def parse_audio(textgrid_file):
|
28 |
+
'''a demo implementation'''
|
29 |
+
words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when']
|
30 |
+
txt=tg.TextGrid.fromFile(textgrid_file)
|
31 |
+
|
32 |
+
total_time=int(np.ceil(txt.maxTime))
|
33 |
+
code_seq=np.zeros(total_time)
|
34 |
+
|
35 |
+
word_level=txt[0]
|
36 |
+
|
37 |
+
for i in range(len(word_level)):
|
38 |
+
start_time=word_level[i].minTime
|
39 |
+
end_time=word_level[i].maxTime
|
40 |
+
mark=word_level[i].mark
|
41 |
+
|
42 |
+
if mark in words:
|
43 |
+
start=int(np.round(start_time))
|
44 |
+
end=int(np.round(end_time))
|
45 |
+
|
46 |
+
if start >= len(code_seq) or end >= len(code_seq):
|
47 |
+
code_seq[-1] = 1
|
48 |
+
else:
|
49 |
+
code_seq[start]=1
|
50 |
+
|
51 |
+
return code_seq
|
52 |
+
|
53 |
+
|
54 |
+
def get_path(model_name, model_type):
|
55 |
+
if model_name == 's2g_body_pixel':
|
56 |
+
if model_type == 'mfcc':
|
57 |
+
return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
|
58 |
+
elif model_type == 'wv2':
|
59 |
+
return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth'
|
60 |
+
elif model_type == 'random':
|
61 |
+
return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth'
|
62 |
+
elif model_type == 'wbhmodel':
|
63 |
+
return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth'
|
64 |
+
elif model_type == 'wobhmodel':
|
65 |
+
return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth'
|
66 |
+
elif model_name == 's2g_body':
|
67 |
+
if model_type == 'a+m-vae':
|
68 |
+
return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth'
|
69 |
+
elif model_type == 'a-vae':
|
70 |
+
return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth'
|
71 |
+
elif model_type == 'a-ed':
|
72 |
+
return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth'
|
73 |
+
elif model_name == 's2g_LS3DCG':
|
74 |
+
return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
|
75 |
+
elif model_name == 's2g_body_vq':
|
76 |
+
if model_type == 'n_com_1024':
|
77 |
+
return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth'
|
78 |
+
elif model_type == 'n_com_2048':
|
79 |
+
return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth'
|
80 |
+
elif model_type == 'n_com_4096':
|
81 |
+
return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth'
|
82 |
+
elif model_type == 'n_com_8192':
|
83 |
+
return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth'
|
84 |
+
elif model_type == 'n_com_16384':
|
85 |
+
return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth'
|
86 |
+
elif model_type == 'n_com_170000':
|
87 |
+
return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth'
|
88 |
+
elif model_type == 'com_1024':
|
89 |
+
return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth'
|
90 |
+
elif model_type == 'com_2048':
|
91 |
+
return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth'
|
92 |
+
elif model_type == 'com_4096':
|
93 |
+
return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth'
|
94 |
+
elif model_type == 'com_8192':
|
95 |
+
return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth'
|
96 |
+
elif model_type == 'com_16384':
|
97 |
+
return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth'
|
98 |
+
|
99 |
+
|
100 |
+
def get_dpath(model_name, model_type):
|
101 |
+
if model_name == 's2g_body_pixel':
|
102 |
+
if model_type == 'audio':
|
103 |
+
return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth'
|
104 |
+
elif model_type == 'wv2':
|
105 |
+
return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth'
|
106 |
+
elif model_type == 'random':
|
107 |
+
return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth'
|
108 |
+
elif model_type == 'wbhmodel':
|
109 |
+
return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth'
|
110 |
+
# return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth'
|
111 |
+
elif model_type == 'wobhmodel':
|
112 |
+
return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth'
|
113 |
+
# return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth'
|
114 |
+
elif model_name == 's2g_body':
|
115 |
+
if model_type == 'a+m-vae':
|
116 |
+
return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth'
|
117 |
+
elif model_type == 'a-vae':
|
118 |
+
return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth'
|
119 |
+
elif model_type == 'a-ed':
|
120 |
+
return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth'
|
121 |
+
elif model_name == 's2g_LS3DCG':
|
122 |
+
return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth'
|
scripts/.idea/__init__.py
ADDED
File without changes
|
scripts/.idea/aws.xml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="accountSettings">
|
4 |
+
<option name="activeRegion" value="us-east-1" />
|
5 |
+
<option name="recentlyUsedRegions">
|
6 |
+
<list>
|
7 |
+
<option value="us-east-1" />
|
8 |
+
</list>
|
9 |
+
</option>
|
10 |
+
</component>
|
11 |
+
</project>
|
scripts/.idea/deployment.xml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
4 |
+
<serverData>
|
5 |
+
<paths name="80ti">
|
6 |
+
<serverdata>
|
7 |
+
<mappings>
|
8 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
9 |
+
</mappings>
|
10 |
+
</serverdata>
|
11 |
+
</paths>
|
12 |
+
<paths name="80ti (1)">
|
13 |
+
<serverdata>
|
14 |
+
<mappings>
|
15 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
16 |
+
</mappings>
|
17 |
+
</serverdata>
|
18 |
+
</paths>
|
19 |
+
<paths name="80ti (2)">
|
20 |
+
<serverdata>
|
21 |
+
<mappings>
|
22 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
23 |
+
</mappings>
|
24 |
+
</serverdata>
|
25 |
+
</paths>
|
26 |
+
<paths name="80ti (3)">
|
27 |
+
<serverdata>
|
28 |
+
<mappings>
|
29 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
30 |
+
</mappings>
|
31 |
+
</serverdata>
|
32 |
+
</paths>
|
33 |
+
<paths name="[email protected]:22">
|
34 |
+
<serverdata>
|
35 |
+
<mappings>
|
36 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
37 |
+
</mappings>
|
38 |
+
</serverdata>
|
39 |
+
</paths>
|
40 |
+
<paths name="titan">
|
41 |
+
<serverdata>
|
42 |
+
<mappings>
|
43 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
44 |
+
</mappings>
|
45 |
+
</serverdata>
|
46 |
+
</paths>
|
47 |
+
<paths name="titan (1)">
|
48 |
+
<serverdata>
|
49 |
+
<mappings>
|
50 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
51 |
+
</mappings>
|
52 |
+
</serverdata>
|
53 |
+
</paths>
|
54 |
+
<paths name="titan (2)">
|
55 |
+
<serverdata>
|
56 |
+
<mappings>
|
57 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
58 |
+
</mappings>
|
59 |
+
</serverdata>
|
60 |
+
</paths>
|
61 |
+
<paths name="titan (3)">
|
62 |
+
<serverdata>
|
63 |
+
<mappings>
|
64 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
65 |
+
</mappings>
|
66 |
+
</serverdata>
|
67 |
+
</paths>
|
68 |
+
</serverData>
|
69 |
+
</component>
|
70 |
+
</project>
|
scripts/.idea/get_prevar.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
4 |
+
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
import smplx as smpl
|
11 |
+
|
12 |
+
from nets import *
|
13 |
+
from repro_nets import *
|
14 |
+
from trainer.options import parse_args
|
15 |
+
from data_utils import torch_data
|
16 |
+
from trainer.config import load_JsonConfig
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils import data
|
22 |
+
|
23 |
+
def init_model(model_name, model_path, args, config):
|
24 |
+
if model_name == 'freeMo':
|
25 |
+
# generator = freeMo_Generator(args)
|
26 |
+
# generator = freeMo_Generator(args)
|
27 |
+
generator = freeMo_dev(args, config)
|
28 |
+
# generator.load_state_dict(torch.load(model_path)['generator'])
|
29 |
+
elif model_name == 'smplx_S2G':
|
30 |
+
generator = smplx_S2G(args, config)
|
31 |
+
elif model_name == 'StyleGestures':
|
32 |
+
generator = StyleGesture_Generator(
|
33 |
+
args,
|
34 |
+
config
|
35 |
+
)
|
36 |
+
elif model_name == 'Audio2Gestures':
|
37 |
+
config.Train.using_mspec_stat = False
|
38 |
+
generator = Audio2Gesture_Generator(
|
39 |
+
args,
|
40 |
+
config,
|
41 |
+
torch.zeros([1, 1, 108]),
|
42 |
+
torch.ones([1, 1, 108])
|
43 |
+
)
|
44 |
+
elif model_name == 'S2G':
|
45 |
+
generator = S2G_Generator(
|
46 |
+
args,
|
47 |
+
config,
|
48 |
+
)
|
49 |
+
elif model_name == 'Tmpt':
|
50 |
+
generator = S2G_Generator(
|
51 |
+
args,
|
52 |
+
config,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
58 |
+
if model_name == 'smplx_S2G':
|
59 |
+
generator.generator.load_state_dict(model_ckpt['generator']['generator'])
|
60 |
+
elif 'generator' in list(model_ckpt.keys()):
|
61 |
+
generator.load_state_dict(model_ckpt['generator'])
|
62 |
+
else:
|
63 |
+
model_ckpt = {'generator': model_ckpt}
|
64 |
+
generator.load_state_dict(model_ckpt)
|
65 |
+
|
66 |
+
return generator
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def prevar_loader(data_root, speakers, args, config, model_path, device, generator):
|
71 |
+
path = model_path.split('ckpt')[0]
|
72 |
+
file = os.path.join(os.path.dirname(path), "pre_variable.npy")
|
73 |
+
data_base = torch_data(
|
74 |
+
data_root=data_root,
|
75 |
+
speakers=speakers,
|
76 |
+
split='pre',
|
77 |
+
limbscaling=False,
|
78 |
+
normalization=config.Data.pose.normalization,
|
79 |
+
norm_method=config.Data.pose.norm_method,
|
80 |
+
split_trans_zero=False,
|
81 |
+
num_pre_frames=config.Data.pose.pre_pose_length,
|
82 |
+
num_generate_length=config.Data.pose.generate_length,
|
83 |
+
num_frames=15,
|
84 |
+
aud_feat_win_size=config.Data.aud.aud_feat_win_size,
|
85 |
+
aud_feat_dim=config.Data.aud.aud_feat_dim,
|
86 |
+
feat_method=config.Data.aud.feat_method,
|
87 |
+
smplx=True,
|
88 |
+
audio_sr=22000,
|
89 |
+
convert_to_6d=config.Data.pose.convert_to_6d,
|
90 |
+
expression=config.Data.pose.expression
|
91 |
+
)
|
92 |
+
|
93 |
+
data_base.get_dataset()
|
94 |
+
pre_set = data_base.all_dataset
|
95 |
+
pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True)
|
96 |
+
|
97 |
+
total_pose = []
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
for bat in pre_loader:
|
101 |
+
pose = bat['poses'].to(device).to(torch.float32)
|
102 |
+
expression = bat['expression'].to(device).to(torch.float32)
|
103 |
+
pose = pose.permute(0, 2, 1)
|
104 |
+
pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0)
|
105 |
+
expression = expression.permute(0, 2, 1)
|
106 |
+
expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0)
|
107 |
+
pose = torch.cat([pose, expression], dim=-1)
|
108 |
+
pose = pose.reshape(pose.shape[0], -1, 1)
|
109 |
+
pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu()
|
110 |
+
total_pose.append(np.asarray(pose_code))
|
111 |
+
total_pose = np.concatenate(total_pose, axis=0)
|
112 |
+
mean = np.mean(total_pose, axis=0)
|
113 |
+
std = np.std(total_pose, axis=0)
|
114 |
+
prevar = (mean, std)
|
115 |
+
np.save(file, prevar, allow_pickle=True)
|
116 |
+
|
117 |
+
return mean, std
|
118 |
+
|
119 |
+
def main():
|
120 |
+
parser = parse_args()
|
121 |
+
args = parser.parse_args()
|
122 |
+
device = torch.device(args.gpu)
|
123 |
+
torch.cuda.set_device(device)
|
124 |
+
|
125 |
+
config = load_JsonConfig(args.config_file)
|
126 |
+
|
127 |
+
print('init model...')
|
128 |
+
generator = init_model(config.Model.model_name, args.model_path, args, config)
|
129 |
+
print('init pre-pose vectors...')
|
130 |
+
mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator)
|
131 |
+
|
132 |
+
main()
|