Mounika256 commited on
Commit
1238fa5
·
verified ·
1 Parent(s): 78027c4

Update scripts/demo.py

Browse files
Files changed (1) hide show
  1. scripts/demo.py +72 -161
scripts/demo.py CHANGED
@@ -1,74 +1,57 @@
1
  import os
2
  import sys
3
- # os.environ["PYOPENGL_PLATFORM"] = "egl"
4
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
5
  sys.path.append(os.getcwd())
6
 
 
 
 
 
7
  from transformers import Wav2Vec2Processor
8
  from glob import glob
9
-
10
- import numpy as np
11
  import json
12
- import smplx as smpl
13
 
14
  from nets import *
15
  from trainer.options import parse_args
16
  from data_utils import torch_data
17
  from trainer.config import load_JsonConfig
18
-
19
- import torch
20
- import torch.nn as nn
21
  import torch.nn.functional as F
22
  from torch.utils import data
23
  from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
24
  from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
25
  from visualise.rendering import RenderTool
26
 
27
- global device
 
28
  device = 'cpu'
29
 
30
  def init_model(model_name, model_path, args, config):
31
  if model_name == 's2g_face':
32
- generator = s2g_face(
33
- args,
34
- config,
35
- )
36
  elif model_name == 's2g_body_vq':
37
- generator = s2g_body_vq(
38
- args,
39
- config,
40
- )
41
  elif model_name == 's2g_body_pixel':
42
- generator = s2g_body_pixel(
43
- args,
44
- config,
45
- )
46
  elif model_name == 's2g_LS3DCG':
47
- generator = LS3DCG(
48
- args,
49
- config,
50
- )
51
  else:
52
  raise NotImplementedError
53
 
54
  model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
55
  if model_name == 'smplx_S2G':
56
  generator.generator.load_state_dict(model_ckpt['generator']['generator'])
57
-
58
  elif 'generator' in list(model_ckpt.keys()):
59
  generator.load_state_dict(model_ckpt['generator'])
60
  else:
61
  model_ckpt = {'generator': model_ckpt}
62
  generator.load_state_dict(model_ckpt)
63
 
64
- return generator
65
-
66
 
67
  def init_dataloader(data_root, speakers, args, config):
68
- if data_root.endswith('.csv'):
69
- raise NotImplementedError
70
- else:
71
- data_class = torch_data
72
  if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
73
  data_base = torch_data(
74
  data_root=data_root,
@@ -91,26 +74,13 @@ def init_dataloader(data_root, speakers, args, config):
91
  config=config
92
  )
93
  else:
94
- data_base = torch_data(
95
- data_root=data_root,
96
- speakers=speakers,
97
- split='val',
98
- limbscaling=False,
99
- normalization=config.Data.pose.normalization,
100
- norm_method=config.Data.pose.norm_method,
101
- split_trans_zero=False,
102
- num_pre_frames=config.Data.pose.pre_pose_length,
103
- aud_feat_win_size=config.Data.aud.aud_feat_win_size,
104
- aud_feat_dim=config.Data.aud.aud_feat_dim,
105
- feat_method=config.Data.aud.feat_method
106
- )
107
  if config.Data.pose.normalization:
108
  norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
109
  norm_stats = np.load(norm_stats_fn, allow_pickle=True)
110
  data_base.data_mean = norm_stats[0]
111
  data_base.data_std = norm_stats[1]
112
- else:
113
- norm_stats = None
114
 
115
  data_base.get_dataset()
116
  infer_set = data_base.all_dataset
@@ -118,69 +88,49 @@ def init_dataloader(data_root, speakers, args, config):
118
 
119
  return infer_set, infer_loader, norm_stats
120
 
121
-
122
  def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
123
  vertices_list = []
124
- poses_list = []
125
  expression = torch.zeros([1, 50])
126
 
127
  for i in result_list:
128
  vertices = []
129
- poses = []
130
  for j in range(i.shape[0]):
131
- output = smplx_model(betas=betas,
132
- expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
133
- jaw_pose=i[j][0:3].unsqueeze_(dim=0),
134
- leye_pose=i[j][3:6].unsqueeze_(dim=0),
135
- reye_pose=i[j][6:9].unsqueeze_(dim=0),
136
- global_orient=i[j][9:12].unsqueeze_(dim=0),
137
- body_pose=i[j][12:75].unsqueeze_(dim=0),
138
- left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
139
- right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
140
- return_verts=True)
 
 
141
  vertices.append(output.vertices.detach().cpu().numpy().squeeze())
