import copy import os import numpy as np import torch from typing import Any, Dict, List, Union from yacs.config import CfgNode import braceexpand import cv2 from ipdb import set_trace from .dataset import Dataset from .utils import get_example, expand_to_aspect_ratio def expand(s): return os.path.expanduser(os.path.expandvars(s)) def expand_urls(urls: Union[str, List[str]]): if isinstance(urls, str): urls = [urls] urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))] return urls AIC_TRAIN_CORRUPT_KEYS = { '0a047f0124ae48f8eee15a9506ce1449ee1ba669', '1a703aa174450c02fbc9cfbf578a5435ef403689', '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d', '5c032b9626e410441544c7669123ecc4ae077058', 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07', '4a77adb695bef75a5d34c04d589baf646fe2ba35', 'a0689017b1065c664daef4ae2d14ea03d543217e', '39596a45cbd21bed4a5f9c2342505532f8ec5cbb', '3d33283b40610d87db660b62982f797d50a7366b', } CORRUPT_KEYS = { *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, } body_permutation = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21] extra_permutation = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18] FLIP_KEYPOINT_PERMUTATION = body_permutation + [25 + i for i in extra_permutation] DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) DEFAULT_IMG_SIZE = 256 class ImageDataset(Dataset): def __init__(self, cfg: CfgNode, dataset_file: str, img_dir: str, train: bool = True, prune: Dict[str, Any] = {}, **kwargs): """ Dataset class used for loading images and corresponding annotations. Args: cfg (CfgNode): Model config file. dataset_file (str): Path to npz file containing dataset info. img_dir (str): Path to image folder. train (bool): Whether it is for training or not (enables data augmentation). """ super(ImageDataset, self).__init__() self.train = train self.cfg = cfg self.img_size = cfg['IMAGE_SIZE'] self.mean = 255. * np.array(self.cfg['IMAGE_MEAN']) self.std = 255. * np.array(self.cfg['IMAGE_STD']) self.img_dir = img_dir self.data = np.load(dataset_file, allow_pickle=True) self.imgname = self.data['imgname'] self.personid = np.zeros(len(self.imgname), dtype=np.int32) self.extra_info = self.data.get('extra_info', [{} for _ in range(len(self.imgname))]) self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION) num_pose = 3 * 24 # Bounding boxes are assumed to be in the center and scale format self.center = self.data['center'] self.scale = self.data['scale'].reshape(len(self.center), -1) / 200.0 if self.scale.shape[1] == 1: self.scale = np.tile(self.scale, (1, 2)) assert self.scale.shape == (len(self.center), 2) # Get gt SMPLX parameters, if available try: self.body_pose = self.data['body_pose'].astype(np.float32) self.has_body_pose = self.data['has_body_pose'].astype(np.float32) except KeyError: self.body_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32) self.has_body_pose = np.zeros(len(self.imgname), dtype=np.float32) try: self.betas = self.data['betas'].astype(np.float32) self.has_betas = self.data['has_betas'].astype(np.float32) except KeyError: self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32) self.has_betas = np.zeros(len(self.imgname), dtype=np.float32) # try: # self.trans = self.data['trans'].astype(np.float32) # except KeyError: # self.trans = np.zeros((len(self.imgname), 3), dtype=np.float32) # Try to get 2d keypoints, if available try: body_keypoints_2d = self.data['body_keypoints_2d'] except KeyError: body_keypoints_2d = np.zeros((len(self.center), 25, 3)) # Try to get extra 2d keypoints, if available try: extra_keypoints_2d = self.data['extra_keypoints_2d'] except KeyError: extra_keypoints_2d = np.zeros((len(self.center), 19, 3)) self.keypoints_2d = np.concatenate((body_keypoints_2d, extra_keypoints_2d), axis=1).astype(np.float32) # Try to get 3d keypoints, if available try: body_keypoints_3d = self.data['body_keypoints_3d'].astype(np.float32) except KeyError: body_keypoints_3d = np.zeros((len(self.center), 25, 4), dtype=np.float32) # Try to get extra 3d keypoints, if available try: extra_keypoints_3d = self.data['extra_keypoints_3d'].astype(np.float32) except KeyError: extra_keypoints_3d = np.zeros((len(self.center), 19, 4), dtype=np.float32) body_keypoints_3d[:, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], -1] = 0 self.keypoints_3d = np.concatenate((body_keypoints_3d, extra_keypoints_3d), axis=1).astype(np.float32) def __len__(self) -> int: return len(self.scale) def __getitem__(self, idx: int) -> Dict: """ Returns an example from the dataset. """ try: image_file_rel = self.imgname[idx].decode('utf-8') except AttributeError: image_file_rel = self.imgname[idx] image_file = os.path.join(self.img_dir, image_file_rel) keypoints_2d = self.keypoints_2d[idx].copy() keypoints_3d = self.keypoints_3d[idx].copy() center = self.center[idx].copy() center_x = center[0] center_y = center[1] scale = self.scale[idx] BBOX_SHAPE = self.cfg['BBOX_SHAPE'] bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() bbox_expand_factor = bbox_size / ((scale*200).max()) body_pose = self.body_pose[idx].copy().astype(np.float32) betas = self.betas[idx].copy().astype(np.float32) # trans = self.trans[idx].copy().astype(np.float32) has_body_pose = self.has_body_pose[idx].copy() has_betas = self.has_betas[idx].copy() smpl_params = {'global_orient': body_pose[:3], 'body_pose': body_pose[3:], 'betas': betas, # 'trans': trans, } has_smpl_params = {'global_orient': has_body_pose, 'body_pose': has_body_pose, 'betas': has_betas } smpl_params_is_axis_angle = {'global_orient': True, 'body_pose': True, 'betas': False } augm_config = self.cfg['augm'] # Crop image and (possibly) perform data augmentation img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, augm_record = get_example(image_file, center_x, center_y, bbox_size, bbox_size, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, self.flip_keypoint_permutation, self.img_size, self.img_size, self.mean, self.std, self.train, augm_config) item = {} # These are the keypoints in the original image coordinates (before cropping) orig_keypoints_2d = self.keypoints_2d[idx].copy() item['img_patch'] = img_patch item['keypoints_2d'] = keypoints_2d.astype(np.float32) item['keypoints_3d'] = keypoints_3d.astype(np.float32) item['orig_keypoints_2d'] = orig_keypoints_2d item['box_center'] = self.center[idx].copy() item['box_size'] = bbox_size item['bbox_expand_factor'] = bbox_expand_factor item['img_size'] = 1.0 * img_size[::-1].copy() item['smpl_params'] = smpl_params item['has_smpl_params'] = has_smpl_params item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle item['imgname'] = image_file item['imgname_rel'] = image_file_rel item['personid'] = int(self.personid[idx]) item['extra_info'] = copy.deepcopy(self.extra_info[idx]) item['idx'] = idx item['_scale'] = scale item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process. return item @staticmethod def load_tars_as_webdataset(cfg: CfgNode, urls: Union[str, List[str]], train: bool, resampled=False, epoch_size=None, cache_dir=None, **kwargs) -> Dataset: """ Loads the dataset from a webdataset tar file. """ from .smplh_prob_filter import poses_check_probable, load_amass_hist_smooth IMG_SIZE = cfg['IMAGE_SIZE'] BBOX_SHAPE = cfg['BBOX_SHAPE'] MEAN = 255. * np.array(cfg['IMAGE_MEAN']) STD = 255. * np.array(cfg['IMAGE_STD']) def split_data(source): for item in source: datas = item['data.pyd'] for pid, data in enumerate(datas): data['pid'] = pid data['orig_has_body_pose'] = data['orig_pve_max'] < 0.1 data['orig_has_betas'] = data['orig_pve_max'] < 0.1 if 'flip_pve_mean' in data: data['flip_has_body_pose'] = data['flip_pve_mean'] < 0.05 # TODO: fix this problems data['flip_has_betas'] = data['flip_has_body_pose'] else: data['flip_has_body_pose'] = data['flip_pve_max'] < 0.65 data['flip_has_betas'] = data['flip_has_body_pose'] # data['body_pose'] = data['orig_poses'] # data['betas'] = data['orig_betas'] if 'detection.npz' in item: det_idx = data['extra_info']['detection_npz_idx'] mask = item['detection.npz']['masks'][det_idx] else: mask = np.ones_like(item['jpg'][:,:,0], dtype=bool) yield { '__key__': item['__key__'], 'jpg': item['jpg'], 'data.pyd': data, 'mask': mask, } def suppress_bad_kps(item, thresh=0.0): if thresh > 0: kp2d = item['data.pyd']['keypoints_2d'] kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2]) item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1) return item def filter_numkp(item, numkp=4, thresh=0.0): kp_conf = item['data.pyd']['keypoints_2d'][:, 2] return (kp_conf > thresh).sum() > numkp def filter_reproj_error(item, thresh=10**4.5): losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item() reproj_loss = losses.get('reprojection_loss', None) return reproj_loss is None or reproj_loss < thresh def filter_bbox_size(item, thresh=1): bbox_size_min = item['data.pyd']['scale'].min().item() * 200. return bbox_size_min > thresh def filter_no_poses(item): return (item['data.pyd']['has_body_pose'] > 0) def supress_bad_betas(item, thresh=3): for side in ['orig', 'flip']: has_betas = item['data.pyd'][f'{side}_has_betas'] if thresh > 0 and has_betas: betas_abs = np.abs(item['data.pyd'][f'{side}_betas']) if (betas_abs > thresh).any(): item['data.pyd'][f'{side}_has_betas'] = False return item amass_poses_hist100_smooth = load_amass_hist_smooth() def supress_bad_poses(item): for side in ['orig', 'flip']: has_body_pose = item['data.pyd'][f'{side}_has_body_pose'] if has_body_pose: body_pose = item['data.pyd'][f'{side}_body_pose'] pose_is_probable = poses_check_probable(torch.from_numpy(body_pose)[None, 3:], amass_poses_hist100_smooth).item() if not pose_is_probable: item['data.pyd'][f'{side}_has_body_pose'] = False return item def poses_betas_simultaneous(item): # We either have both body_pose and betas, or neither for side in ['orig', 'flip']: has_betas = item['data.pyd'][f'{side}_has_betas'] has_body_pose = item['data.pyd'][f'{side}_has_body_pose'] item['data.pyd'][f'{side}_has_betas'] = item['data.pyd'][f'{side}_has_body_pose'] = np.array(float((has_body_pose>0) and (has_betas>0))) return item def set_betas_for_reg(item): for side in ['orig', 'flip']: # Always have betas set to true has_betas = item['data.pyd'][f'{side}_has_betas'] betas = item['data.pyd'][f'{side}_betas'] if not (has_betas>0): item['data.pyd'][f'{side}_has_betas'] = np.array(float((True))) item['data.pyd'][f'{side}_betas'] = betas * 0 return item # Load the dataset if epoch_size is not None: resampled = True corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS) import webdataset as wds dataset = wds.WebDataset(expand_urls(urls), nodesplitter=wds.split_by_node, shardshuffle=True, resampled=resampled, cache_dir=cache_dir, ).select(corrupt_filter) if train: dataset = dataset.shuffle(100) dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png') # Process the dataset dataset = dataset.compose(split_data) # Filter/clean the dataset SUPPRESS_KP_CONF_THRESH = cfg.get('SUPPRESS_KP_CONF_THRESH', 0.0) SUPPRESS_BETAS_THRESH = cfg.get('SUPPRESS_BETAS_THRESH', 0.0) SUPPRESS_BAD_POSES = cfg.get('SUPPRESS_BAD_POSES', False) POSES_BETAS_SIMULTANEOUS = cfg.get('POSES_BETAS_SIMULTANEOUS', False) BETAS_REG = cfg.get('BETAS_REG', False) FILTER_NO_POSES = cfg.get('FILTER_NO_POSES', False) FILTER_NUM_KP = cfg.get('FILTER_NUM_KP', 4) FILTER_NUM_KP_THRESH = cfg.get('FILTER_NUM_KP_THRESH', 0.0) FILTER_REPROJ_THRESH = cfg.get('FILTER_REPROJ_THRESH', 0.0) FILTER_MIN_BBOX_SIZE = cfg.get('FILTER_MIN_BBOX_SIZE', 0.0) if SUPPRESS_KP_CONF_THRESH > 0: dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH)) if SUPPRESS_BETAS_THRESH > 0: dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH)) if SUPPRESS_BAD_POSES: dataset = dataset.map(lambda x: supress_bad_poses(x)) if POSES_BETAS_SIMULTANEOUS: dataset = dataset.map(lambda x: poses_betas_simultaneous(x)) if FILTER_NO_POSES: dataset = dataset.select(lambda x: filter_no_poses(x)) if FILTER_NUM_KP > 0: dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH)) if FILTER_REPROJ_THRESH > 0: dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH)) if FILTER_MIN_BBOX_SIZE > 0: dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE)) if BETAS_REG: dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end use_skimage_antialias = cfg.get('USE_SKIMAGE_ANTIALIAS', False) border_mode = { 'constant': cv2.BORDER_CONSTANT, 'replicate': cv2.BORDER_REPLICATE, }[cfg.get('BORDER_MODE', 'constant')] # Process the dataset further dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train, augm_config=cfg['augm'], MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE, BBOX_SHAPE=BBOX_SHAPE, use_skimage_antialias=use_skimage_antialias, border_mode=border_mode, )) if epoch_size is not None: dataset = dataset.with_epoch(epoch_size) return dataset @staticmethod def process_webdataset_tar_item(item, train, augm_config=None, MEAN=DEFAULT_MEAN, STD=DEFAULT_STD, IMG_SIZE=DEFAULT_IMG_SIZE, BBOX_SHAPE=None, use_skimage_antialias=False, border_mode=cv2.BORDER_CONSTANT, ): # Read data from item key = item['__key__'] image = item['jpg'] data = item['data.pyd'] mask = item['mask'] pid = data['pid'] keypoints_2d = data['keypoints_2d'] keypoints_3d = data['keypoints_3d'] center = data['center'] scale = data['scale'] body_pose = (data['orig_poses'], data['flip_poses']) betas = (data['orig_betas'], data['flip_betas']) # trans = data['trans'] has_body_pose = (data['orig_has_body_pose'], data['flip_has_body_pose']) has_betas = (data['orig_has_betas'], data['flip_has_betas']) # image_file = data['image_file'] # Process data orig_keypoints_2d = keypoints_2d.copy() center_x = center[0] center_y = center[1] bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() if bbox_size < 1: breakpoint() smpl_params = {'global_orient': (body_pose[0][:3], body_pose[1][:3]), 'body_pose': (body_pose[0][3:], body_pose[1][3:]), 'betas': betas, # 'trans': trans, } has_smpl_params = {'global_orient': has_body_pose, 'body_pose': has_body_pose, 'betas': has_betas } smpl_params_is_axis_angle = {'global_orient': True, 'body_pose': True, 'betas': False } augm_config = copy.deepcopy(augm_config) # Crop image and (possibly) perform data augmentation img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2) img_patch_rgba, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans, augm_record = get_example(img_rgba, center_x, center_y, bbox_size, bbox_size, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, FLIP_KEYPOINT_PERMUTATION, IMG_SIZE, IMG_SIZE, MEAN, STD, train, augm_config, is_bgr=False, return_trans=True, use_skimage_antialias=use_skimage_antialias, border_mode=border_mode, ) img_patch = img_patch_rgba[:3,:,:] mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1) if (mask_patch < 0.5).all(): mask_patch = np.ones_like(mask_patch) item = {} item['img'] = img_patch item['mask'] = mask_patch # item['img_og'] = image # item['mask_og'] = mask item['keypoints_2d'] = keypoints_2d.astype(np.float32) item['keypoints_3d'] = keypoints_3d.astype(np.float32) item['orig_keypoints_2d'] = orig_keypoints_2d item['box_center'] = center.copy() item['box_size'] = bbox_size item['img_size'] = 1.0 * img_size[::-1].copy() item['smpl_params'] = smpl_params item['has_smpl_params'] = has_smpl_params item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle item['_scale'] = scale item['_trans'] = trans item['imgname'] = key item['pid'] = pid item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process. return item