Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from .typing_utils import SampleList | |
| def add_prefix(inputs, prefix): | |
| """Add prefix for dict. | |
| Args: | |
| inputs (dict): The input dict with str keys. | |
| prefix (str): The prefix to add. | |
| Returns: | |
| dict: The dict with keys updated with ``prefix``. | |
| """ | |
| outputs = dict() | |
| for name, value in inputs.items(): | |
| outputs[f'{prefix}.{name}'] = value | |
| return outputs | |
| def stack_batch(inputs: List[torch.Tensor], | |
| data_samples: Optional[SampleList] = None, | |
| size: Optional[tuple] = None, | |
| size_divisor: Optional[int] = None, | |
| pad_val: Union[int, float] = 0, | |
| seg_pad_val: Union[int, float] = 255) -> torch.Tensor: | |
| """Stack multiple inputs to form a batch and pad the images and gt_sem_segs | |
| to the max shape use the right bottom padding mode. | |
| Args: | |
| inputs (List[Tensor]): The input multiple tensors. each is a | |
| CHW 3D-tensor. | |
| data_samples (list[:obj:`SegDataSample`]): The list of data samples. | |
| It usually includes information such as `gt_sem_seg`. | |
| size (tuple, optional): Fixed padding size. | |
| size_divisor (int, optional): The divisor of padded size. | |
| pad_val (int, float): The padding value. Defaults to 0 | |
| seg_pad_val (int, float): The padding value. Defaults to 255 | |
| Returns: | |
| Tensor: The 4D-tensor. | |
| List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. | |
| """ | |
| assert isinstance(inputs, list), \ | |
| f'Expected input type to be list, but got {type(inputs)}' | |
| assert len({tensor.ndim for tensor in inputs}) == 1, \ | |
| f'Expected the dimensions of all inputs must be the same, ' \ | |
| f'but got {[tensor.ndim for tensor in inputs]}' | |
| assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ | |
| f'but got {inputs[0].ndim}' | |
| assert len({tensor.shape[0] for tensor in inputs}) == 1, \ | |
| f'Expected the channels of all inputs must be the same, ' \ | |
| f'but got {[tensor.shape[0] for tensor in inputs]}' | |
| # only one of size and size_divisor should be valid | |
| assert (size is not None) ^ (size_divisor is not None), \ | |
| 'only one of size and size_divisor should be valid' | |
| padded_inputs = [] | |
| padded_samples = [] | |
| inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] | |
| max_size = np.stack(inputs_sizes).max(0) | |
| if size_divisor is not None and size_divisor > 1: | |
| # the last two dims are H,W, both subject to divisibility requirement | |
| max_size = (max_size + | |
| (size_divisor - 1)) // size_divisor * size_divisor | |
| for i in range(len(inputs)): | |
| tensor = inputs[i] | |
| if size is not None: | |
| width = max(size[-1] - tensor.shape[-1], 0) | |
| height = max(size[-2] - tensor.shape[-2], 0) | |
| # (padding_left, padding_right, padding_top, padding_bottom) | |
| padding_size = (0, width, 0, height) | |
| elif size_divisor is not None: | |
| width = max(max_size[-1] - tensor.shape[-1], 0) | |
| height = max(max_size[-2] - tensor.shape[-2], 0) | |
| padding_size = (0, width, 0, height) | |
| else: | |
| padding_size = [0, 0, 0, 0] | |
| # pad img | |
| pad_img = F.pad(tensor, padding_size, value=pad_val) | |
| padded_inputs.append(pad_img) | |
| # pad gt_sem_seg | |
| if data_samples is not None: | |
| data_sample = data_samples[i] | |
| pad_shape = None | |
| if 'gt_sem_seg' in data_sample: | |
| gt_sem_seg = data_sample.gt_sem_seg.data | |
| del data_sample.gt_sem_seg.data | |
| data_sample.gt_sem_seg.data = F.pad( | |
| gt_sem_seg, padding_size, value=seg_pad_val) | |
| pad_shape = data_sample.gt_sem_seg.shape | |
| if 'gt_edge_map' in data_sample: | |
| gt_edge_map = data_sample.gt_edge_map.data | |
| del data_sample.gt_edge_map.data | |
| data_sample.gt_edge_map.data = F.pad( | |
| gt_edge_map, padding_size, value=seg_pad_val) | |
| pad_shape = data_sample.gt_edge_map.shape | |
| if 'gt_depth_map' in data_sample: | |
| gt_depth_map = data_sample.gt_depth_map.data | |
| del data_sample.gt_depth_map.data | |
| data_sample.gt_depth_map.data = F.pad( | |
| gt_depth_map, padding_size, value=seg_pad_val) | |
| pad_shape = data_sample.gt_depth_map.shape | |
| data_sample.set_metainfo({ | |
| 'img_shape': tensor.shape[-2:], | |
| 'pad_shape': pad_shape, | |
| 'padding_size': padding_size | |
| }) | |
| padded_samples.append(data_sample) | |
| else: | |
| padded_samples.append( | |
| dict( | |
| img_padding_size=padding_size, | |
| pad_shape=pad_img.shape[-2:])) | |
| return torch.stack(padded_inputs, dim=0), padded_samples | |