142
- # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
143
- pose = output.body_pose
144
- poses.append(pose.detach().cpu())
145
- vertices = np.asarray(vertices)
146
- vertices_list.append(vertices)
147
- poses = torch.cat(poses, dim=0)
148
- poses_list.append(poses)
149
- if require_pose:
150
- return vertices_list, poses_list
151
- else:
152
- return vertices_list, None
153
-
154
 
155
  global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
156
 
157
-
158
  def infer(g_body, g_face, smplx_model, rendertool, config, args):
159
- betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
160
  am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
161
  am_sr = 16000
162
- num_sample = args.num_sample
163
  cur_wav_file = args.audio_file
164
  id = args.id
165
  face = args.only_face
166
  stand = args.stand
 
 
167
  if face:
168
- body_static = torch.zeros([1, 162], device=device)
169
- body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
170
 
171
  result_list = []
172
 
173
- pred_face = g_face.infer_on_audio(cur_wav_file,
174
- initial_pose=None,
175
- norm_stats=None,
176
- w_pre=False,
177
- # id=id,
178
- frame=None,
179
- am=am,
180
- am_sr=am_sr
181
- )
182
- pred_face = torch.tensor(pred_face).squeeze().to(device)
183
- # pred_face = torch.zeros([gt.shape[0], 105])
184
 
185
  if config.Data.pose.convert_to_6d:
186
  pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
@@ -190,19 +140,11 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
190
  pred_jaw = pred_face[:, :3]
191
  pred_face = pred_face[:, 3:]
192
 
193
- id = torch.tensor([id], device=device)
194
 
195
  for i in range(num_sample):
196
- pred_res = g_body.infer_on_audio(cur_wav_file,
197
- initial_pose=None,
198
- norm_stats=None,
199
- txgfile=None,
200
- id=id,
201
- var=None,
202
- fps=30,
203
- w_pre=False
204
- )
205
- pred = torch.tensor(pred_res).squeeze().to(device)
206
 
207
  if pred.shape[0] < pred_face.shape[0]:
208
  repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
@@ -210,95 +152,64 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
210
  else:
211
  pred = pred[:pred_face.shape[0], :]
212
 
213
- body_or_face = False
214
- if pred.shape[1] < 275:
215
- body_or_face = True
216
  if config.Data.pose.convert_to_6d:
217
  pred = pred.reshape(pred.shape[0], -1, 6)
218
- pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
219
- pred = pred.reshape(pred.shape[0], -1)
220
-
221
- if config.Model.model_name == 's2g_LS3DCG':
222
- pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
223
- else:
224
- pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
225
 
226
- # pred[:, 9:12] = global_orient
227
  pred = part2full(pred, stand)
228
  if face:
229
  pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
230
- # result_list[0] = poses2pred(result_list[0], stand)
231
- # if gt_0 is None:
232
- # gt_0 = gt
233
- # pred = pred2poses(pred, gt_0)
234
- # result_list[0] = poses2poses(result_list[0], gt_0)
235
 
236
  result_list.append(pred)
237
 
238
-
239
  vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
240
-
241
  result_list = [res.to('cpu') for res in result_list]
242
- dict = np.concatenate(result_list[:], axis=0)
243
- file_name = 'visualise/video/' + config.Log.name + '/' + \
244
- cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
245
  np.save(file_name, dict)
246
 
247
  rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
248
 
249
-
250
  def main():
251
  parser = parse_args()
252
  args = parser.parse_args()
253
- args.config_file = './config/body_pixel.json'
254
- # device = torch.device(args.gpu)
255
- # torch.cuda.set_device(device)
256
 
 
 
257
 
258
  config = load_JsonConfig(args.config_file)
259
 
260
- face_model_name = args.face_model_name
261
- face_model_path = args.face_model_path
262
- body_model_name = args.body_model_name
263
- body_model_path = args.body_model_path
264
- smplx_path = './visualise/'
265
-
266
- os.environ['smplx_npz_path'] = config.smplx_npz_path
267
- os.environ['extra_joint_path'] = config.extra_joint_path
268
- os.environ['j14_regressor_path'] = config.j14_regressor_path
269
-
270
  print('init model...')
