Mounika256 commited on
Commit
3a17336
·
verified ·
1 Parent(s): 1238fa5

Update nets/smplx_body_pixel.py

Browse files
Files changed (1) hide show
  1. nets/smplx_body_pixel.py +21 -326
nets/smplx_body_pixel.py CHANGED
@@ -1,326 +1,21 @@
1
- import os
2
- import sys
3
-
4
- import torch
5
- from torch.optim.lr_scheduler import StepLR
6
-
7
- sys.path.append(os.getcwd())
8
-
9
- from nets.layers import *
10
- from nets.base import TrainWrapperBaseClass
11
- from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
12
- from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder
13
- from nets.spg.vqvae_1d import AudioEncoder
14
- from nets.utils import parse_audio, denormalize
15
- from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
16
- import numpy as np
17
- import torch.optim as optim
18
- import torch.nn.functional as F
19
- from sklearn.preprocessing import normalize
20
-
21
- from data_utils.lower_body import c_index, c_index_3d, c_index_6d
22
- from data_utils.utils import smooth_geom, get_mfcc_sepa
23
-
24
-
25
- class TrainWrapper(TrainWrapperBaseClass):
26
- '''
27
- a wrapper receving a batch from data_utils and calculate loss
28
- '''
29
-
30
- def __init__(self, args, config):
31
- self.args = args
32
- self.config = config
33
- self.device = torch.device(self.args.gpu)
34
- self.global_step = 0
35
-
36
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
37
- self.expression = self.config.Data.pose.expression
38
- self.epoch = 0
39
- self.init_params()
40
- self.num_classes = 4
41
- self.audio = True
42
- self.composition = self.config.Model.composition
43
- self.bh_model = self.config.Model.bh_model
44
-
45
- if self.audio:
46
- self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
47
- else:
48
- self.audioencoder = None
49
- if self.convert_to_6d:
50
- dim, layer = 512, 10
51
- else:
52
- dim, layer = 256, 15
53
- self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
54
- self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
55
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
56
- self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
57
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
58
-
59
- model_path = self.config.Model.vq_path
60
- model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
61
- self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
62
- self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
63
-
64
- if torch.cuda.device_count() > 1:
65
- self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1])
66
- self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1])
67
- self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1])
68
- if self.audioencoder is not None:
69
- self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1])
70
-
71
- self.discriminator = None
72
- if self.convert_to_6d:
73
- self.c_index = c_index_6d
74
- else:
75
- self.c_index = c_index_3d
76
-
77
- super().__init__(args, config)
78
-
79
- def init_optimizer(self):
80
-
81
- print('using Adam')
82
- self.generator_optimizer = optim.Adam(
83
- self.generator.parameters(),
84
- lr=self.config.Train.learning_rate.generator_learning_rate,
85
- betas=[0.9, 0.999]
86
- )
87
- if self.audioencoder is not None:
88
- opt = self.config.Model.AudioOpt
89
- if opt == 'Adam':
90
- self.audioencoder_optimizer = optim.Adam(
91
- self.audioencoder.parameters(),
92
- lr=self.config.Train.learning_rate.generator_learning_rate,
93
- betas=[0.9, 0.999]
94
- )
95
- else:
96
- print('using SGD')
97
- self.audioencoder_optimizer = optim.SGD(
98
- filter(lambda p: p.requires_grad,self.audioencoder.parameters()),
99
- lr=self.config.Train.learning_rate.generator_learning_rate*10,
100
- momentum=0.9,
101
- nesterov=False,
102
- )
103
-
104
- def state_dict(self):
105
- model_state = {
106
- 'generator': self.generator.state_dict(),
107
- 'generator_optim': self.generator_optimizer.state_dict(),
108
- 'audioencoder': self.audioencoder.state_dict() if self.audio else None,
109
- 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
110
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
111
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
112
- }
113
- return model_state
114
-
115
- def load_state_dict(self, state_dict):
116
-
117
- from collections import OrderedDict
118
- new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.`
119
- for k, v in state_dict.items():
120
- sub_dict = OrderedDict()
121
- if v is not None:
122
- for k1, v1 in v.items():
123
- name = k1.replace('module.', '')
124
- sub_dict[name] = v1
125
- new_state_dict[k] = sub_dict
126
- state_dict = new_state_dict
127
- if 'generator' in state_dict:
128
- self.generator.load_state_dict(state_dict['generator'])
129
- else:
130
- self.generator.load_state_dict(state_dict)
131
-
132
- if 'generator_optim' in state_dict and self.generator_optimizer is not None:
133
- self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
134
-
135
- if self.discriminator is not None:
136
- self.discriminator.load_state_dict(state_dict['discriminator'])
137
-
138
- if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
139
- self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
140
-
141
- if 'audioencoder' in state_dict and self.audioencoder is not None:
142
- self.audioencoder.load_state_dict(state_dict['audioencoder'])
143
-
144
- def init_params(self):
145
- if self.config.Data.pose.convert_to_6d:
146
- scale = 2
147
- else:
148
- scale = 1
149
-
150
- global_orient = round(0 * scale)
151
- leye_pose = reye_pose = round(0 * scale)
152
- jaw_pose = round(0 * scale)
153
- body_pose = round((63 - 24) * scale)
154
- left_hand_pose = right_hand_pose = round(45 * scale)
155
- if self.expression:
156
- expression = 100
157
- else:
158
- expression = 0
159
-
160
- b_j = 0
161
- jaw_dim = jaw_pose
162
- b_e = b_j + jaw_dim
163
- eye_dim = leye_pose + reye_pose
164
- b_b = b_e + eye_dim
165
- body_dim = global_orient + body_pose
166
- b_h = b_b + body_dim
167
- hand_dim = left_hand_pose + right_hand_pose
168
- b_f = b_h + hand_dim
169
- face_dim = expression
170
-
171
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
172
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
173
- self.pose = int(self.full_dim / round(3 * scale))
174
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
175
-
176
- def __call__(self, bat):
177
- # assert (not self.args.infer), "infer mode"
178
- self.global_step += 1
179
-
180
- total_loss = None
181
- loss_dict = {}
182
-
183
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
184
-
185
- id = bat['speaker'].to(self.device) - 20
186
- # id = F.one_hot(id, self.num_classes)
187
-
188
- poses = poses[:, self.c_index, :]
189
-
190
- aud = aud.permute(0, 2, 1)
191
- gt_poses = poses.permute(0, 2, 1)
192
-
193
- with torch.no_grad():
194
- self.g_body.eval()
195
- self.g_hand.eval()
196
- if torch.cuda.device_count() > 1:
197
- _, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
198
- _, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
199
- else:
200
- _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
201
- _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
202
- latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1)
203
- latents = latents.detach()
204
-
205
- if self.audio:
206
- audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
207
- logits = self.generator(latents[:, :], id, audio)
208
- else:
209
- logits = self.generator(latents, id)
210
- logits = logits.permute(0, 2, 3, 1).contiguous()
211
-
212
- self.generator_optimizer.zero_grad()
213
- if self.audio:
214
- self.audioencoder_optimizer.zero_grad()
215
-
216
- loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
217
- loss.backward()
218
-
219
- grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
220
-
221
- if torch.isnan(grad).sum() > 0:
222
- print('fuck')
223
-
224
- loss_dict['grad'] = grad.item()
225
- loss_dict['ce_loss'] = loss.item()
226
- self.generator_optimizer.step()
227
- if self.audio:
228
- self.audioencoder_optimizer.step()
229
-
230
- return total_loss, loss_dict
231
-
232
- def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None,
233
- continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs):
234
- '''
235
- initial_pose: (B, C, T), normalized
236
- (aud_fn, txgfile) -> generated motion (B, T, C)
237
- '''
238
- output = []
239
-
240
- assert self.args.infer, "train mode"
241
- self.generator.eval()
242
- self.g_body.eval()
243
- self.g_hand.eval()
244
-
245
- if continuity:
246
- aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps)
247
- else:
248
- aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am)
249
- aud_feat = aud_feat.transpose(1, 0)
250
- aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
251
- aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
252
-
253
- if id is None:
254
- id = torch.tensor([0]).to(self.device)
255
- else:
256
- id = id.repeat(B)
257
-
258
- with torch.no_grad():
259
- aud_feat = aud_feat.permute(0, 2, 1)
260
- if continuity:
261
- self.audioencoder.eval()
262
- pre_pose = {}
263
- pre_pose['b'] = pre_pose['h'] = None
264
- pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose)
265
- pre_pose['b'] = body_0[:, :, -4:].transpose(1,2)
266
- pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2)
267
- _, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose)
268
- body = torch.cat([body_0, body_1], dim=2)
269
- hand = torch.cat([hand_0, hand_1], dim=2)
270
-
271
- else:
272
- if self.audio:
273
- self.audioencoder.eval()
274
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
275
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio)
276
- else:
277
- latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B)
278
-
279
- body_latents = latents[..., 0]
280
- hand_latents = latents[..., 1]
281
-
282
- body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
283
- hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
284
-
285
- pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy()
286
-
287
- output = pred_poses
288
-
289
- return output
290
-
291
- def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None):
292
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
293
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio,
294
- pre_latents=pre_latents, pre_audio=pre_audio)
295
-
296
- body_latents = latents[..., 0]
297
- hand_latents = latents[..., 1]
298
-
299
- body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1],
300
- latents=body_latents, pre_state=pre_pose['b'])
301
- hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1],
302
- latents=hand_latents, pre_state=pre_pose['h'])
303
-
304
- return latents, audio, body, hand
305
-
306
- def generate(self, aud, id, frame_num=0):
307
-
308
- self.generator.eval()
309
- self.g_body.eval()
310
- self.g_hand.eval()
311
- aud_feat = aud.permute(0, 2, 1)
312
- if self.audio:
313
- self.audioencoder.eval()
314
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
315
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio)
316
- else:
317
- latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0])
318
-
319
- body_latents = latents[..., 0]
320
- hand_latents = latents[..., 1]
321
-
322
- body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
323
- hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
324
-
325
- pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2)
326
- return pred_poses
 
1
+ def __init__(self, args, config):
2
+ self.args = args
3
+ self.config = config
4
+ self.global_step = 0
5
+
6
+ # Force CPU device
7
+ self.device = torch.device('cpu')
8
+
9
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
10
+ self.expression = self.config.Data.pose.expression
11
+ self.epoch = 0
12
+ self.init_params()
13
+ self.num_classes = 4
14
+ self.audio = True
15
+ self.composition = self.config.Model.composition
16
+ self.bh_model = self.config.Model.bh_model
17
+
18
+ if self.audio:
19
+ self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
20
+ else:
21
+ self.audioencoder = None