File size: 1,720 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch

from typing import Optional

from .transforms import real_orient_mat2q, real_orient_q2mat


def update_params_after_orient_rotation(
    poses       : torch.Tensor,           # (B, 46)
    rot_mat     : torch.Tensor,           # the rotation orientation matrix
    root_offset : Optional[torch.Tensor] = None, # the offset from custom root to model root
):
    '''
    
    ### Args
    - `poses`: torch.Tensor, shape = (B, 46)
    - `rot_mat`: torch.Tensor, shape = (B, 3, 3)
    - `root_offset`: torch.Tensor or None, shape = (B, 3)
        - If None, the function won't update the translation.
        - If not None, the function will calculate the root translation offset that make the model 
           rotate around the custom root instead of the model root.
           
    ### Returns
    - If `root_offset` is None:
        - `poses`: torch.Tensor, shape = (B, 46)
    - If `root_offset` is not None:
        - `poses`: torch.Tensor, shape = (B, 46)
        - `trans_offset`: torch.Tensor, shape = (B, 3)
    '''
    poses = poses.clone()
    # 1. Transform the SKEL orientation to real matrix.
    orient_q = poses[:, :3]  # (B, 3)
    orient_mat = real_orient_q2mat(orient_q)  # (B, 3, 3)
    orient_mat = torch.einsum('bij,bjk->bik', rot_mat, orient_mat)  # (B, 3, 3)
    orient_q = real_orient_mat2q(orient_mat)  # (B, 3)
    poses[:, :3] = orient_q

    # 2. Update the translation if needed.
    if root_offset is not None:
        root_before = root_offset.clone()  # (B, 3)
        root_after = torch.einsum('bij,bj->bi', rot_mat, root_before)  # (B, 3)
        root_offset = root_after - root_before  # (B, 3)
        ret = poses, root_offset
    else:
        ret = poses

    return ret