Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import os.path as osp | |
| import warnings | |
| from typing import Optional, Sequence | |
| import mmcv | |
| import mmengine.fileio as fileio | |
| import numpy as np | |
| from mmengine.runner import Runner | |
| from mmseg.engine import SegVisualizationHook | |
| from mmseg.structures import SegDataSample | |
| from opencd.registry import HOOKS | |
| from opencd.visualization import CDLocalVisualizer | |
| class CDVisualizationHook(SegVisualizationHook): | |
| """Change Detection Visualization Hook. Used to visualize validation and | |
| testing process prediction results. | |
| Args: | |
| img_shape (tuple): if img_shape is given and `draw_on_from_to_img` is | |
| False, the original images will not be read. | |
| draw_on_from_to_img (bool): whether to draw semantic prediction results | |
| on the original images. If it is False, it means that drawing on | |
| the black board. Defaults to False. | |
| """ | |
| def __init__(self, | |
| img_shape: tuple = None, | |
| draw_on_from_to_img: bool = False, | |
| draw: bool = False, | |
| interval: int = 50, | |
| show: bool = False, | |
| wait_time: float = 0., | |
| backend_args: Optional[dict] = None): | |
| self.img_shape = img_shape | |
| self.draw_on_from_to_img = draw_on_from_to_img | |
| if self.draw_on_from_to_img: | |
| warnings.warn('`draw_on_from_to_img` works only in ' | |
| 'semantic change detection.') | |
| self._visualizer: CDLocalVisualizer = \ | |
| CDLocalVisualizer.get_current_instance() | |
| self.interval = interval | |
| self.show = show | |
| if self.show: | |
| # No need to think about vis backends. | |
| self._visualizer._vis_backends = {} | |
| warnings.warn('The show is True, it means that only ' | |
| 'the prediction results are visualized ' | |
| 'without storing data, so vis_backends ' | |
| 'needs to be excluded.') | |
| self.wait_time = wait_time | |
| self.backend_args = backend_args.copy() if backend_args else None | |
| self.draw = draw | |
| if not self.draw: | |
| warnings.warn('The draw is False, it means that the ' | |
| 'hook for visualization will not take ' | |
| 'effect. The results will NOT be ' | |
| 'visualized or stored.') | |
| def _after_iter(self, | |
| runner: Runner, | |
| batch_idx: int, | |
| data_batch: dict, | |
| outputs: Sequence[SegDataSample], | |
| mode: str = 'val') -> None: | |
| """Run after every ``self.interval`` validation iterations. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the validation process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. | |
| mode (str): mode (str): Current mode of runner. Defaults to 'val'. | |
| """ | |
| if self.draw is False or mode == 'train': | |
| return | |
| if self.every_n_inner_iters(batch_idx, self.interval): | |
| for output in outputs: | |
| img_path = output.img_path[0] | |
| img_from_to = [] | |
| window_name = osp.basename(img_path).split('.')[0] | |
| if self.img_shape is not None: | |
| assert len(self.img_shape) == 3, \ | |
| '`img_shape` should be (H, W, C)' | |
| else: | |
| img_bytes = fileio.get( | |
| img_path, backend_args=self.backend_args) | |
| img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
| self.img_shape = img.shape | |
| if self.draw_on_from_to_img: | |
| # for semantic change detection | |
| for _img_path in output.img_path: | |
| _img_bytes = fileio.get( | |
| _img_path, backend_args=self.backend_args) | |
| _img = mmcv.imfrombytes(_img_bytes, channel_order='rgb') | |
| img_from_to.append(_img) | |
| img = np.zeros(self.img_shape) | |
| self._visualizer.add_datasample( | |
| window_name, | |
| img, | |
| img_from_to, | |
| data_sample=output, | |
| show=self.show, | |
| wait_time=self.wait_time, | |
| step=runner.iter, | |
| draw_gt=False) | |