Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| from mmcv.runner import build_optimizer | |
| from detrsmpl.core.conventions.keypoints_mapping import ( | |
| get_keypoint_idx, | |
| get_keypoint_idxs_by_part, | |
| ) | |
| from .smplify import OptimizableParameters, SMPLify | |
| class SMPLifyX(SMPLify): | |
| """Re-implementation of SMPLify-X with extended features. | |
| - video input | |
| - 3D keypoints | |
| """ | |
| def __call__(self, | |
| keypoints2d: torch.Tensor = None, | |
| keypoints2d_conf: torch.Tensor = None, | |
| keypoints3d: torch.Tensor = None, | |
| keypoints3d_conf: torch.Tensor = None, | |
| init_global_orient: torch.Tensor = None, | |
| init_transl: torch.Tensor = None, | |
| init_body_pose: torch.Tensor = None, | |
| init_betas: torch.Tensor = None, | |
| init_left_hand_pose: torch.Tensor = None, | |
| init_right_hand_pose: torch.Tensor = None, | |
| init_expression: torch.Tensor = None, | |
| init_jaw_pose: torch.Tensor = None, | |
| init_leye_pose: torch.Tensor = None, | |
| init_reye_pose: torch.Tensor = None, | |
| return_verts: bool = False, | |
| return_joints: bool = False, | |
| return_full_pose: bool = False, | |
| return_losses: bool = False) -> dict: | |
| """Run registration. | |
| Notes: | |
| B: batch size | |
| K: number of keypoints | |
| D: body shape dimension | |
| D_H: hand pose dimension | |
| D_E: expression dimension | |
| Provide only keypoints2d or keypoints3d, not both. | |
| Args: | |
| keypoints2d: 2D keypoints of shape (B, K, 2) | |
| keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
| keypoints3d: 3D keypoints of shape (B, K, 3). | |
| keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
| init_global_orient: initial global_orient of shape (B, 3) | |
| init_transl: initial transl of shape (B, 3) | |
| init_body_pose: initial body_pose of shape (B, 69) | |
| init_betas: initial betas of shape (B, D) | |
| init_left_hand_pose: initial left hand pose of shape (B, D_H) | |
| init_right_hand_pose: initial right hand pose of shape (B, D_H) | |
| init_expression: initial left hand pose of shape (B, D_E) | |
| init_jaw_pose: initial jaw pose of shape (B, 3) | |
| init_leye_pose: initial left eye pose of shape (B, 3) | |
| init_reye_pose: initial right eye pose of shape (B, 3) | |
| return_verts: whether to return vertices | |
| return_joints: whether to return joints | |
| return_full_pose: whether to return full pose | |
| return_losses: whether to return loss dict | |
| Returns: | |
| ret: a dictionary that includes body model parameters, | |
| and optional attributes such as vertices and joints | |
| """ | |
| assert keypoints2d is not None or keypoints3d is not None, \ | |
| 'Neither of 2D nor 3D keypoints are provided.' | |
| assert not (keypoints2d is not None and keypoints3d is not None), \ | |
| 'Do not provide both 2D and 3D keypoints.' | |
| batch_size = keypoints2d.shape[0] if keypoints2d is not None \ | |
| else keypoints3d.shape[0] | |
| global_orient = self._match_init_batch_size( | |
| init_global_orient, self.body_model.global_orient, batch_size) | |
| transl = self._match_init_batch_size(init_transl, | |
| self.body_model.transl, | |
| batch_size) | |
| body_pose = self._match_init_batch_size(init_body_pose, | |
| self.body_model.body_pose, | |
| batch_size) | |
| left_hand_pose = self._match_init_batch_size( | |
| init_left_hand_pose, self.body_model.left_hand_pose, batch_size) | |
| right_hand_pose = self._match_init_batch_size( | |
| init_right_hand_pose, self.body_model.right_hand_pose, batch_size) | |
| expression = self._match_init_batch_size(init_expression, | |
| self.body_model.expression, | |
| batch_size) | |
| jaw_pose = self._match_init_batch_size(init_jaw_pose, | |
| self.body_model.jaw_pose, | |
| batch_size) | |
| leye_pose = self._match_init_batch_size(init_leye_pose, | |
| self.body_model.leye_pose, | |
| batch_size) | |
| reye_pose = self._match_init_batch_size(init_reye_pose, | |
| self.body_model.reye_pose, | |
| batch_size) | |
| if init_betas is None and self.use_one_betas_per_video: | |
| betas = torch.zeros(1, self.body_model.betas.shape[-1]).to( | |
| self.device) | |
| else: | |
| betas = self._match_init_batch_size(init_betas, | |
| self.body_model.betas, | |
| batch_size) | |
| for i in range(self.num_epochs): | |
| for stage_idx, stage_config in enumerate(self.stage_config): | |
| # print(stage_name) | |
| self._optimize_stage( | |
| global_orient=global_orient, | |
| transl=transl, | |
| body_pose=body_pose, | |
| betas=betas, | |
| left_hand_pose=left_hand_pose, | |
| right_hand_pose=right_hand_pose, | |
| expression=expression, | |
| jaw_pose=jaw_pose, | |
| leye_pose=leye_pose, | |
| reye_pose=reye_pose, | |
| keypoints2d=keypoints2d, | |
| keypoints2d_conf=keypoints2d_conf, | |
| keypoints3d=keypoints3d, | |
| keypoints3d_conf=keypoints3d_conf, | |
| **stage_config, | |
| ) | |
| return { | |
| 'global_orient': global_orient, | |
| 'transl': transl, | |
| 'body_pose': body_pose, | |
| 'betas': betas, | |
| 'left_hand_pose': left_hand_pose, | |
| 'right_hand_pose': right_hand_pose, | |
| 'expression': expression, | |
| 'jaw_pose': jaw_pose, | |
| 'leye_pose': leye_pose, | |
| 'reye_pose': reye_pose | |
| } | |
| def _optimize_stage(self, | |
| betas: torch.Tensor, | |
| body_pose: torch.Tensor, | |
| global_orient: torch.Tensor, | |
| transl: torch.Tensor, | |
| left_hand_pose: torch.Tensor, | |
| right_hand_pose: torch.Tensor, | |
| expression: torch.Tensor, | |
| jaw_pose: torch.Tensor, | |
| leye_pose: torch.Tensor, | |
| reye_pose: torch.Tensor, | |
| fit_global_orient: bool = True, | |
| fit_transl: bool = True, | |
| fit_body_pose: bool = True, | |
| fit_betas: bool = True, | |
| fit_left_hand_pose: bool = True, | |
| fit_right_hand_pose: bool = True, | |
| fit_expression: bool = True, | |
| fit_jaw_pose: bool = True, | |
| fit_leye_pose: bool = True, | |
| fit_reye_pose: bool = True, | |
| keypoints2d: torch.Tensor = None, | |
| keypoints2d_conf: torch.Tensor = None, | |
| keypoints2d_weight: float = None, | |
| keypoints3d: torch.Tensor = None, | |
| keypoints3d_conf: torch.Tensor = None, | |
| keypoints3d_weight: float = None, | |
| shape_prior_weight: float = None, | |
| joint_prior_weight: float = None, | |
| smooth_loss_weight: float = None, | |
| pose_prior_weight: float = None, | |
| pose_reg_weight: float = None, | |
| limb_length_weight: float = None, | |
| joint_weights: dict = {}, | |
| ftol: float = 1e-4, | |
| num_iter: int = 1) -> None: | |
| """Optimize a stage of body model parameters according to | |
| configuration. | |
| Notes: | |
| B: batch size | |
| K: number of keypoints | |
| D: shape dimension | |
| Args: | |
| betas: shape (B, D) | |
| body_pose: shape (B, 69) | |
| global_orient: shape (B, 3) | |
| transl: shape (B, 3) | |
| fit_global_orient: whether to optimize global_orient | |
| fit_transl: whether to optimize transl | |
| fit_body_pose: whether to optimize body_pose | |
| fit_betas: whether to optimize betas | |
| fit_left_hand_pose: whether to optimize left hand pose | |
| fit_right_hand_pose: whether to optimize right hand pose | |
| fit_expression: whether to optimize expression | |
| fit_jaw_pose: whether to optimize jaw pose | |
| fit_leye_pose: whether to optimize left eye pose | |
| fit_reye_pose: whether to optimize right eye pose | |
| keypoints2d: 2D keypoints of shape (B, K, 2) | |
| keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
| keypoints2d_weight: weight of 2D keypoint loss | |
| keypoints3d: 3D keypoints of shape (B, K, 3). | |
| keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
| keypoints3d_weight: weight of 3D keypoint loss | |
| shape_prior_weight: weight of shape prior loss | |
| joint_prior_weight: weight of joint prior loss | |
| smooth_loss_weight: weight of smooth loss | |
| pose_prior_weight: weight of pose prior loss | |
| pose_reg_weight: weight of pose regularization loss | |
| limb_length_weight: weight of limb length loss | |
| joint_weights: per joint weight of shape (K, ) | |
| num_iter: number of iterations | |
| ftol: early stop tolerance for relative change in loss | |
| Returns: | |
| None | |
| """ | |
| parameters = OptimizableParameters() | |
| parameters.set_param(fit_global_orient, global_orient) | |
| parameters.set_param(fit_transl, transl) | |
| parameters.set_param(fit_body_pose, body_pose) | |
| parameters.set_param(fit_betas, betas) | |
| parameters.set_param(fit_left_hand_pose, left_hand_pose) | |
| parameters.set_param(fit_right_hand_pose, right_hand_pose) | |
| parameters.set_param(fit_expression, expression) | |
| parameters.set_param(fit_jaw_pose, jaw_pose) | |
| parameters.set_param(fit_leye_pose, leye_pose) | |
| parameters.set_param(fit_reye_pose, reye_pose) | |
| optimizer = build_optimizer(parameters, self.optimizer) | |
| pre_loss = None | |
| for iter_idx in range(num_iter): | |
| def closure(): | |
| # body_pose_fixed = use_reference_spine(body_pose, | |
| # init_body_pose) | |
| optimizer.zero_grad() | |
| betas_video = self._expand_betas(body_pose.shape[0], betas) | |
| loss_dict = self.evaluate( | |
| global_orient=global_orient, | |
| body_pose=body_pose, | |
| betas=betas_video, | |
| transl=transl, | |
| left_hand_pose=left_hand_pose, | |
| right_hand_pose=right_hand_pose, | |
| expression=expression, | |
| jaw_pose=jaw_pose, | |
| leye_pose=leye_pose, | |
| reye_pose=reye_pose, | |
| keypoints2d=keypoints2d, | |
| keypoints2d_conf=keypoints2d_conf, | |
| keypoints2d_weight=keypoints2d_weight, | |
| keypoints3d=keypoints3d, | |
| keypoints3d_conf=keypoints3d_conf, | |
| keypoints3d_weight=keypoints3d_weight, | |
| joint_prior_weight=joint_prior_weight, | |
| shape_prior_weight=shape_prior_weight, | |
| smooth_loss_weight=smooth_loss_weight, | |
| pose_prior_weight=pose_prior_weight, | |
| pose_reg_weight=pose_reg_weight, | |
| limb_length_weight=limb_length_weight, | |
| joint_weights=joint_weights) | |
| loss = loss_dict['total_loss'] | |
| loss.backward() | |
| return loss | |
| loss = optimizer.step(closure) | |
| if iter_idx > 0 and pre_loss is not None and ftol > 0: | |
| loss_rel_change = self._compute_relative_change( | |
| pre_loss, loss.item()) | |
| if loss_rel_change < ftol: | |
| print(f'[ftol={ftol}] Early stop at {iter_idx} iter!') | |
| break | |
| pre_loss = loss.item() | |
| def evaluate( | |
| self, | |
| betas: torch.Tensor = None, | |
| body_pose: torch.Tensor = None, | |
| global_orient: torch.Tensor = None, | |
| transl: torch.Tensor = None, | |
| left_hand_pose: torch.Tensor = None, | |
| right_hand_pose: torch.Tensor = None, | |
| expression: torch.Tensor = None, | |
| jaw_pose: torch.Tensor = None, | |
| leye_pose: torch.Tensor = None, | |
| reye_pose: torch.Tensor = None, | |
| keypoints2d: torch.Tensor = None, | |
| keypoints2d_conf: torch.Tensor = None, | |
| keypoints2d_weight: float = None, | |
| keypoints3d: torch.Tensor = None, | |
| keypoints3d_conf: torch.Tensor = None, | |
| keypoints3d_weight: float = None, | |
| shape_prior_weight: float = None, | |
| joint_prior_weight: float = None, | |
| smooth_loss_weight: float = None, | |
| pose_prior_weight: float = None, | |
| pose_reg_weight: float = None, | |
| limb_length_weight: float = None, | |
| joint_weights: dict = {}, | |
| return_verts: bool = False, | |
| return_full_pose: bool = False, | |
| return_joints: bool = False, | |
| reduction_override: str = None, | |
| ): | |
| """Evaluate fitted parameters through loss computation. This function | |
| serves two purposes: 1) internally, for loss backpropagation 2) | |
| externally, for fitting quality evaluation. | |
| Notes: | |
| B: batch size | |
| K: number of keypoints | |
| D: body shape dimension | |
| D_H: hand pose dimension | |
| D_E: expression dimension | |
| Args: | |
| betas: shape (B, D) | |
| body_pose: shape (B, 69) | |
| global_orient: shape (B, 3) | |
| transl: shape (B, 3) | |
| left_hand_pose: shape (B, D_H) | |
| right_hand_pose: shape (B, D_H) | |
| expression: shape (B, D_E) | |
| jaw_pose: shape (B, 3) | |
| leye_pose: shape (B, 3) | |
| reye_pose: shape (B, 3) | |
| keypoints2d: 2D keypoints of shape (B, K, 2) | |
| keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
| keypoints2d_weight: weight of 2D keypoint loss | |
| keypoints3d: 3D keypoints of shape (B, K, 3). | |
| keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
| keypoints3d_weight: weight of 3D keypoint loss | |
| shape_prior_weight: weight of shape prior loss | |
| joint_prior_weight: weight of joint prior loss | |
| smooth_loss_weight: weight of smooth loss | |
| pose_prior_weight: weight of pose prior loss | |
| pose_reg_weight: weight of pose regularization loss | |
| limb_length_weight: weight of limb length loss | |
| joint_weights: per joint weight of shape (K, ) | |
| return_verts: whether to return vertices | |
| return_joints: whether to return joints | |
| return_full_pose: whether to return full pose | |
| reduction_override: reduction method, e.g., 'none', 'sum', 'mean' | |
| Returns: | |
| ret: a dictionary that includes body model parameters, | |
| and optional attributes such as vertices and joints | |
| """ | |
| ret = {} | |
| body_model_output = self.body_model(global_orient=global_orient, | |
| body_pose=body_pose, | |
| betas=betas, | |
| transl=transl, | |
| left_hand_pose=left_hand_pose, | |
| right_hand_pose=right_hand_pose, | |
| expression=expression, | |
| jaw_pose=jaw_pose, | |
| leye_pose=leye_pose, | |
| reye_pose=reye_pose, | |
| return_verts=return_verts, | |
| return_full_pose=return_full_pose) | |
| model_joints = body_model_output['joints'] | |
| model_joint_mask = body_model_output['joint_mask'] | |
| loss_dict = self._compute_loss(model_joints, | |
| model_joint_mask, | |
| keypoints2d=keypoints2d, | |
| keypoints2d_conf=keypoints2d_conf, | |
| keypoints2d_weight=keypoints2d_weight, | |
| keypoints3d=keypoints3d, | |
| keypoints3d_conf=keypoints3d_conf, | |
| keypoints3d_weight=keypoints3d_weight, | |
| joint_prior_weight=joint_prior_weight, | |
| shape_prior_weight=shape_prior_weight, | |
| smooth_loss_weight=smooth_loss_weight, | |
| pose_prior_weight=pose_prior_weight, | |
| pose_reg_weight=pose_reg_weight, | |
| limb_length_weight=limb_length_weight, | |
| joint_weights=joint_weights, | |
| reduction_override=reduction_override, | |
| body_pose=body_pose, | |
| betas=betas) | |
| ret.update(loss_dict) | |
| if return_verts: | |
| ret['vertices'] = body_model_output['vertices'] | |
| if return_full_pose: | |
| ret['full_pose'] = body_model_output['full_pose'] | |
| if return_joints: | |
| ret['joints'] = model_joints | |
| return ret | |
| def _set_keypoint_idxs(self): | |
| """Set keypoint indices to 1) body parts to be assigned different | |
| weights 2) be ignored for keypoint loss computation. | |
| Returns: | |
| None | |
| """ | |
| convention = self.body_model.keypoint_dst | |
| # obtain ignore keypoint indices | |
| if self.ignore_keypoints is not None: | |
| self.ignore_keypoint_idxs = [] | |
| for keypoint_name in self.ignore_keypoints: | |
| keypoint_idx = get_keypoint_idx(keypoint_name, | |
| convention=convention) | |
| if keypoint_idx != -1: | |
| self.ignore_keypoint_idxs.append(keypoint_idx) | |
| # obtain body part keypoint indices | |
| shoulder_keypoint_idxs = get_keypoint_idxs_by_part( | |
| 'shoulder', convention=convention) | |
| hip_keypoint_idxs = get_keypoint_idxs_by_part('hip', | |
| convention=convention) | |
| self.shoulder_hip_keypoint_idxs = [ | |
| *shoulder_keypoint_idxs, *hip_keypoint_idxs | |
| ] | |
| # head keypoints include all facial landmarks | |
| self.face_keypoint_idxs = get_keypoint_idxs_by_part( | |
| 'head', convention=convention) | |
| left_hand_keypoint_idxs = get_keypoint_idxs_by_part( | |
| 'left_hand', convention=convention) | |
| right_hand_keypoint_idxs = get_keypoint_idxs_by_part( | |
| 'right_hand', convention=convention) | |
| self.hand_keypoint_idxs = [ | |
| *left_hand_keypoint_idxs, *right_hand_keypoint_idxs | |
| ] | |
| self.body_keypoint_idxs = get_keypoint_idxs_by_part( | |
| 'body', convention=convention) | |
| def _get_weight(self, | |
| use_shoulder_hip_only: bool = False, | |
| body_weight: float = 1.0, | |
| hand_weight: float = 1.0, | |
| face_weight: float = 1.0): | |
| """Get per keypoint weight. | |
| Notes: | |
| K: number of keypoints | |
| Args: | |
| use_shoulder_hip_only: whether to use only shoulder and hip | |
| keypoints for loss computation. This is useful in the | |
| warming-up stage to find a reasonably good initialization. | |
| body_weight: weight of body keypoints. Body part segmentation | |
| definition is included in the HumanData convention. | |
| hand_weight: weight of hand keypoints. | |
| face_weight: weight of face keypoints. | |
| Returns: | |
| weight: per keypoint weight tensor of shape (K) | |
| """ | |
| num_keypoint = self.body_model.num_joints | |
| if use_shoulder_hip_only: | |
| weight = torch.zeros([num_keypoint]).to(self.device) | |
| weight[self.shoulder_hip_keypoint_idxs] = 1.0 | |
| else: | |
| weight = torch.ones([num_keypoint]).to(self.device) | |
| weight[self.body_keypoint_idxs] = \ | |
| weight[self.body_keypoint_idxs] * body_weight | |
| weight[self.hand_keypoint_idxs] = \ | |
| weight[self.hand_keypoint_idxs] * hand_weight | |
| weight[self.face_keypoint_idxs] = \ | |
| weight[self.face_keypoint_idxs] * face_weight | |
| if hasattr(self, 'ignore_keypoint_idxs'): | |
| weight[self.ignore_keypoint_idxs] = 0.0 | |
| return weight | |