Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional | |
| from torch import Tensor, nn | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, | |
| OptSampleList, SampleList, add_prefix) | |
| from .encoder_decoder import EncoderDecoder | |
| class CascadeEncoderDecoder(EncoderDecoder): | |
| """Cascade Encoder Decoder segmentors. | |
| CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of | |
| CascadeEncoderDecoder are cascaded. The output of previous decoder_head | |
| will be the input of next decoder_head. | |
| Args: | |
| num_stages (int): How many stages will be cascaded. | |
| backbone (ConfigType): The config for the backnone of segmentor. | |
| decode_head (ConfigType): The config for the decode head of segmentor. | |
| neck (OptConfigType): The config for the neck of segmentor. | |
| Defaults to None. | |
| auxiliary_head (OptConfigType): The config for the auxiliary head of | |
| segmentor. Defaults to None. | |
| train_cfg (OptConfigType): The config for training. Defaults to None. | |
| test_cfg (OptConfigType): The config for testing. Defaults to None. | |
| data_preprocessor (dict, optional): The pre-process config of | |
| :class:`BaseDataPreprocessor`. | |
| pretrained (str, optional): The path for pretrained model. | |
| Defaults to None. | |
| init_cfg (dict, optional): The weight initialized config for | |
| :class:`BaseModule`. | |
| """ | |
| def __init__(self, | |
| num_stages: int, | |
| backbone: ConfigType, | |
| decode_head: ConfigType, | |
| neck: OptConfigType = None, | |
| auxiliary_head: OptConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| pretrained: Optional[str] = None, | |
| init_cfg: OptMultiConfig = None): | |
| self.num_stages = num_stages | |
| super().__init__( | |
| backbone=backbone, | |
| decode_head=decode_head, | |
| neck=neck, | |
| auxiliary_head=auxiliary_head, | |
| train_cfg=train_cfg, | |
| test_cfg=test_cfg, | |
| data_preprocessor=data_preprocessor, | |
| pretrained=pretrained, | |
| init_cfg=init_cfg) | |
| def _init_decode_head(self, decode_head: ConfigType) -> None: | |
| """Initialize ``decode_head``""" | |
| assert isinstance(decode_head, list) | |
| assert len(decode_head) == self.num_stages | |
| self.decode_head = nn.ModuleList() | |
| for i in range(self.num_stages): | |
| self.decode_head.append(MODELS.build(decode_head[i])) | |
| self.align_corners = self.decode_head[-1].align_corners | |
| self.num_classes = self.decode_head[-1].num_classes | |
| self.out_channels = self.decode_head[-1].out_channels | |
| def encode_decode(self, inputs: Tensor, | |
| batch_img_metas: List[dict]) -> Tensor: | |
| """Encode images with backbone and decode into a semantic segmentation | |
| map of the same size as input.""" | |
| x = self.extract_feat(inputs) | |
| out = self.decode_head[0].forward(x) | |
| for i in range(1, self.num_stages - 1): | |
| out = self.decode_head[i].forward(x, out) | |
| seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, | |
| self.test_cfg) | |
| return seg_logits_list | |
| def _decode_head_forward_train(self, inputs: Tensor, | |
| data_samples: SampleList) -> dict: | |
| """Run forward function and calculate loss for decode head in | |
| training.""" | |
| losses = dict() | |
| loss_decode = self.decode_head[0].loss(inputs, data_samples, | |
| self.train_cfg) | |
| losses.update(add_prefix(loss_decode, 'decode_0')) | |
| # get batch_img_metas | |
| batch_size = len(data_samples) | |
| batch_img_metas = [] | |
| for batch_index in range(batch_size): | |
| metainfo = data_samples[batch_index].metainfo | |
| batch_img_metas.append(metainfo) | |
| for i in range(1, self.num_stages): | |
| # forward test again, maybe unnecessary for most methods. | |
| if i == 1: | |
| prev_outputs = self.decode_head[0].forward(inputs) | |
| else: | |
| prev_outputs = self.decode_head[i - 1].forward( | |
| inputs, prev_outputs) | |
| loss_decode = self.decode_head[i].loss(inputs, prev_outputs, | |
| data_samples, | |
| self.train_cfg) | |
| losses.update(add_prefix(loss_decode, f'decode_{i}')) | |
| return losses | |
| def _forward(self, | |
| inputs: Tensor, | |
| data_samples: OptSampleList = None) -> Tensor: | |
| """Network forward process. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W). | |
| data_samples (List[:obj:`SegDataSample`]): The seg data samples. | |
| It usually includes information such as `metainfo` and | |
| `gt_semantic_seg`. | |
| Returns: | |
| Tensor: Forward output of model without any post-processes. | |
| """ | |
| x = self.extract_feat(inputs) | |
| out = self.decode_head[0].forward(x) | |
| for i in range(1, self.num_stages): | |
| # TODO support PointRend tensor mode | |
| out = self.decode_head[i].forward(x, out) | |
| return out | |