Spaces:
Running
on
L4
Running
on
L4
File size: 6,300 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from lib.kits.basic import *
import pickle
from smplx.vertex_joint_selector import VertexJointSelector
from smplx.vertex_ids import vertex_ids
from smplx.lbs import vertices2joints
from lib.body_models.skel.skel_model import SKEL, SKELOutput
class SKELWrapper(SKEL):
def __init__(
self,
*args,
joint_regressor_custom: Optional[str] = None,
joint_regressor_extra : Optional[str] = None,
update_root : bool = False,
**kwargs
):
''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. '''
super(SKELWrapper, self).__init__(*args, **kwargs)
# The final joints are combined from three parts:
# 1. The joints from the standard output.
# Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.)
# Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53]
skel_to_smpl = [
0,
6,
1,
11, # not aligned well; not used
7,
2,
11, # not aligned well; not used
8, # or 9
3, # or 4
12, # not aligned well; not used
10, # not used
5, # not used
12,
19, # not aligned well; not used
14, # not aligned well; not used
13, # not used
20, # or 19
15, # or 14
21, # or 22
16, # or 17,
23,
18,
23, # not aligned well; not used
18, # not aligned well; not used
]
smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long))
self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long))
# (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.)
# 2. Joints selected from skin vertices.
self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh'])
# 3. Extra joints from the J_regressor_extra.
if joint_regressor_extra is not None:
self.register_buffer(
'J_regressor_extra',
torch.tensor(pickle.load(
open(joint_regressor_extra, 'rb'),
encoding='latin1'
), dtype=torch.float32)
)
self.custom_regress_joints = joint_regressor_custom is not None
if self.custom_regress_joints:
get_logger().info('Using customized joint regressor.')
with open(joint_regressor_custom, 'rb') as f:
J_regressor_custom = pickle.load(f, encoding='latin1')
if 'scipy.sparse' in str(type(J_regressor_custom)):
J_regressor_custom = J_regressor_custom.todense() # (24, 6890)
self.register_buffer(
'J_regressor_custom',
torch.tensor(
J_regressor_custom,
dtype=torch.float32
)
)
self.update_root = update_root
def forward(self, **kwargs) -> SKELOutput: # type: ignore
''' Map the order of joints of SKEL to SMPL's. '''
if 'trans' not in kwargs.keys():
kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3)) # (B, 3)
skel_output = super(SKELWrapper, self).forward(**kwargs)
verts = skel_output.skin_verts # (B, 6890, 3)
joints = skel_output.joints.clone() # (B, 24, 3)
# Update the root joint position (to avoid the root too forward).
if self.update_root:
# make root 0 to plane 11-1-6
hips_middle = (joints[:, 1] + joints[:, 6]) / 2 # (B, 3)
lumbar2middle = (hips_middle - joints[:, 11]) # (B, 3)
lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True) # (B, 3)
lumbar2root = joints[:, 0] - joints[:, 11]
lumbar2root_proj = \
torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\
lumbar2middle_unit # (B, 3)
root2root_proj = lumbar2root_proj - lumbar2root # (B, 3)
joints[:, 0] += root2root_proj * 0.7
# Combine the joints from three parts:
if self.custom_regress_joints:
# 1.x. Regress joints from the skin vertices using SMPL's regressor.
joints = vertices2joints(self.J_regressor_custom, verts) # (B, 24, 3)
else:
# 1.y. Map selected joints of interests from SKEL to SMPL.
joints = joints[:, self.J_skel_to_smpl] # (B, 24, 3)
joints_custom = joints.clone()
# 2. Concat joints selected from skin vertices.
joints = self.vertex_joint_selector(verts, joints) # (B, 45, 3)
# 3. Map selected joints to OpenPose.
joints = joints[:, self.J_smpl_to_openpose] # (B, 25, 3)
# 4. Add extra joints from the J_regressor_extra.
joints_extra = vertices2joints(self.J_regressor_extra, verts) # (B, 19, 3)
joints = torch.cat([joints, joints_extra], dim=1) # (B, 44, 3)
# Update the joints in the output.
skel_output.joints_backup = skel_output.joints
skel_output.joints_custom = joints_custom
skel_output.joints = joints
return skel_output
@staticmethod
def get_static_root_offset(skel_output):
'''
Background:
By default, the orientation rotation is always around the original skel_root.
In order to make the orientation rotation around the custom_root, we need to calculate the translation offset.
This function calculates the translation offset in static pose. (From custom_root to skel_root.)
'''
custom_root = skel_output.joints_custom[:, 0] # (B, 3)
skel_root = skel_output.joints_backup[:, 0] # (B, 3)
offset = skel_root - custom_root # (B, 3)
return offset |