Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Dict, List, Tuple, Union | |
| from mmengine.model import BaseModel | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import OptTrackSampleList, TrackSampleList | |
| from mmdet.utils import OptConfigType, OptMultiConfig | |
| class BaseMOTModel(BaseModel, metaclass=ABCMeta): | |
| """Base class for multiple object tracking. | |
| Args: | |
| data_preprocessor (dict or ConfigDict, optional): The pre-process | |
| config of :class:`TrackDataPreprocessor`. it usually includes, | |
| ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. | |
| init_cfg (dict or list[dict]): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None) -> None: | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None: | |
| """Freeze module during training.""" | |
| if isinstance(module, str): | |
| modules = [module] | |
| else: | |
| if not (isinstance(module, list) or isinstance(module, tuple)): | |
| raise TypeError('module must be a str or a list.') | |
| else: | |
| modules = module | |
| for module in modules: | |
| m = getattr(self, module) | |
| m.eval() | |
| for param in m.parameters(): | |
| param.requires_grad = False | |
| def with_detector(self) -> bool: | |
| """bool: whether the framework has a detector.""" | |
| return hasattr(self, 'detector') and self.detector is not None | |
| def with_reid(self) -> bool: | |
| """bool: whether the framework has a reid model.""" | |
| return hasattr(self, 'reid') and self.reid is not None | |
| def with_motion(self) -> bool: | |
| """bool: whether the framework has a motion model.""" | |
| return hasattr(self, 'motion') and self.motion is not None | |
| def with_track_head(self) -> bool: | |
| """bool: whether the framework has a track_head.""" | |
| return hasattr(self, 'track_head') and self.track_head is not None | |
| def with_tracker(self) -> bool: | |
| """bool: whether the framework has a tracker.""" | |
| return hasattr(self, 'tracker') and self.tracker is not None | |
| def forward(self, | |
| inputs: Dict[str, Tensor], | |
| data_samples: OptTrackSampleList = None, | |
| mode: str = 'predict', | |
| **kwargs): | |
| """The unified entry for a forward process in both training and test. | |
| The method should accept three modes: "tensor", "predict" and "loss": | |
| - "tensor": Forward the whole network and return tensor or tuple of | |
| tensor without any post-processing, same as a common nn.Module. | |
| - "predict": Forward and return the predictions, which are fully | |
| processed to a list of :obj:`TrackDataSample`. | |
| - "loss": Forward and return a dict of losses according to the given | |
| inputs and data samples. | |
| Note that this method doesn't handle neither back propagation nor | |
| optimizer updating, which are done in the :meth:`train_step`. | |
| Args: | |
| inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) | |
| encoding input images. Typically these should be mean centered | |
| and std scaled. The N denotes batch size. The T denotes the | |
| number of key/reference frames. | |
| - img (Tensor) : The key images. | |
| - ref_img (Tensor): The reference images. | |
| data_samples (list[:obj:`TrackDataSample`], optional): The | |
| annotation data of every samples. Defaults to None. | |
| mode (str): Return what kind of value. Defaults to 'predict'. | |
| Returns: | |
| The return type depends on ``mode``. | |
| - If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
| - If ``mode="predict"``, return a list of :obj:`TrackDataSample`. | |
| - If ``mode="loss"``, return a dict of tensor. | |
| """ | |
| if mode == 'loss': | |
| return self.loss(inputs, data_samples, **kwargs) | |
| elif mode == 'predict': | |
| return self.predict(inputs, data_samples, **kwargs) | |
| elif mode == 'tensor': | |
| return self._forward(inputs, data_samples, **kwargs) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}". ' | |
| 'Only supports loss, predict and tensor mode') | |
| def loss(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, | |
| **kwargs) -> Union[dict, tuple]: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| pass | |
| def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, | |
| **kwargs) -> TrackSampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing.""" | |
| pass | |
| def _forward(self, | |
| inputs: Dict[str, Tensor], | |
| data_samples: OptTrackSampleList = None, | |
| **kwargs): | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| inputs (Dict[str, Tensor]): of shape (N, T, C, H, W). | |
| data_samples (List[:obj:`TrackDataSample`], optional): The | |
| Data Samples. It usually includes information such as | |
| `gt_instance`. | |
| Returns: | |
| tuple[list]: A tuple of features from ``head`` forward. | |
| """ | |
| raise NotImplementedError( | |
| "_forward function (namely 'tensor' mode) is not supported now") | |