Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Sequence, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from mmengine.utils import is_str | |
| if hasattr(torch, 'tensor_split'): | |
| tensor_split = torch.tensor_split | |
| else: | |
| # A simple implementation of `tensor_split`. | |
| def tensor_split(input: torch.Tensor, indices: list): | |
| outs = [] | |
| for start, end in zip([0] + indices, indices + [input.size(0)]): | |
| outs.append(input[start:end]) | |
| return outs | |
| LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int] | |
| SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence] | |
| def format_label(value: LABEL_TYPE) -> torch.Tensor: | |
| """Convert various python types to label-format tensor. | |
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
| :class:`Sequence`, :class:`int`. | |
| Args: | |
| value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. | |
| Returns: | |
| :obj:`torch.Tensor`: The foramtted label tensor. | |
| """ | |
| # Handle single number | |
| if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: | |
| value = int(value.item()) | |
| if isinstance(value, np.ndarray): | |
| value = torch.from_numpy(value).to(torch.long) | |
| elif isinstance(value, Sequence) and not is_str(value): | |
| value = torch.tensor(value).to(torch.long) | |
| elif isinstance(value, int): | |
| value = torch.LongTensor([value]) | |
| elif not isinstance(value, torch.Tensor): | |
| raise TypeError(f'Type {type(value)} is not an available label type.') | |
| assert value.ndim == 1, \ | |
| f'The dims of value should be 1, but got {value.ndim}.' | |
| return value | |
| def format_score(value: SCORE_TYPE) -> torch.Tensor: | |
| """Convert various python types to score-format tensor. | |
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
| :class:`Sequence`. | |
| Args: | |
| value (torch.Tensor | numpy.ndarray | Sequence): Score values. | |
| Returns: | |
| :obj:`torch.Tensor`: The foramtted score tensor. | |
| """ | |
| if isinstance(value, np.ndarray): | |
| value = torch.from_numpy(value).float() | |
| elif isinstance(value, Sequence) and not is_str(value): | |
| value = torch.tensor(value).float() | |
| elif not isinstance(value, torch.Tensor): | |
| raise TypeError(f'Type {type(value)} is not an available label type.') | |
| assert value.ndim == 1, \ | |
| f'The dims of value should be 1, but got {value.ndim}.' | |
| return value | |
| def cat_batch_labels(elements: List[torch.Tensor]): | |
| """Concat a batch of label tensor to one tensor. | |
| Args: | |
| elements (List[tensor]): A batch of labels. | |
| Returns: | |
| Tuple[torch.Tensor, List[int]]: The first item is the concated label | |
| tensor, and the second item is the split indices of every sample. | |
| """ | |
| labels = [] | |
| splits = [0] | |
| for element in elements: | |
| labels.append(element) | |
| splits.append(splits[-1] + element.size(0)) | |
| batch_label = torch.cat(labels) | |
| return batch_label, splits[1:-1] | |
| def batch_label_to_onehot(batch_label, split_indices, num_classes): | |
| """Convert a concated label tensor to onehot format. | |
| Args: | |
| batch_label (torch.Tensor): A concated label tensor from multiple | |
| samples. | |
| split_indices (List[int]): The split indices of every sample. | |
| num_classes (int): The number of classes. | |
| Returns: | |
| torch.Tensor: The onehot format label tensor. | |
| Examples: | |
| >>> import torch | |
| >>> from mmpretrain.structures import batch_label_to_onehot | |
| >>> # Assume a concated label from 3 samples. | |
| >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] | |
| >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) | |
| >>> split_indices = [2, 5] | |
| >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) | |
| tensor([[1, 1, 0, 0, 0], | |
| [1, 0, 1, 0, 1], | |
| [0, 1, 0, 1, 0]]) | |
| """ | |
| sparse_onehot_list = F.one_hot(batch_label, num_classes) | |
| onehot_list = [ | |
| sparse_onehot.sum(0) | |
| for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) | |
| ] | |
| return torch.stack(onehot_list) | |
| def label_to_onehot(label: LABEL_TYPE, num_classes: int): | |
| """Convert a label to onehot format tensor. | |
| Args: | |
| label (LABEL_TYPE): Label value. | |
| num_classes (int): The number of classes. | |
| Returns: | |
| torch.Tensor: The onehot format label tensor. | |
| Examples: | |
| >>> import torch | |
| >>> from mmpretrain.structures import label_to_onehot | |
| >>> # Single-label | |
| >>> label_to_onehot(1, num_classes=5) | |
| tensor([0, 1, 0, 0, 0]) | |
| >>> # Multi-label | |
| >>> label_to_onehot([0, 2, 3], num_classes=5) | |
| tensor([1, 0, 1, 1, 0]) | |
| """ | |
| label = format_label(label) | |
| sparse_onehot = F.one_hot(label, num_classes) | |
| return sparse_onehot.sum(0) | |