Spaces:
Runtime error
Runtime error
| import einops | |
| import numpy as np | |
| import torch | |
| import pytorch_lightning as pl | |
| from typing import Dict | |
| from torchvision.utils import make_grid | |
| from tqdm import tqdm | |
| from yacs.config import CfgNode | |
| from lib.datasets.track_dataset import TrackDatasetEval | |
| from lib.models.modules import MANOTransformerDecoderHead, temporal_attention | |
| from hawor.utils.pylogger import get_pylogger | |
| from hawor.utils.render_openpose import render_openpose | |
| from lib.utils.geometry import rot6d_to_rotmat_hmr2 as rot6d_to_rotmat | |
| from lib.utils.geometry import perspective_projection | |
| from hawor.utils.rotation import angle_axis_to_rotation_matrix | |
| from torch.utils.data import default_collate | |
| from .backbones import create_backbone | |
| from .mano_wrapper import MANO | |
| log = get_pylogger(__name__) | |
| idx = 0 | |
| class HAWOR(pl.LightningModule): | |
| def __init__(self, cfg: CfgNode): | |
| """ | |
| Setup HAWOR model | |
| Args: | |
| cfg (CfgNode): Config file as a yacs CfgNode | |
| """ | |
| super().__init__() | |
| # Save hyperparameters | |
| self.save_hyperparameters(logger=False, ignore=['init_renderer']) | |
| self.cfg = cfg | |
| self.crop_size = cfg.MODEL.IMAGE_SIZE | |
| self.seq_len = 16 | |
| self.pose_num = 16 | |
| self.pose_dim = 6 # rot6d representation | |
| self.box_info_dim = 3 | |
| # Create backbone feature extractor | |
| self.backbone = create_backbone(cfg) | |
| try: | |
| if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): | |
| whole_state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'] | |
| backbone_state_dict = {} | |
| for key in whole_state_dict: | |
| if key[:9] == 'backbone.': | |
| backbone_state_dict[key[9:]] = whole_state_dict[key] | |
| self.backbone.load_state_dict(backbone_state_dict) | |
| print(f'Loaded backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| else: | |
| print('WARNING: init backbone from sratch !!!') | |
| except: | |
| print('WARNING: init backbone from sratch !!!') | |
| # Space-time memory | |
| if cfg.MODEL.ST_MODULE: | |
| hdim = cfg.MODEL.ST_HDIM | |
| nlayer = cfg.MODEL.ST_NLAYER | |
| self.st_module = temporal_attention(in_dim=1280+3, | |
| out_dim=1280, | |
| hdim=hdim, | |
| nlayer=nlayer, | |
| residual=True) | |
| print(f'Using Temporal Attention space-time: {nlayer} layers {hdim} dim.') | |
| else: | |
| self.st_module = None | |
| # Motion memory | |
| if cfg.MODEL.MOTION_MODULE: | |
| hdim = cfg.MODEL.MOTION_HDIM | |
| nlayer = cfg.MODEL.MOTION_NLAYER | |
| self.motion_module = temporal_attention(in_dim=self.pose_num * self.pose_dim + self.box_info_dim, | |
| out_dim=self.pose_num * self.pose_dim, | |
| hdim=hdim, | |
| nlayer=nlayer, | |
| residual=False) | |
| print(f'Using Temporal Attention motion layer: {nlayer} layers {hdim} dim.') | |
| else: | |
| self.motion_module = None | |
| # Create MANO head | |
| # self.mano_head = build_mano_head(cfg) | |
| self.mano_head = MANOTransformerDecoderHead(cfg) | |
| # default open torch compile | |
| if cfg.MODEL.BACKBONE.get('TORCH_COMPILE', 0): | |
| log.info("Model will use torch.compile") | |
| self.backbone = torch.compile(self.backbone) | |
| self.mano_head = torch.compile(self.mano_head) | |
| # Define loss functions | |
| # self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') | |
| # self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') | |
| # self.mano_parameter_loss = ParameterLoss() | |
| # Instantiate MANO model | |
| mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} | |
| self.mano = MANO(**mano_cfg) | |
| # Buffer that shows whetheer we need to initialize ActNorm layers | |
| self.register_buffer('initialized', torch.tensor(False)) | |
| # Disable automatic optimization since we use adversarial training | |
| self.automatic_optimization = False | |
| if cfg.MODEL.get('LOAD_WEIGHTS', None): | |
| whole_state_dict = torch.load(cfg.MODEL.LOAD_WEIGHTS, map_location='cpu')['state_dict'] | |
| self.load_state_dict(whole_state_dict, strict=True) | |
| print(f"load {cfg.MODEL.LOAD_WEIGHTS}") | |
| def get_parameters(self): | |
| all_params = list(self.mano_head.parameters()) | |
| if not self.st_module is None: | |
| all_params += list(self.st_module.parameters()) | |
| if not self.motion_module is None: | |
| all_params += list(self.motion_module.parameters()) | |
| all_params += list(self.backbone.parameters()) | |
| return all_params | |
| def configure_optimizers(self) -> torch.optim.Optimizer: | |
| """ | |
| Setup model and distriminator Optimizers | |
| Returns: | |
| Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers | |
| """ | |
| param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] | |
| optimizer = torch.optim.AdamW(params=param_groups, | |
| # lr=self.cfg.TRAIN.LR, | |
| weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) | |
| return optimizer | |
| def forward_step(self, batch: Dict, train: bool = False) -> Dict: | |
| """ | |
| Run a forward step of the network | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| train (bool): Flag indicating whether it is training or validation mode | |
| Returns: | |
| Dict: Dictionary containing the regression output | |
| """ | |
| image = batch['img'].flatten(0, 1) | |
| center = batch['center'].flatten(0, 1) | |
| scale = batch['scale'].flatten(0, 1) | |
| img_focal = batch['img_focal'].flatten(0, 1) | |
| img_center = batch['img_center'].flatten(0, 1) | |
| bn = len(image) | |
| # estimate focal length, and bbox | |
| bbox_info = self.bbox_est(center, scale, img_focal, img_center) | |
| # backbone | |
| feature = self.backbone(image[:,:,:,32:-32]) | |
| feature = feature.float() | |
| # space-time module | |
| if self.st_module is not None: | |
| bb = einops.repeat(bbox_info, 'b c -> b c h w', h=16, w=12) | |
| feature = torch.cat([feature, bb], dim=1) | |
| feature = einops.rearrange(feature, '(b t) c h w -> (b h w) t c', t=16) | |
| feature = self.st_module(feature) | |
| feature = einops.rearrange(feature, '(b h w) t c -> (b t) c h w', h=16, w=12) | |
| # smpl_head: transformer + smpl | |
| # pred_mano_params, pred_cam, pred_mano_params_list = self.mano_head(feature) | |
| # pred_shape = pred_mano_params_list['pred_shape'] | |
| # pred_pose = pred_mano_params_list['pred_pose'] | |
| pred_pose, pred_shape, pred_cam = self.mano_head(feature) | |
| pred_rotmat_0 = rot6d_to_rotmat(pred_pose).reshape(-1, self.pose_num, 3, 3) | |
| # smpl motion module | |
| if self.motion_module is not None: | |
| bb = einops.rearrange(bbox_info, '(b t) c -> b t c', t=16) | |
| pred_pose = einops.rearrange(pred_pose, '(b t) c -> b t c', t=16) | |
| pred_pose = torch.cat([pred_pose, bb], dim=2) | |
| pred_pose = self.motion_module(pred_pose) | |
| pred_pose = einops.rearrange(pred_pose, 'b t c -> (b t) c') | |
| out = {} | |
| if 'do_flip' in batch: | |
| pred_cam[..., 1] *= -1 | |
| center[..., 0] = img_center[..., 0]*2 - center[..., 0] - 1 | |
| out['pred_cam'] = pred_cam | |
| out['pred_pose'] = pred_pose | |
| out['pred_shape'] = pred_shape | |
| out['pred_rotmat'] = rot6d_to_rotmat(out['pred_pose']).reshape(-1, self.pose_num, 3, 3) | |
| out['pred_rotmat_0'] = pred_rotmat_0 | |
| s_out = self.mano.query(out) | |
| j3d = s_out.joints | |
| j2d = self.project(j3d, out['pred_cam'], center, scale, img_focal, img_center) | |
| j2d = j2d / self.crop_size - 0.5 # norm to [-0.5, 0.5] | |
| trans_full = self.get_trans(out['pred_cam'], center, scale, img_focal, img_center) | |
| out['trans_full'] = trans_full | |
| output = { | |
| 'pred_mano_params': { | |
| 'global_orient': out['pred_rotmat'][:, :1].clone(), | |
| 'hand_pose': out['pred_rotmat'][:, 1:].clone(), | |
| 'betas': out['pred_shape'].clone(), | |
| }, | |
| 'pred_keypoints_3d': j3d.clone(), | |
| 'pred_keypoints_2d': j2d.clone(), | |
| 'out': out, | |
| } | |
| # print(output) | |
| # output['gt_project_j2d'] = self.project(batch['gt_j3d_wo_trans'].clone().flatten(0,1), out['pred_cam'], center, scale, img_focal, img_center) | |
| # output['gt_project_j2d'] = output['gt_project_j2d'] / self.crop_size - 0.5 | |
| return output | |
| def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: | |
| """ | |
| Compute losses given the input batch and the regression output | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| output (Dict): Dictionary containing the regression output | |
| train (bool): Flag indicating whether it is training or validation mode | |
| Returns: | |
| torch.Tensor : Total loss for current batch | |
| """ | |
| pred_mano_params = output['pred_mano_params'] | |
| pred_keypoints_2d = output['pred_keypoints_2d'] | |
| pred_keypoints_3d = output['pred_keypoints_3d'] | |
| batch_size = pred_mano_params['hand_pose'].shape[0] | |
| device = pred_mano_params['hand_pose'].device | |
| dtype = pred_mano_params['hand_pose'].dtype | |
| # Get annotations | |
| gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0, 1) | |
| gt_keypoints_2d = torch.cat([gt_keypoints_2d, torch.ones(*gt_keypoints_2d.shape[:-1], 1, device=gt_keypoints_2d.device)], dim=-1) | |
| gt_keypoints_3d = batch['gt_j3d_wo_trans'].flatten(0, 1) | |
| gt_keypoints_3d = torch.cat([gt_keypoints_3d, torch.ones(*gt_keypoints_3d.shape[:-1], 1, device=gt_keypoints_3d.device)], dim=-1) | |
| pose_gt = batch['gt_cam_full_pose'].flatten(0, 1).reshape(-1, 16, 3) | |
| rotmat_gt = angle_axis_to_rotation_matrix(pose_gt) | |
| gt_mano_params = { | |
| 'global_orient': rotmat_gt[:, :1], | |
| 'hand_pose': rotmat_gt[:, 1:], | |
| 'betas': batch['gt_cam_betas'], | |
| } | |
| # Compute 3D keypoint loss | |
| loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) | |
| loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) | |
| # to avoid nan | |
| loss_keypoints_2d = torch.nan_to_num(loss_keypoints_2d) | |
| # Compute loss on MANO parameters | |
| loss_mano_params = {} | |
| for k, pred in pred_mano_params.items(): | |
| gt = gt_mano_params[k].view(batch_size, -1) | |
| loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1)) | |
| loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ | |
| self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ | |
| sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) | |
| losses = dict(loss=loss.detach(), | |
| loss_keypoints_2d=loss_keypoints_2d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'], | |
| loss_keypoints_3d=loss_keypoints_3d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D']) | |
| for k, v in loss_mano_params.items(): | |
| losses['loss_' + k] = v.detach() * self.cfg.LOSS_WEIGHTS[k.upper()] | |
| output['losses'] = losses | |
| return loss | |
| # Tensoroboard logging should run from first rank only | |
| def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True, render_log: bool = True) -> None: | |
| """ | |
| Log results to Tensorboard | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| output (Dict): Dictionary containing the regression output | |
| step_count (int): Global training step count | |
| train (bool): Flag indicating whether it is training or validation mode | |
| """ | |
| mode = 'train' if train else 'val' | |
| batch_size = output['pred_keypoints_2d'].shape[0] | |
| images = batch['img'].flatten(0,1) | |
| images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) | |
| images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) | |
| losses = output['losses'] | |
| if write_to_summary_writer: | |
| summary_writer = self.logger.experiment | |
| for loss_name, val in losses.items(): | |
| summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) | |
| if render_log: | |
| gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0,1).clone() | |
| pred_keypoints_2d = output['pred_keypoints_2d'].clone().detach().reshape(batch_size, -1, 2) | |
| gt_project_j2d = pred_keypoints_2d | |
| # gt_project_j2d = output['gt_project_j2d'].clone().detach().reshape(batch_size, -1, 2) | |
| num_images = 4 | |
| skip=16 | |
| predictions = self.visualize_tensorboard(images[:num_images*skip:skip].cpu().numpy(), | |
| pred_keypoints_2d[:num_images*skip:skip].cpu().numpy(), | |
| gt_project_j2d[:num_images*skip:skip].cpu().numpy(), | |
| gt_keypoints_2d[:num_images*skip:skip].cpu().numpy(), | |
| ) | |
| summary_writer.add_image('%s/predictions' % mode, predictions, step_count) | |
| def forward(self, batch: Dict) -> Dict: | |
| """ | |
| Run a forward step of the network in val mode | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| Returns: | |
| Dict: Dictionary containing the regression output | |
| """ | |
| return self.forward_step(batch, train=False) | |
| def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: | |
| """ | |
| Run a full training step | |
| Args: | |
| joint_batch (Dict): Dictionary containing image and mocap batch data | |
| batch_idx (int): Unused. | |
| batch_idx (torch.Tensor): Unused. | |
| Returns: | |
| Dict: Dictionary containing regression output. | |
| """ | |
| batch = joint_batch['img'] | |
| optimizer = self.optimizers(use_pl_optimizer=True) | |
| batch_size = batch['img'].shape[0] | |
| output = self.forward_step(batch, train=True) | |
| # pred_mano_params = output['pred_mano_params'] | |
| loss = self.compute_loss(batch, output, train=True) | |
| # Error if Nan | |
| if torch.isnan(loss): | |
| raise ValueError('Loss is NaN') | |
| optimizer.zero_grad() | |
| self.manual_backward(loss) | |
| # Clip gradient | |
| if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: | |
| gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) | |
| self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) | |
| optimizer.step() | |
| # if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: | |
| if self.global_step > 0 and self.global_step % 100 == 0: | |
| self.tensorboard_logging(batch, output, self.global_step, train=True, render_log=self.cfg.TRAIN.get("RENDER_LOG", True)) | |
| self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=batch_size) | |
| return output | |
| def inference(self, imgfiles, boxes, img_focal, img_center, device='cuda', do_flip=False): | |
| db = TrackDatasetEval(imgfiles, boxes, img_focal=img_focal, | |
| img_center=img_center, normalization=True, dilate=1.2, do_flip=do_flip) | |
| # Results | |
| pred_cam = [] | |
| pred_pose = [] | |
| pred_shape = [] | |
| pred_rotmat = [] | |
| pred_trans = [] | |
| # To-do: efficient implementation with batch | |
| items = [] | |
| for i in tqdm(range(len(db))): | |
| item = db[i] | |
| items.append(item) | |
| # padding to 16 | |
| if i == len(db) - 1 and len(db) % 16 != 0: | |
| pad = 16 - len(db) % 16 | |
| for _ in range(pad): | |
| items.append(item) | |
| if len(items) < 16: | |
| continue | |
| elif len(items) == 16: | |
| batch = default_collate(items) | |
| items = [] | |
| else: | |
| raise NotImplementedError | |
| with torch.no_grad(): | |
| batch = {k: v.to(device).unsqueeze(0) for k, v in batch.items() if type(v)==torch.Tensor} | |
| # for image_i in range(16): | |
| # hawor_input_cv2 = vis_tensor_cv2(batch['img'][:, image_i]) | |
| # cv2.imwrite(f'debug_vis_model.png', hawor_input_cv2) | |
| # print("vis") | |
| output = self.forward(batch) | |
| out = output['out'] | |
| if i == len(db) - 1 and len(db) % 16 != 0: | |
| out = {k:v[:len(db) % 16] for k,v in out.items()} | |
| else: | |
| out = {k:v for k,v in out.items()} | |
| pred_cam.append(out['pred_cam'].cpu()) | |
| pred_pose.append(out['pred_pose'].cpu()) | |
| pred_shape.append(out['pred_shape'].cpu()) | |
| pred_rotmat.append(out['pred_rotmat'].cpu()) | |
| pred_trans.append(out['trans_full'].cpu()) | |
| results = {'pred_cam': torch.cat(pred_cam), | |
| 'pred_pose': torch.cat(pred_pose), | |
| 'pred_shape': torch.cat(pred_shape), | |
| 'pred_rotmat': torch.cat(pred_rotmat), | |
| 'pred_trans': torch.cat(pred_trans), | |
| 'img_focal': img_focal, | |
| 'img_center': img_center} | |
| return results | |
| def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: | |
| """ | |
| Run a validation step and log to Tensorboard | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| batch_idx (int): Unused. | |
| Returns: | |
| Dict: Dictionary containing regression output. | |
| """ | |
| # batch_size = batch['img'].shape[0] | |
| output = self.forward_step(batch, train=False) | |
| loss = self.compute_loss(batch, output, train=False) | |
| output['loss'] = loss | |
| self.tensorboard_logging(batch, output, self.global_step, train=False) | |
| return output | |
| def visualize_tensorboard(self, images, pred_keypoints, gt_project_j2d, gt_keypoints): | |
| pred_keypoints = 256 * (pred_keypoints + 0.5) | |
| gt_keypoints = 256 * (gt_keypoints + 0.5) | |
| gt_project_j2d = 256 * (gt_project_j2d + 0.5) | |
| pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) | |
| gt_keypoints = np.concatenate((gt_keypoints, np.ones_like(gt_keypoints)[:, :, [0]]), axis=-1) | |
| gt_project_j2d = np.concatenate((gt_project_j2d, np.ones_like(gt_project_j2d)[:, :, [0]]), axis=-1) | |
| images_np = np.transpose(images, (0,2,3,1)) | |
| rend_imgs = [] | |
| for i in range(images_np.shape[0]): | |
| pred_keypoints_img = render_openpose(255 * images_np[i].copy(), pred_keypoints[i]) / 255 | |
| gt_project_j2d_img = render_openpose(255 * images_np[i].copy(), gt_project_j2d[i]) / 255 | |
| gt_keypoints_img = render_openpose(255*images_np[i].copy(), gt_keypoints[i]) / 255 | |
| rend_imgs.append(torch.from_numpy(images[i])) | |
| rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1)) | |
| rend_imgs.append(torch.from_numpy(gt_project_j2d_img).permute(2,0,1)) | |
| rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1)) | |
| rend_imgs = make_grid(rend_imgs, nrow=4, padding=2) | |
| return rend_imgs | |
| def project(self, points, pred_cam, center, scale, img_focal, img_center, return_full=False): | |
| trans_full = self.get_trans(pred_cam, center, scale, img_focal, img_center) | |
| # Projection in full frame image coordinate | |
| points = points + trans_full | |
| points2d_full = perspective_projection(points, rotation=None, translation=None, | |
| focal_length=img_focal, camera_center=img_center) | |
| # Adjust projected points to crop image coordinate | |
| # (s.t. 1. we can calculate loss in crop image easily | |
| # 2. we can query its pixel in the crop | |
| # ) | |
| b = scale * 200 | |
| points2d = points2d_full - (center - b[:,None]/2)[:,None,:] | |
| points2d = points2d * (self.crop_size / b)[:,None,None] | |
| if return_full: | |
| return points2d_full, points2d | |
| else: | |
| return points2d | |
| def get_trans(self, pred_cam, center, scale, img_focal, img_center): | |
| b = scale * 200 | |
| cx, cy = center[:,0], center[:,1] # center of crop | |
| s, tx, ty = pred_cam.unbind(-1) | |
| img_cx, img_cy = img_center[:,0], img_center[:,1] # center of original image | |
| bs = b*s | |
| tx_full = tx + 2*(cx-img_cx)/bs | |
| ty_full = ty + 2*(cy-img_cy)/bs | |
| tz_full = 2*img_focal/bs | |
| trans_full = torch.stack([tx_full, ty_full, tz_full], dim=-1) | |
| trans_full = trans_full.unsqueeze(1) | |
| return trans_full | |
| def bbox_est(self, center, scale, img_focal, img_center): | |
| # Original image center | |
| img_cx, img_cy = img_center[:,0], img_center[:,1] | |
| # Implement CLIFF (Li et al.) bbox feature | |
| cx, cy, b = center[:, 0], center[:, 1], scale * 200 | |
| bbox_info = torch.stack([cx - img_cx, cy - img_cy, b], dim=-1) | |
| bbox_info[:, :2] = bbox_info[:, :2] / img_focal.unsqueeze(-1) * 2.8 | |
| bbox_info[:, 2] = (bbox_info[:, 2] - 0.24 * img_focal) / (0.06 * img_focal) | |
| return bbox_info | |