File size: 4,150 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from lib.kits.basic import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops

from omegaconf import OmegaConf

from lib.platform import PM
from lib.body_models.skel_utils.transforms import params_q2rep, params_rep2q

from .utils.pose_transformer import TransformerDecoder


class SKELTransformerDecoderHead(nn.Module):
    """ Cross-attention based SKEL Transformer decoder
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        if cfg.pd_poses_repr == 'rotation_6d':
            n_poses = 24 * 6
        elif cfg.pd_poses_repr == 'euler_angle':
            n_poses = 46
        else:
            raise ValueError(f"Unknown pose representation: {cfg.pd_poses_repr}")

        n_betas = 10
        n_cam   = 3
        self.input_is_mean_shape = False

        # Build transformer decoder.
        transformer_args = {
                'num_tokens' : 1,
                'token_dim'  : (n_poses + n_betas + n_cam) if self.input_is_mean_shape else 1,
                'dim'        : 1024,
            }
        transformer_args.update(OmegaConf.to_container(cfg.transformer_decoder, resolve=True))  # type: ignore
        self.transformer = TransformerDecoder(**transformer_args)

        # Build decoders for parameters.
        dim = transformer_args['dim']
        self.poses_decoder = nn.Linear(dim, n_poses)
        self.betas_decoder = nn.Linear(dim, n_betas)
        self.cam_decoder   = nn.Linear(dim, n_cam)

        # Load mean shape parameters as initial values.
        skel_mean_path = Path(__file__).parent / 'SKEL_mean.npz'
        skel_mean_params = np.load(skel_mean_path)

        init_poses = torch.from_numpy(skel_mean_params['poses'].astype(np.float32)).unsqueeze(0) # (1, 46)
        if cfg.pd_poses_repr == 'rotation_6d':
            init_poses = params_q2rep(init_poses).reshape(1, 24*6)  # (1, 24*6)
        init_betas = torch.from_numpy(skel_mean_params['betas'].astype(np.float32)).unsqueeze(0)
        init_cam = torch.from_numpy(skel_mean_params['cam'].astype(np.float32)).unsqueeze(0)

        self.register_buffer('init_poses', init_poses)
        self.register_buffer('init_betas', init_betas)
        self.register_buffer('init_cam', init_cam)

    def forward(self, x, **kwargs):
        B = x.shape[0]
        # vit pretrained backbone is channel-first. Change to token-first
        x = einops.rearrange(x, 'b c h w -> b (h w) c')

        # Initialize the parameters.
        init_poses = self.init_poses.expand(B, -1)  # (B, 46)
        init_betas = self.init_betas.expand(B, -1)  # (B, 10)
        init_cam   = self.init_cam.expand(B, -1)    # (B, 3)

        # Input token to transformer is zero token.
        with PM.time_monitor('init_token'):
            if self.input_is_mean_shape:
                token = torch.cat([init_poses, init_betas, init_cam], dim=1)[:, None, :]  # (B, 1, C)
            else:
                token = x.new_zeros(B, 1, 1)

        # Pass through transformer.
        with PM.time_monitor('transformer'):
            token_out = self.transformer(token, context=x)
            token_out = token_out.squeeze(1)  # (B, C)

        # Parse the SKEL parameters out from token_out.
        with PM.time_monitor('decode'):
            pd_poses = self.poses_decoder(token_out) + init_poses
            pd_betas = self.betas_decoder(token_out) + init_betas
            pd_cam = self.cam_decoder(token_out) + init_cam

        with PM.time_monitor('rot_repr_transform'):
            if self.cfg.pd_poses_repr == 'rotation_6d':
                pd_poses = params_rep2q(pd_poses.reshape(-1, 24, 6))  # (B, 46)
            elif self.cfg.pd_poses_repr == 'euler_angle':
                pd_poses = pd_poses.reshape(-1, 46)  # (B, 46)
            else:
                raise ValueError(f"Unknown pose representation: {self.cfg.pd_poses_repr}")

        pd_skel_params = {
                'poses'        : pd_poses,
                'poses_orient' : pd_poses[:, :3],
                'poses_body'   : pd_poses[:, 3:],
                'betas'        : pd_betas
            }
        return pd_skel_params, pd_cam