271
- generator = init_model(body_model_name, body_model_path, args, config)
272
- generator2 = None
273
- generator_face = init_model(face_model_name, face_model_path, args, config)
274
-
275
- print('init smlpx model...')
276
- dtype = torch.float64
277
- model_params = dict(model_path=smplx_path,
278
- model_type='smplx',
279
- create_global_orient=True,
280
- create_body_pose=True,
281
- create_betas=True,
282
- num_betas=300,
283
- create_left_hand_pose=True,
284
- create_right_hand_pose=True,
285
- use_pca=False,
286
- flat_hand_mean=False,
287
- create_expression=True,
288
- num_expression_coeffs=100,
289
- num_pca_comps=12,
290
- create_jaw_pose=True,
291
- create_leye_pose=True,
292
- create_reye_pose=True,
293
- create_transl=False,
294
- # gender='ne',
295
- dtype=dtype, )
296
- smplx_model = smpl.create(**model_params).to(device)
297
  print('init rendertool...')
298
  rendertool = RenderTool('visualise/video/' + config.Log.name)
299
 
300
  infer(generator, generator_face, smplx_model, rendertool, config, args)
301
 
302
-
303
  if __name__ == '__main__':
304
  main()
 
1
  import os
2
  import sys
3
+
4
+ # Force CPU-only for Hugging Face (no CUDA)
5
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
6
  sys.path.append(os.getcwd())
7
 
8
+ import torch
9
+ import numpy as np
10
+ import smplx as smpl
11
+
12
  from transformers import Wav2Vec2Processor
13
  from glob import glob
 
 
14
  import json
 
15
 
16
  from nets import *
17
  from trainer.options import parse_args
18
  from data_utils import torch_data
19
  from trainer.config import load_JsonConfig
 
 
 
20
  import torch.nn.functional as F
21
  from torch.utils import data
22
  from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
23
  from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
24
  from visualise.rendering import RenderTool
25
 
26
+ # Global forced device
27
+ torch_device = torch.device('cpu')
28
  device = 'cpu'
29
 
30
  def init_model(model_name, model_path, args, config):
31
  if model_name == 's2g_face':
32
+ generator = s2g_face(args, config)
 
 
 
33
  elif model_name == 's2g_body_vq':
34
+ generator = s2g_body_vq(args, config)
 
 
 
35
  elif model_name == 's2g_body_pixel':
36
+ generator = s2g_body_pixel(args, config)
 
 
 
37
  elif model_name == 's2g_LS3DCG':
38
+ generator = LS3DCG(args, config)
 
 
 
39
  else:
40
  raise NotImplementedError
41
 
42
  model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
43
  if model_name == 'smplx_S2G':
44
  generator.generator.load_state_dict(model_ckpt['generator']['generator'])
 
45
  elif 'generator' in list(model_ckpt.keys()):
46
  generator.load_state_dict(model_ckpt['generator'])
47
  else:
48
  model_ckpt = {'generator': model_ckpt}
49
  generator.load_state_dict(model_ckpt)
50
 
51
+ return generator.to(torch_device)
 
52
 
53
  def init_dataloader(data_root, speakers, args, config):
54
+ data_class = torch_data
 
 
 
55
  if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
56
  data_base = torch_data(
57
  data_root=data_root,
 
74
  config=config
75
  )
76
  else:
77
+ raise NotImplementedError
78
+
 
 
 
 
 
 
 
 
 
 
 
79
  if config.Data.pose.normalization:
80
  norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
81
  norm_stats = np.load(norm_stats_fn, allow_pickle=True)
82
  data_base.data_mean = norm_stats[0]
83
  data_base.data_std = norm_stats[1]
 
 
84
 
85
  data_base.get_dataset()
86
  infer_set = data_base.all_dataset
 
88
 
89
  return infer_set, infer_loader, norm_stats
90
 
 
91
  def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
92
  vertices_list = []
 
93
  expression = torch.zeros([1, 50])
94
 
95
  for i in result_list:
96
  vertices = []
 
97
  for j in range(i.shape[0]):
98
+ output = smplx_model(
99
+ betas=betas,
100
+ expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
101
+ jaw_pose=i[j][0:3].unsqueeze_(dim=0),
102
+ leye_pose=i[j][3:6].unsqueeze_(dim=0),
103
+ reye_pose=i[j][6:9].unsqueeze_(dim=0),
104
+ global_orient=i[j][9:12].unsqueeze_(dim=0),
105
+ body_pose=i[j][12:75].unsqueeze_(dim=0),
106
+ left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
107
+ right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
108
+ return_verts=True
109
+ )
110
  vertices.append(output.vertices.detach().cpu().numpy().squeeze())
