YoonaAI commited on
Commit
c31e128
·
1 Parent(s): 46da5aa

Create models/pymaf_net.py

Browse files
Files changed (1) hide show
  1. lib/pymaf/models/pymaf_net.py +362 -0
lib/pymaf/models/pymaf_net.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from lib.pymaf.utils.geometry import rot6d_to_rotmat, projection, rotation_matrix_to_angle_axis
6
+ from .maf_extractor import MAF_Extractor
7
+ from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, H36M_TO_J14
8
+ from .hmr import ResNet_Backbone
9
+ from .res_module import IUV_predict_layer
10
+ from lib.common.config import cfg
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ BN_MOMENTUM = 0.1
16
+
17
+
18
+ class Regressor(nn.Module):
19
+ def __init__(self, feat_dim, smpl_mean_params):
20
+ super().__init__()
21
+
22
+ npose = 24 * 6
23
+
24
+ self.fc1 = nn.Linear(feat_dim + npose + 13, 1024)
25
+ self.drop1 = nn.Dropout()
26
+ self.fc2 = nn.Linear(1024, 1024)
27
+ self.drop2 = nn.Dropout()
28
+ self.decpose = nn.Linear(1024, npose)
29
+ self.decshape = nn.Linear(1024, 10)
30
+ self.deccam = nn.Linear(1024, 3)
31
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
32
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
33
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
34
+
35
+ self.smpl = SMPL(SMPL_MODEL_DIR, batch_size=64, create_transl=False)
36
+
37
+ mean_params = np.load(smpl_mean_params)
38
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
39
+ init_shape = torch.from_numpy(
40
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
41
+ init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
42
+ self.register_buffer('init_pose', init_pose)
43
+ self.register_buffer('init_shape', init_shape)
44
+ self.register_buffer('init_cam', init_cam)
45
+
46
+ def forward(self,
47
+ x,
48
+ init_pose=None,
49
+ init_shape=None,
50
+ init_cam=None,
51
+ n_iter=1,
52
+ J_regressor=None):
53
+ batch_size = x.shape[0]
54
+
55
+ if init_pose is None:
56
+ init_pose = self.init_pose.expand(batch_size, -1)
57
+ if init_shape is None:
58
+ init_shape = self.init_shape.expand(batch_size, -1)
59
+ if init_cam is None:
60
+ init_cam = self.init_cam.expand(batch_size, -1)
61
+
62
+ pred_pose = init_pose
63
+ pred_shape = init_shape
64
+ pred_cam = init_cam
65
+ for i in range(n_iter):
66
+ xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
67
+ xc = self.fc1(xc)
68
+ xc = self.drop1(xc)
69
+ xc = self.fc2(xc)
70
+ xc = self.drop2(xc)
71
+ pred_pose = self.decpose(xc) + pred_pose
72
+ pred_shape = self.decshape(xc) + pred_shape
73
+ pred_cam = self.deccam(xc) + pred_cam
74
+
75
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
76
+
77
+ pred_output = self.smpl(betas=pred_shape,
78
+ body_pose=pred_rotmat[:, 1:],
79
+ global_orient=pred_rotmat[:, 0].unsqueeze(1),
80
+ pose2rot=False)
81
+
82
+ pred_vertices = pred_output.vertices
83
+ pred_joints = pred_output.joints
84
+ pred_smpl_joints = pred_output.smpl_joints
85
+ pred_keypoints_2d = projection(pred_joints, pred_cam)
86
+ pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
87
+ 3)).reshape(
88
+ -1, 72)
89
+
90
+ if J_regressor is not None:
91
+ pred_joints = torch.matmul(J_regressor, pred_vertices)
92
+ pred_pelvis = pred_joints[:, [0], :].clone()
93
+ pred_joints = pred_joints[:, H36M_TO_J14, :]
94
+ pred_joints = pred_joints - pred_pelvis
95
+
96
+ output = {
97
+ 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
98
+ 'verts': pred_vertices,
99
+ 'kp_2d': pred_keypoints_2d,
100
+ 'kp_3d': pred_joints,
101
+ 'smpl_kp_3d': pred_smpl_joints,
102
+ 'rotmat': pred_rotmat,
103
+ 'pred_cam': pred_cam,
104
+ 'pred_shape': pred_shape,
105
+ 'pred_pose': pred_pose,
106
+ }
107
+ return output
108
+
109
+ def forward_init(self,
110
+ x,
111
+ init_pose=None,
112
+ init_shape=None,
113
+ init_cam=None,
114
+ n_iter=1,
115
+ J_regressor=None):
116
+ batch_size = x.shape[0]
117
+
118
+ if init_pose is None:
119
+ init_pose = self.init_pose.expand(batch_size, -1)
120
+ if init_shape is None:
121
+ init_shape = self.init_shape.expand(batch_size, -1)
122
+ if init_cam is None:
123
+ init_cam = self.init_cam.expand(batch_size, -1)
124
+
125
+ pred_pose = init_pose
126
+ pred_shape = init_shape
127
+ pred_cam = init_cam
128
+
129
+ pred_rotmat = rot6d_to_rotmat(pred_pose.contiguous()).view(
130
+ batch_size, 24, 3, 3)
131
+
132
+ pred_output = self.smpl(betas=pred_shape,
133
+ body_pose=pred_rotmat[:, 1:],
134
+ global_orient=pred_rotmat[:, 0].unsqueeze(1),
135
+ pose2rot=False)
136
+
137
+ pred_vertices = pred_output.vertices
138
+ pred_joints = pred_output.joints
139
+ pred_smpl_joints = pred_output.smpl_joints
140
+ pred_keypoints_2d = projection(pred_joints, pred_cam)
141
+ pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
142
+ 3)).reshape(
143
+ -1, 72)
144
+
145
+ if J_regressor is not None:
146
+ pred_joints = torch.matmul(J_regressor, pred_vertices)
147
+ pred_pelvis = pred_joints[:, [0], :].clone()
148
+ pred_joints = pred_joints[:, H36M_TO_J14, :]
149
+ pred_joints = pred_joints - pred_pelvis
150
+
151
+ output = {
152
+ 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
153
+ 'verts': pred_vertices,
154
+ 'kp_2d': pred_keypoints_2d,
155
+ 'kp_3d': pred_joints,
156
+ 'smpl_kp_3d': pred_smpl_joints,
157
+ 'rotmat': pred_rotmat,
158
+ 'pred_cam': pred_cam,
159
+ 'pred_shape': pred_shape,
160
+ 'pred_pose': pred_pose,
161
+ }
162
+ return output
163
+
164
+
165
+ class PyMAF(nn.Module):
166
+ """ PyMAF based Deep Regressor for Human Mesh Recovery
167
+ PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021
168
+ """
169
+
170
+ def __init__(self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True):
171
+ super().__init__()
172
+ self.feature_extractor = ResNet_Backbone(
173
+ model=cfg.MODEL.PyMAF.BACKBONE, pretrained=pretrained)
174
+
175
+ # deconv layers
176
+ self.inplanes = self.feature_extractor.inplanes
177
+ self.deconv_with_bias = cfg.RES_MODEL.DECONV_WITH_BIAS
178
+ self.deconv_layers = self._make_deconv_layer(
179
+ cfg.RES_MODEL.NUM_DECONV_LAYERS,
180
+ cfg.RES_MODEL.NUM_DECONV_FILTERS,
181
+ cfg.RES_MODEL.NUM_DECONV_KERNELS,
182
+ )
183
+
184
+ self.maf_extractor = nn.ModuleList()
185
+ for _ in range(cfg.MODEL.PyMAF.N_ITER):
186
+ self.maf_extractor.append(MAF_Extractor())
187
+ ma_feat_len = self.maf_extractor[-1].Dmap.shape[
188
+ 0] * cfg.MODEL.PyMAF.MLP_DIM[-1]
189
+
190
+ grid_size = 21
191
+ xv, yv = torch.meshgrid([
192
+ torch.linspace(-1, 1, grid_size),
193
+ torch.linspace(-1, 1, grid_size)
194
+ ])
195
+ points_grid = torch.stack([xv.reshape(-1),
196
+ yv.reshape(-1)]).unsqueeze(0)
197
+ self.register_buffer('points_grid', points_grid)
198
+ grid_feat_len = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1]
199
+
200
+ self.regressor = nn.ModuleList()
201
+ for i in range(cfg.MODEL.PyMAF.N_ITER):
202
+ if i == 0:
203
+ ref_infeat_dim = grid_feat_len
204
+ else:
205
+ ref_infeat_dim = ma_feat_len
206
+ self.regressor.append(
207
+ Regressor(feat_dim=ref_infeat_dim,
208
+ smpl_mean_params=smpl_mean_params))
209
+
210
+ dp_feat_dim = 256
211
+ self.with_uv = cfg.LOSS.POINT_REGRESSION_WEIGHTS > 0
212
+ if cfg.MODEL.PyMAF.AUX_SUPV_ON:
213
+ self.dp_head = IUV_predict_layer(feat_dim=dp_feat_dim)
214
+
215
+ def _make_layer(self, block, planes, blocks, stride=1):
216
+ downsample = None
217
+ if stride != 1 or self.inplanes != planes * block.expansion:
218
+ downsample = nn.Sequential(
219
+ nn.Conv2d(self.inplanes,
220
+ planes * block.expansion,
221
+ kernel_size=1,
222
+ stride=stride,
223
+ bias=False),
224
+ nn.BatchNorm2d(planes * block.expansion),
225
+ )
226
+
227
+ layers = []
228
+ layers.append(block(self.inplanes, planes, stride, downsample))
229
+ self.inplanes = planes * block.expansion
230
+ for i in range(1, blocks):
231
+ layers.append(block(self.inplanes, planes))
232
+
233
+ return nn.Sequential(*layers)
234
+
235
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
236
+ """
237
+ Deconv_layer used in Simple Baselines:
238
+ Xiao et al. Simple Baselines for Human Pose Estimation and Tracking
239
+ https://github.com/microsoft/human-pose-estimation.pytorch
240
+ """
241
+ assert num_layers == len(num_filters), \
242
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
243
+ assert num_layers == len(num_kernels), \
244
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
245
+
246
+ def _get_deconv_cfg(deconv_kernel, index):
247
+ if deconv_kernel == 4:
248
+ padding = 1
249
+ output_padding = 0
250
+ elif deconv_kernel == 3:
251
+ padding = 1
252
+ output_padding = 1
253
+ elif deconv_kernel == 2:
254
+ padding = 0
255
+ output_padding = 0
256
+
257
+ return deconv_kernel, padding, output_padding
258
+
259
+ layers = []
260
+ for i in range(num_layers):
261
+ kernel, padding, output_padding = _get_deconv_cfg(
262
+ num_kernels[i], i)
263
+
264
+ planes = num_filters[i]
265
+ layers.append(
266
+ nn.ConvTranspose2d(in_channels=self.inplanes,
267
+ out_channels=planes,
268
+ kernel_size=kernel,
269
+ stride=2,
270
+ padding=padding,
271
+ output_padding=output_padding,
272
+ bias=self.deconv_with_bias))
273
+ layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
274
+ layers.append(nn.ReLU(inplace=True))
275
+ self.inplanes = planes
276
+
277
+ return nn.Sequential(*layers)
278
+
279
+ def forward(self, x, J_regressor=None):
280
+
281
+ batch_size = x.shape[0]
282
+
283
+ # spatial features and global features
284
+ s_feat, g_feat = self.feature_extractor(x)
285
+
286
+ assert cfg.MODEL.PyMAF.N_ITER >= 0 and cfg.MODEL.PyMAF.N_ITER <= 3
287
+ if cfg.MODEL.PyMAF.N_ITER == 1:
288
+ deconv_blocks = [self.deconv_layers]
289
+ elif cfg.MODEL.PyMAF.N_ITER == 2:
290
+ deconv_blocks = [self.deconv_layers[0:6], self.deconv_layers[6:9]]
291
+ elif cfg.MODEL.PyMAF.N_ITER == 3:
292
+ deconv_blocks = [
293
+ self.deconv_layers[0:3], self.deconv_layers[3:6],
294
+ self.deconv_layers[6:9]
295
+ ]
296
+
297
+ out_list = {}
298
+
299
+ # initial parameters
300
+ # TODO: remove the initial mesh generation during forward to reduce runtime
301
+ # by generating initial mesh the beforehand: smpl_output = self.init_smpl
302
+ smpl_output = self.regressor[0].forward_init(g_feat,
303
+ J_regressor=J_regressor)
304
+
305
+ out_list['smpl_out'] = [smpl_output]
306
+ out_list['dp_out'] = []
307
+
308
+ # for visulization
309
+ vis_feat_list = [s_feat.detach()]
310
+
311
+ # parameter predictions
312
+ for rf_i in range(cfg.MODEL.PyMAF.N_ITER):
313
+ pred_cam = smpl_output['pred_cam']
314
+ pred_shape = smpl_output['pred_shape']
315
+ pred_pose = smpl_output['pred_pose']
316
+
317
+ pred_cam = pred_cam.detach()
318
+ pred_shape = pred_shape.detach()
319
+ pred_pose = pred_pose.detach()
320
+
321
+ s_feat_i = deconv_blocks[rf_i](s_feat)
322
+ s_feat = s_feat_i
323
+ vis_feat_list.append(s_feat_i.detach())
324
+
325
+ self.maf_extractor[rf_i].im_feat = s_feat_i
326
+ self.maf_extractor[rf_i].cam = pred_cam
327
+
328
+ if rf_i == 0:
329
+ sample_points = torch.transpose(
330
+ self.points_grid.expand(batch_size, -1, -1), 1, 2)
331
+ ref_feature = self.maf_extractor[rf_i].sampling(sample_points)
332
+ else:
333
+ pred_smpl_verts = smpl_output['verts'].detach()
334
+ # TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration
335
+ pred_smpl_verts_ds = torch.matmul(
336
+ self.maf_extractor[rf_i].Dmap.unsqueeze(0),
337
+ pred_smpl_verts) # [B, 431, 3]
338
+ ref_feature = self.maf_extractor[rf_i](
339
+ pred_smpl_verts_ds) # [B, 431 * n_feat]
340
+
341
+ smpl_output = self.regressor[rf_i](ref_feature,
342
+ pred_pose,
343
+ pred_shape,
344
+ pred_cam,
345
+ n_iter=1,
346
+ J_regressor=J_regressor)
347
+ out_list['smpl_out'].append(smpl_output)
348
+
349
+ if self.training and cfg.MODEL.PyMAF.AUX_SUPV_ON:
350
+ iuv_out_dict = self.dp_head(s_feat)
351
+ out_list['dp_out'].append(iuv_out_dict)
352
+
353
+ return out_list
354
+
355
+
356
+ def pymaf_net(smpl_mean_params, pretrained=True):
357
+ """ Constructs an PyMAF model with ResNet50 backbone.
358
+ Args:
359
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
360
+ """
361
+ model = PyMAF(smpl_mean_params, pretrained)
362
+ return model