File size: 6,300 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from lib.kits.basic import *

import pickle
from smplx.vertex_joint_selector import VertexJointSelector
from smplx.vertex_ids import vertex_ids
from smplx.lbs import vertices2joints

from lib.body_models.skel.skel_model import SKEL, SKELOutput

class SKELWrapper(SKEL):
    def __init__(
        self,
        *args,
        joint_regressor_custom: Optional[str] = None,
        joint_regressor_extra : Optional[str] = None,
        update_root : bool = False,
        **kwargs
    ):
        ''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. '''

        super(SKELWrapper, self).__init__(*args, **kwargs)

        # The final joints are combined from three parts:
        # 1. The joints from the standard output.
        #    Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.)
        #    Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53]
        skel_to_smpl = [
             0,
             6,
             1,
            11, # not aligned well; not used
             7,
             2,
            11, # not aligned well; not used
             8, # or 9
             3, # or 4
            12, # not aligned well; not used
            10, # not used
             5, # not used
            12,
            19, # not aligned well; not used
            14, # not aligned well; not used
            13, # not used
            20, # or 19
            15, # or 14
            21, # or 22
            16, # or 17,
            23,
            18,
            23, # not aligned well; not used
            18, # not aligned well; not used
        ]

        smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]

        self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long))
        self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long))
        # (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.)
        # 2. Joints selected from skin vertices.
        self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh'])
        # 3. Extra joints from the J_regressor_extra.
        if joint_regressor_extra is not None:
            self.register_buffer(
                'J_regressor_extra',
                torch.tensor(pickle.load(
                    open(joint_regressor_extra, 'rb'),
                    encoding='latin1'
                ), dtype=torch.float32)
            )

        self.custom_regress_joints = joint_regressor_custom is not None
        if self.custom_regress_joints:
            get_logger().info('Using customized joint regressor.')
            with open(joint_regressor_custom, 'rb') as f:
                J_regressor_custom = pickle.load(f, encoding='latin1')
            if 'scipy.sparse' in str(type(J_regressor_custom)):
                J_regressor_custom = J_regressor_custom.todense()  # (24, 6890)
            self.register_buffer(
                'J_regressor_custom',
                torch.tensor(
                        J_regressor_custom,
                        dtype=torch.float32
                    )
            )

        self.update_root = update_root

    def forward(self, **kwargs) -> SKELOutput:  # type: ignore
        ''' Map the order of joints of SKEL to SMPL's. '''

        if 'trans' not in kwargs.keys():
            kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3))  # (B, 3)

        skel_output = super(SKELWrapper, self).forward(**kwargs)
        verts = skel_output.skin_verts  # (B, 6890, 3)
        joints = skel_output.joints.clone()  # (B, 24, 3)

        # Update the root joint position (to avoid the root too forward).
        if self.update_root:
            # make root 0 to plane 11-1-6
            hips_middle = (joints[:, 1] + joints[:, 6]) / 2  # (B, 3)
            lumbar2middle = (hips_middle - joints[:, 11])  # (B, 3)
            lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True)  # (B, 3)
            lumbar2root = joints[:, 0] - joints[:, 11]
            lumbar2root_proj = \
                torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\
                lumbar2middle_unit  # (B, 3)
            root2root_proj = lumbar2root_proj - lumbar2root  # (B, 3)
            joints[:, 0] += root2root_proj * 0.7

        # Combine the joints from three parts:
        if self.custom_regress_joints:
            # 1.x. Regress joints from the skin vertices using SMPL's regressor.
            joints = vertices2joints(self.J_regressor_custom, verts)  # (B, 24, 3)
        else:
            # 1.y. Map selected joints of interests from SKEL to SMPL.
            joints = joints[:, self.J_skel_to_smpl]  # (B, 24, 3)
        joints_custom = joints.clone()
        # 2. Concat joints selected from skin vertices.
        joints = self.vertex_joint_selector(verts, joints)  # (B, 45, 3)
        # 3. Map selected joints to OpenPose.
        joints = joints[:, self.J_smpl_to_openpose]  # (B, 25, 3)
        # 4. Add extra joints from the J_regressor_extra.
        joints_extra = vertices2joints(self.J_regressor_extra, verts)  # (B, 19, 3)
        joints = torch.cat([joints, joints_extra], dim=1)  # (B, 44, 3)

        # Update the joints in the output.
        skel_output.joints_backup = skel_output.joints
        skel_output.joints_custom = joints_custom
        skel_output.joints = joints

        return skel_output


    @staticmethod
    def get_static_root_offset(skel_output):
        '''
        Background:
            By default, the orientation rotation is always around the original skel_root.
            In order to make the orientation rotation around the custom_root, we need to calculate the translation offset.
        This function calculates the translation offset in static pose. (From custom_root to skel_root.)
        '''
        custom_root = skel_output.joints_custom[:, 0]  # (B, 3)
        skel_root = skel_output.joints_backup[:, 0]  # (B, 3)
        offset = skel_root - custom_root  # (B, 3)
        return offset