Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from smplx import MANO as _MANO | |
| from smplx import MANOLayer as _MANOLayer | |
| from detrsmpl.core.conventions.keypoints_mapping import ( | |
| convert_kps, | |
| get_keypoint_num, | |
| ) | |
| class MANO(_MANO): | |
| """Extension of the official MANO implementation.""" | |
| full_pose_keys = {'global_orient', 'hand_pose'} | |
| NUM_VERTS = 776 | |
| NUM_FACES = 9976 | |
| KpId2manokps = { | |
| 0: 0, # Wrist | |
| 1: 5, | |
| 2: 6, | |
| 3: 7, # Index | |
| 4: 9, | |
| 5: 10, | |
| 6: 11, # Middle | |
| 7: 17, | |
| 8: 18, | |
| 9: 19, # Pinky | |
| 10: 13, | |
| 11: 14, | |
| 12: 15, # Ring | |
| 13: 1, | |
| 14: 2, | |
| 15: 3 | |
| } # Thumb | |
| kpId2vertices = { | |
| 4: 744, # Thumb | |
| 8: 320, # Index | |
| 12: 443, # Middle | |
| 16: 555, # Ring | |
| 20: 672 # Pink | |
| } | |
| def __init__(self, | |
| *args, | |
| keypoint_src: str = 'mano', | |
| keypoint_dst: str = 'human_data', | |
| keypoint_approximate: bool = False, | |
| **kwargs): | |
| """ | |
| Args: | |
| *args: extra arguments for MANO initialization. | |
| keypoint_src: source convention of keypoints. This convention | |
| is used for keypoints obtained from joint regressors. | |
| Keypoints then undergo conversion into keypoint_dst | |
| convention. | |
| keypoint_dst: destination convention of keypoints. This convention | |
| is used for keypoints in the output. | |
| keypoint_approximate: whether to use approximate matching in | |
| convention conversion for keypoints. | |
| **kwargs: extra keyword arguments for MANO initialization. | |
| Returns: | |
| None | |
| """ | |
| super(MANO, self).__init__(*args, **kwargs) | |
| self.keypoint_src = keypoint_src | |
| self.keypoint_dst = keypoint_dst | |
| self.keypoint_approximate = keypoint_approximate | |
| self.num_verts = self.get_num_verts() | |
| self.num_faces = self.get_num_faces() | |
| self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
| def forward(self, | |
| *args, | |
| return_verts: bool = True, | |
| return_full_pose: bool = False, | |
| **kwargs) -> dict: | |
| """Forward function. | |
| Args: | |
| *args: extra arguments for MANO | |
| return_verts: whether to return vertices | |
| return_full_pose: whether to return full pose parameters | |
| **kwargs: extra arguments for MANO | |
| Returns: | |
| output: contains output parameters and attributes | |
| """ | |
| if 'right_hand_pose' in kwargs: | |
| kwargs['hand_pose'] = kwargs['right_hand_pose'] | |
| mano_output = super(MANO, self).forward(*args, **kwargs) | |
| joints = mano_output.joints | |
| joints = self.get_keypoints_from_mesh(mano_output.vertices, joints) | |
| joints, joint_mask = convert_kps(joints, | |
| src=self.keypoint_src, | |
| dst=self.keypoint_dst, | |
| approximate=self.keypoint_approximate) | |
| if isinstance(joint_mask, np.ndarray): | |
| joint_mask = torch.tensor(joint_mask, | |
| dtype=torch.uint8, | |
| device=joints.device) | |
| batch_size = joints.shape[0] | |
| joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
| output = dict( | |
| global_orient=mano_output.global_orient, | |
| hand_pose=mano_output.hand_pose, | |
| joints=joints, | |
| joint_mask=joint_mask, | |
| keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1), | |
| betas=mano_output.betas, | |
| ) | |
| if return_verts: | |
| output['vertices'] = mano_output.vertices | |
| if return_full_pose: | |
| output['full_pose'] = mano_output.full_pose | |
| return output | |
| def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed): | |
| """Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5 | |
| mesh vertices for the fingers.""" | |
| batch_size = keypoints_regressed.shape[0] | |
| keypoints = torch.zeros((batch_size, 21, 3)).cuda() | |
| # fill keypoints which are regressed | |
| for manoId, myId in self.KpId2manokps.items(): | |
| keypoints[:, myId, :] = keypoints_regressed[:, manoId, :] | |
| # get other keypoints from mesh | |
| for myId, meshId in self.kpId2vertices.items(): | |
| keypoints[:, myId, :] = mesh_vertices[:, meshId, :] | |
| return keypoints | |
| class MANOLayer(_MANOLayer): | |
| """Extension of the official MANO implementation.""" | |
| full_pose_keys = {'global_orient', 'hand_pose'} | |
| NUM_VERTS = 776 | |
| NUM_FACES = 9976 | |
| KpId2manokps = { | |
| 0: 0, # Wrist | |
| 1: 5, | |
| 2: 6, | |
| 3: 7, # Index | |
| 4: 9, | |
| 5: 10, | |
| 6: 11, # Middle | |
| 7: 17, | |
| 8: 18, | |
| 9: 19, # Pinky | |
| 10: 13, | |
| 11: 14, | |
| 12: 15, # Ring | |
| 13: 1, | |
| 14: 2, | |
| 15: 3 | |
| } # Thumb | |
| kpId2vertices = { | |
| 4: 744, # Thumb | |
| 8: 320, # Index | |
| 12: 443, # Middle | |
| 16: 555, # Ring | |
| 20: 672 # Pink | |
| } | |
| def __init__(self, | |
| *args, | |
| keypoint_src: str = 'mano', | |
| keypoint_dst: str = 'human_data', | |
| keypoint_approximate: bool = False, | |
| **kwargs): | |
| """ | |
| Args: | |
| *args: extra arguments for MANO initialization. | |
| keypoint_src: source convention of keypoints. This convention | |
| is used for keypoints obtained from joint regressors. | |
| Keypoints then undergo conversion into keypoint_dst | |
| convention. | |
| keypoint_dst: destination convention of keypoints. This convention | |
| is used for keypoints in the output. | |
| keypoint_approximate: whether to use approximate matching in | |
| convention conversion for keypoints. | |
| **kwargs: extra keyword arguments for MANO initialization. | |
| Returns: | |
| None | |
| """ | |
| super(MANOLayer, self).__init__(*args, **kwargs) | |
| self.keypoint_src = keypoint_src | |
| self.keypoint_dst = keypoint_dst | |
| self.keypoint_approximate = keypoint_approximate | |
| self.num_verts = self.get_num_verts() | |
| self.num_faces = self.get_num_faces() | |
| self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
| def forward(self, | |
| *args, | |
| return_verts: bool = True, | |
| return_full_pose: bool = False, | |
| **kwargs) -> dict: | |
| """Forward function. | |
| Args: | |
| *args: extra arguments for MANO | |
| return_verts: whether to return vertices | |
| return_full_pose: whether to return full pose parameters | |
| **kwargs: extra arguments for MANO | |
| Returns: | |
| output: contains output parameters and attributes | |
| """ | |
| if 'right_hand_pose' in kwargs: | |
| kwargs['hand_pose'] = kwargs['right_hand_pose'] | |
| mano_output = super(MANOLayer, self).forward(*args, **kwargs) | |
| joints = mano_output.joints | |
| joints = self.get_keypoints_from_mesh(mano_output.vertices, joints) | |
| joints, joint_mask = convert_kps(joints, | |
| src=self.keypoint_src, | |
| dst=self.keypoint_dst, | |
| approximate=self.keypoint_approximate) | |
| if isinstance(joint_mask, np.ndarray): | |
| joint_mask = torch.tensor(joint_mask, | |
| dtype=torch.uint8, | |
| device=joints.device) | |
| batch_size = joints.shape[0] | |
| joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
| output = dict( | |
| global_orient=mano_output.global_orient, | |
| hand_pose=mano_output.hand_pose, | |
| joints=joints, | |
| joint_mask=joint_mask, | |
| keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1), | |
| betas=mano_output.betas, | |
| ) | |
| if return_verts: | |
| output['vertices'] = mano_output.vertices | |
| if return_full_pose: | |
| output['full_pose'] = mano_output.full_pose | |
| return output | |
| def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed): | |
| """Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5 | |
| mesh vertices for the fingers.""" | |
| batch_size = keypoints_regressed.shape[0] | |
| keypoints = torch.zeros((batch_size, 21, 3)).cuda() | |
| # fill keypoints which are regressed | |
| for manoId, myId in self.KpId2manokps.items(): | |
| keypoints[:, myId, :] = keypoints_regressed[:, manoId, :] | |
| # get other keypoints from mesh | |
| for myId, meshId in self.kpId2vertices.items(): | |
| keypoints[:, myId, :] = mesh_vertices[:, meshId, :] | |
| return keypoints | |