Mounika256 commited on
Commit
1b1b150
·
verified ·
1 Parent(s): 3a17336

Update nets/smplx_body_pixel.py

Browse files
Files changed (1) hide show
  1. nets/smplx_body_pixel.py +231 -21
nets/smplx_body_pixel.py CHANGED
@@ -1,21 +1,231 @@
1
- def __init__(self, args, config):
2
- self.args = args
3
- self.config = config
4
- self.global_step = 0
5
-
6
- # Force CPU device
7
- self.device = torch.device('cpu')
8
-
9
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
10
- self.expression = self.config.Data.pose.expression
11
- self.epoch = 0
12
- self.init_params()
13
- self.num_classes = 4
14
- self.audio = True
15
- self.composition = self.config.Model.composition
16
- self.bh_model = self.config.Model.bh_model
17
-
18
- if self.audio:
19
- self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
20
- else:
21
- self.audioencoder = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ from torch.optim.lr_scheduler import StepLR
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ from nets.layers import *
12
+ from nets.base import TrainWrapperBaseClass
13
+ from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
14
+ from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder, AudioEncoder
15
+ from nets.utils import parse_audio, denormalize
16
+ from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
17
+ from data_utils.lower_body import c_index, c_index_3d, c_index_6d
18
+ from data_utils.utils import smooth_geom, get_mfcc_sepa
19
+ import numpy as np
20
+ from sklearn.preprocessing import normalize
21
+
22
+
23
+ class TrainWrapper(TrainWrapperBaseClass):
24
+ '''
25
+ a wrapper receiving a batch from data_utils and calculate loss
26
+ '''
27
+
28
+ def __init__(self, args, config):
29
+ self.args = args
30
+ self.config = config
31
+ self.global_step = 0
32
+
33
+ # Force CPU device
34
+ self.device = torch.device('cpu')
35
+
36
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
37
+ self.expression = self.config.Data.pose.expression
38
+ self.epoch = 0
39
+ self.init_params()
40
+ self.num_classes = 4
41
+ self.audio = True
42
+ self.composition = self.config.Model.composition
43
+ self.bh_model = self.config.Model.bh_model
44
+
45
+ if self.audio:
46
+ self.audioencoder = AudioEncoder(
47
+ in_dim=64,
48
+ num_hiddens=256,
49
+ num_residual_layers=2,
50
+ num_residual_hiddens=256
51
+ ).to(self.device)
52
+ else:
53
+ self.audioencoder = None
54
+
55
+ if self.convert_to_6d:
56
+ dim, layer = 512, 10
57
+ else:
58
+ dim, layer = 256, 15
59
+
60
+ self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
61
+ self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
62
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
63
+ self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
64
+ num_residual_layers=2, num_residual_hiddens=512).to(self.device)
65
+
66
+ model_path = self.config.Model.vq_path
67
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
68
+ self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
69
+ self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
70
+
71
+ self.discriminator = None
72
+ if self.convert_to_6d:
73
+ self.c_index = c_index_6d
74
+ else:
75
+ self.c_index = c_index_3d
76
+
77
+ super().__init__(args, config)
78
+
79
+ def init_optimizer(self):
80
+ print('using Adam')
81
+ self.generator_optimizer = optim.Adam(
82
+ self.generator.parameters(),
83
+ lr=self.config.Train.learning_rate.generator_learning_rate,
84
+ betas=[0.9, 0.999]
85
+ )
86
+ if self.audioencoder is not None:
87
+ opt = self.config.Model.AudioOpt
88
+ if opt == 'Adam':
89
+ self.audioencoder_optimizer = optim.Adam(
90
+ self.audioencoder.parameters(),
91
+ lr=self.config.Train.learning_rate.generator_learning_rate,
92
+ betas=[0.9, 0.999]
93
+ )
94
+ else:
95
+ print('using SGD')
96
+ self.audioencoder_optimizer = optim.SGD(
97
+ filter(lambda p: p.requires_grad, self.audioencoder.parameters()),
98
+ lr=self.config.Train.learning_rate.generator_learning_rate * 10,
99
+ momentum=0.9,
100
+ nesterov=False
101
+ )
102
+
103
+ def state_dict(self):
104
+ return {
105
+ 'generator': self.generator.state_dict(),
106
+ 'generator_optim': self.generator_optimizer.state_dict(),
107
+ 'audioencoder': self.audioencoder.state_dict() if self.audio else None,
108
+ 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
109
+ 'discriminator': self.discriminator.state_dict() if self.discriminator else None,
110
+ 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator else None
111
+ }
112
+
113
+ def load_state_dict(self, state_dict):
114
+ from collections import OrderedDict
115
+ new_state_dict = OrderedDict()
116
+ for k, v in state_dict.items():
117
+ sub_dict = OrderedDict()
118
+ if v is not None:
119
+ for k1, v1 in v.items():
120
+ name = k1.replace('module.', '')
121
+ sub_dict[name] = v1
122
+ new_state_dict[k] = sub_dict
123
+ state_dict = new_state_dict
124
+
125
+ if 'generator' in state_dict:
126
+ self.generator.load_state_dict(state_dict['generator'])
127
+ else:
128
+ self.generator.load_state_dict(state_dict)
129
+
130
+ if 'generator_optim' in state_dict and self.generator_optimizer is not None:
131
+ self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
132
+
133
+ if self.discriminator is not None:
134
+ self.discriminator.load_state_dict(state_dict['discriminator'])
135
+ if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
136
+ self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
137
+
138
+ if 'audioencoder' in state_dict and self.audioencoder is not None:
139
+ self.audioencoder.load_state_dict(state_dict['audioencoder'])
140
+
141
+ def init_params(self):
142
+ if self.config.Data.pose.convert_to_6d:
143
+ scale = 2
144
+ else:
145
+ scale = 1
146
+
147
+ global_orient = round(0 * scale)
148
+ leye_pose = reye_pose = round(0 * scale)
149
+ jaw_pose = round(0 * scale)
150
+ body_pose = round((63 - 24) * scale)
151
+ left_hand_pose = right_hand_pose = round(45 * scale)
152
+ if self.expression:
153
+ expression = 100
154
+ else:
155
+ expression = 0
156
+
157
+ b_j = 0
158
+ jaw_dim = jaw_pose
159
+ b_e = b_j + jaw_dim
160
+ eye_dim = leye_pose + reye_pose
161
+ b_b = b_e + eye_dim
162
+ body_dim = global_orient + body_pose
163
+ b_h = b_b + body_dim
164
+ hand_dim = left_hand_pose + right_hand_pose
165
+ b_f = b_h + hand_dim
166
+ face_dim = expression
167
+
168
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
169
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
170
+ self.pose = int(self.full_dim / round(3 * scale))
171
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
172
+
173
+ def __call__(self, bat):
174
+ self.global_step += 1
175
+ total_loss = None
176
+ loss_dict = {}
177
+
178
+ aud, poses = bat['aud_feat'].to(self.device).float(), bat['poses'].to(self.device).float()
179
+ id = bat['speaker'].to(self.device) - 20
180
+ poses = poses[:, self.c_index, :]
181
+ aud = aud.permute(0, 2, 1)
182
+ gt_poses = poses.permute(0, 2, 1)
183
+
184
+ with torch.no_grad():
185
+ self.g_body.eval()
186
+ self.g_hand.eval()
187
+ _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
188
+ _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
189
+ latents = torch.cat([body_latents.unsqueeze(-1), hand_latents.unsqueeze(-1)], dim=-1).detach()
190
+
191
+ if self.audio:
192
+ audio = self.audioencoder(aud.transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(-1).repeat(1, 1, 1, 2)
193
+ logits = self.generator(latents, id, audio)
194
+ else:
195
+ logits = self.generator(latents, id)
196
+ logits = logits.permute(0, 2, 3, 1).contiguous()
197
+
198
+ self.generator_optimizer.zero_grad()
199
+ if self.audio:
200
+ self.audioencoder_optimizer.zero_grad()
201
+
202
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
203
+ loss.backward()
204
+
205
+ grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
206
+
207
+ loss_dict['grad'] = grad.item()
208
+ loss_dict['ce_loss'] = loss.item()
209
+ self.generator_optimizer.step()
210
+ if self.audio:
211
+ self.audioencoder_optimizer.step()
212
+
213
+ return total_loss, loss_dict
214
+
215
+ # ----------------------------------------
216
+ # 🚀 NEW SIMPLE WRAPPER CLASS for inference
217
+ # ----------------------------------------
218
+
219
+ class s2g_body_pixel(nn.Module):
220
+ def __init__(self, args, config):
221
+ super().__init__()
222
+ self.wrapper = TrainWrapper(args, config)
223
+
224
+ def infer_on_audio(self, *args, **kwargs):
225
+ return self.wrapper.infer_on_audio(*args, **kwargs)
226
+
227
+ def forward(self, *args, **kwargs):
228
+ return self.wrapper(*args, **kwargs)
229
+
230
+ def load_state_dict(self, *args, **kwargs):
231
+ return self.wrapper.load_state_dict(*args, **kwargs)