Spaces:
Runtime error
Runtime error
| from typing import List | |
| import numpy as np | |
| from detrsmpl.utils.path_utils import ( | |
| Existence, | |
| check_path_existence, | |
| check_path_suffix, | |
| ) | |
| from .human_data import HumanData | |
| class HumanDataCacheReader(): | |
| def __init__(self, npz_path: str): | |
| self.npz_path = npz_path | |
| npz_file = np.load(npz_path, allow_pickle=True) | |
| self.slice_size = npz_file['slice_size'].item() | |
| self.data_len = npz_file['data_len'].item() | |
| self.keypoints_info = npz_file['keypoints_info'].item() | |
| self.non_sliced_data = None | |
| self.npz_file = None | |
| def __del__(self): | |
| if self.npz_file is not None: | |
| self.npz_file.close() | |
| def get_item(self, index, required_keys: List[str] = []): | |
| if self.npz_file is None: | |
| self.npz_file = np.load(self.npz_path, allow_pickle=True) | |
| cache_key = str(int(index / self.slice_size)) | |
| base_data = self.npz_file[cache_key].item() | |
| base_data.update(self.keypoints_info) | |
| for key in required_keys: | |
| non_sliced_value = self.get_non_sliced_data(key) | |
| if isinstance(non_sliced_value, dict) and\ | |
| key in base_data and\ | |
| isinstance(base_data[key], dict): | |
| base_data[key].update(non_sliced_value) | |
| else: | |
| base_data[key] = non_sliced_value | |
| ret_human_data = HumanData.new(source_dict=base_data) | |
| # data in cache is compressed | |
| ret_human_data.__keypoints_compressed__ = True | |
| # set missing values and attributes by default method | |
| ret_human_data.__set_default_values__() | |
| return ret_human_data | |
| def get_non_sliced_data(self, key: str): | |
| if self.non_sliced_data is None: | |
| if self.npz_file is None: | |
| npz_file = np.load(self.npz_path, allow_pickle=True) | |
| self.non_sliced_data = npz_file['non_sliced_data'].item() | |
| else: | |
| self.non_sliced_data = self.npz_file['non_sliced_data'].item() | |
| return self.non_sliced_data[key] | |
| class HumanDataCacheWriter(): | |
| def __init__(self, | |
| slice_size: int, | |
| data_len: int, | |
| keypoints_info: dict, | |
| non_sliced_data: dict, | |
| key_strict: bool = True): | |
| self.slice_size = slice_size | |
| self.data_len = data_len | |
| self.keypoints_info = keypoints_info | |
| self.non_sliced_data = non_sliced_data | |
| self.sliced_data = {} | |
| self.key_strict = key_strict | |
| def update_sliced_dict(self, sliced_dict): | |
| self.sliced_data.update(sliced_dict) | |
| def dump(self, npz_path: str, overwrite: bool = True): | |
| """Dump keys and items to an npz file. | |
| Args: | |
| npz_path (str): | |
| Path to a dumped npz file. | |
| overwrite (bool, optional): | |
| Whether to overwrite if there is already a file. | |
| Defaults to True. | |
| Raises: | |
| ValueError: | |
| npz_path does not end with '.npz'. | |
| FileExistsError: | |
| When overwrite is False and file exists. | |
| """ | |
| if not check_path_suffix(npz_path, ['.npz']): | |
| raise ValueError('Not an npz file.') | |
| if not overwrite: | |
| if check_path_existence(npz_path, 'file') == Existence.FileExist: | |
| raise FileExistsError | |
| dict_to_dump = { | |
| 'slice_size': self.slice_size, | |
| 'data_len': self.data_len, | |
| 'keypoints_info': self.keypoints_info, | |
| 'non_sliced_data': self.non_sliced_data, | |
| 'key_strict': self.key_strict, | |
| } | |
| dict_to_dump.update(self.sliced_data) | |
| np.savez_compressed(npz_path, **dict_to_dump) | |