Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| from spiga.data.visualize.inspect_dataset import DatasetInspector, inspect_parser | |
| class HeatmapInspector(DatasetInspector): | |
| def __init__(self, database, anns_type, data_aug=True, image_shape=(256,256)): | |
| super().__init__(database, anns_type, data_aug=data_aug, pose=False, image_shape=image_shape) | |
| self.data_config.aug_names.append('heatmaps2D') | |
| self.data_config.heatmap2D_norm = False | |
| self.data_config.aug_names.append('boundaries') | |
| self.data_config.shuffle = False | |
| self.reload_dataset() | |
| def show_dataset(self, ids_list=None): | |
| if ids_list is None: | |
| ids = self.get_idx(shuffle=self.data_config.shuffle) | |
| else: | |
| ids = ids_list | |
| for img_id in ids: | |
| data_dict = self.dataset[img_id] | |
| crop_imgs, _ = self.plot_features(data_dict) | |
| # Plot landmark crop | |
| cv2.imshow('crop', crop_imgs['lnd']) | |
| # Plot landmarks 2D (group) | |
| crop_allheats = self._plot_heatmaps2D(data_dict) | |
| # Plot boundaries shape | |
| cv2.imshow('boundary', np.max(data_dict['boundary'], axis=0)) | |
| for lnd_idx in range(self.data_config.database.num_landmarks): | |
| # Heatmaps 2D | |
| crop_heats = self._plot_heatmaps2D(data_dict, lnd_idx) | |
| maps = cv2.hconcat([crop_allheats['heatmaps2D'], crop_heats['heatmaps2D']]) | |
| cv2.imshow('heatmaps', maps) | |
| key = cv2.waitKey() | |
| if key == ord('q'): | |
| break | |
| if key == ord('n'): | |
| break | |
| if key == ord('q'): | |
| break | |
| def _plot_heatmaps2D(self, data_dict, heatmap_id=None): | |
| # Variables | |
| heatmaps = {} | |
| image = data_dict['image'] | |
| if heatmap_id is None: | |
| heatmaps2D = data_dict['heatmap2D'] | |
| heatmaps2D = np.max(heatmaps2D, axis=0) | |
| else: | |
| heatmaps2D = data_dict['heatmap2D'][heatmap_id] | |
| # Plot maps | |
| heatmaps['heatmaps2D'] = self._merge_imgmap(image, heatmaps2D) | |
| return heatmaps | |
| def _merge_imgmap(self, image, maps): | |
| crop_maps = cv2.applyColorMap(np.uint8(255 * maps), cv2.COLORMAP_JET) | |
| return cv2.addWeighted(image, 0.7, crop_maps, 0.3, 0) | |
| if __name__ == '__main__': | |
| args = inspect_parser() | |
| data_aug = True | |
| database = args.database | |
| anns_type = args.anns | |
| select_img = args.img | |
| if args.clean: | |
| data_aug = False | |
| if len(args.shape) != 2: | |
| raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256') | |
| else: | |
| img_shape = tuple(args.shape) | |
| visualizer = HeatmapInspector(database, anns_type, data_aug, image_shape=img_shape) | |
| visualizer.show_dataset(ids_list=select_img) | |