111
+ vertices_list.append(np.asarray(vertices))
112
+ return vertices_list, None
 
 
 
 
 
 
 
 
 
 
113
 
114
  global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
115
 
 
116
  def infer(g_body, g_face, smplx_model, rendertool, config, args):
117
+ betas = torch.zeros([1, 300], dtype=torch.float64).to(torch_device)
118
  am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
119
  am_sr = 16000
 
120
  cur_wav_file = args.audio_file
121
  id = args.id
122
  face = args.only_face
123
  stand = args.stand
124
+ num_sample = args.num_sample
125
+
126
  if face:
127
+ body_static = torch.zeros([1, 162], device=torch_device)
128
+ body_static[:, 6:9] = global_orient.reshape(1, 3).repeat(body_static.shape[0], 1)
129
 
130
  result_list = []
131
 
132
+ pred_face = g_face.infer_on_audio(cur_wav_file, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=am, am_sr=am_sr)
133
+ pred_face = torch.tensor(pred_face).squeeze().to(torch_device)
 
 
 
 
 
 
 
 
 
134
 
135
  if config.Data.pose.convert_to_6d:
136
  pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
 
140
  pred_jaw = pred_face[:, :3]
141
  pred_face = pred_face[:, 3:]
142
 
143
+ id = torch.tensor([id], device=torch_device)
144
 
145
  for i in range(num_sample):
146
+ pred_res = g_body.infer_on_audio(cur_wav_file, initial_pose=None, norm_stats=None, txgfile=None, id=id, var=None, fps=30, w_pre=False)
147
+ pred = torch.tensor(pred_res).squeeze().to(torch_device)
 
 
 
 
 
 
 
 
148
 
149
  if pred.shape[0] < pred_face.shape[0]:
150
  repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
 
152
  else:
153
  pred = pred[:pred_face.shape[0], :]
154
 
 
 
 
155
  if config.Data.pose.convert_to_6d:
156
  pred = pred.reshape(pred.shape[0], -1, 6)
157
+ pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred)).reshape(pred.shape[0], -1)
 
 
 
 
 
 
158
 
159
+ pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
160
  pred = part2full(pred, stand)
161
  if face:
162
  pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
 
 
 
 
 
163
 
164
  result_list.append(pred)
165
 
 
166
  vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
 
167
  result_list = [res.to('cpu') for res in result_list]
168
+ dict = np.concatenate(result_list, axis=0)
169
+ file_name = 'visualise/video/' + config.Log.name + '/' + cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
 
170
  np.save(file_name, dict)
171
 
172
  rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
173
 
 
174
  def main():
175
  parser = parse_args()
176
  args = parser.parse_args()
 
 
 
177
 
178
+ # Force correct config file
179
+ args.config_file = './config/body_pixel.json'
180
 
181
  config = load_JsonConfig(args.config_file)
182
 
 
 
 
 
 
 
 
 
 
 
183
  print('init model...')
184
+ generator = init_model(args.body_model_name, args.body_model_path, args, config)
185
+ generator_face = init_model(args.face_model_name, args.face_model_path, args, config)
186
+
187
+ print('init smplx model...')
188
+ smplx_model = smpl.create(
189
+ model_path='./visualise/',
190
+ model_type='smplx',
191
+ create_global_orient=True,
192
+ create_body_pose=True,
193
+ create_betas=True,
194
+ num_betas=300,
195
+ create_left_hand_pose=True,
196
+ create_right_hand_pose=True,
197
+ use_pca=False,
198
+ flat_hand_mean=False,
199
+ create_expression=True,
200
+ num_expression_coeffs=100,
201
+ num_pca_comps=12,
202
+ create_jaw_pose=True,
203
+ create_leye_pose=True,
204
+ create_reye_pose=True,
205
+ create_transl=False,
206
+ dtype=torch.float64
207
+ ).to(torch_device)
208
+
 
209
  print('init rendertool...')
210
  rendertool = RenderTool('visualise/video/' + config.Log.name)
211
 
212
  infer(generator, generator_face, smplx_model, rendertool, config, args)
213
 
 
214
  if __name__ == '__main__':
215
  main()