# -------------------------------------------------------- # InternImage # Copyright (c) 2022 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import argparse import copy import os import os.path as osp import time import warnings import mmcv import mmcv_custom # noqa: F401,F403 import mmseg_custom # noqa: F401,F403 import torch import torch.distributed as dist from mmcv.cnn.utils import revert_sync_batchnorm from mmcv.runner import get_dist_info, init_dist from mmcv.utils import Config, DictAction, get_git_hash from mmseg import __version__ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor from mmseg.utils import collect_env, get_device, get_root_logger, setup_multi_processes def parse_args(): parser = argparse.ArgumentParser(description="Train a segmentor") parser.add_argument("config", help="train config file path") parser.add_argument("--work-dir", help="the dir to save logs and models") parser.add_argument("--load-from", help="the checkpoint file to load weights from") parser.add_argument("--resume-from", help="the checkpoint file to resume from") parser.add_argument( "--no-validate", action="store_true", help="whether not to evaluate the checkpoint during training", ) group_gpus = parser.add_mutually_exclusive_group() group_gpus.add_argument( "--gpus", type=int, help="(Deprecated, please use --gpu-id) number of gpus to use " "(only applicable to non-distributed training)", ) group_gpus.add_argument( "--gpu-ids", type=int, nargs="+", help="(Deprecated, please use --gpu-id) ids of gpus to use " "(only applicable to non-distributed training)", ) group_gpus.add_argument( "--gpu-id", type=int, default=0, help="id of gpu to use (only applicable to non-distributed training)", ) parser.add_argument("--seed", type=int, default=None, help="random seed") parser.add_argument( "--diff_seed", action="store_true", help="Whether or not set different seeds for different ranks", ) parser.add_argument( "--deterministic", action="store_true", help="whether to set deterministic options for CUDNN backend.", ) parser.add_argument( "--options", nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " "not be supported in version v0.22.0. Override some settings in the " "used config, the key-value pair in xxx=yyy format will be merged " "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' "marks are necessary and that no white space is allowed.", ) parser.add_argument( "--cfg-options", nargs="+", action=DictAction, help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' "Note that the quotation marks are necessary and that no white space " "is allowed.", ) parser.add_argument( "--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher", ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--auto-resume", action="store_true", help="resume from the latest checkpoint automatically.", ) args = parser.parse_args() if "LOCAL_RANK" not in os.environ: os.environ["LOCAL_RANK"] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( "--options and --cfg-options cannot be both " "specified, --options is deprecated in favor of --cfg-options. " "--options will not be supported in version v0.22.0." ) if args.options: warnings.warn( "--options is deprecated in favor of --cfg-options. " "--options will not be supported in version v0.22.0." ) args.cfg_options = args.options return args def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get("work_dir", None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join( "./work_dirs", osp.splitext(osp.basename(args.config))[0] ) if args.load_from is not None: cfg.load_from = args.load_from if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpus is not None: cfg.gpu_ids = range(1) warnings.warn( "`--gpus` is deprecated because we only support " "single GPU mode in non-distributed training. " "Use `gpus=1` now." ) if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids[0:1] warnings.warn( "`--gpu-ids` is deprecated, please use `--gpu-id`. " "Because we only support single GPU mode in " "non-distributed training. Use the first GPU " "in `gpu_ids` now." ) if args.gpus is None and args.gpu_ids is None: cfg.gpu_ids = [args.gpu_id] cfg.auto_resume = args.auto_resume # init distributed env first, since logger depends on the dist info. if args.launcher == "none": distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # gpu_ids is used to calculate iter when resuming checkpoint _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) log_file = osp.join(cfg.work_dir, f"{timestamp}.log") logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # set multi-process settings setup_multi_processes(cfg) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = "\n".join([f"{k}: {v}" for k, v in env_info_dict.items()]) dash_line = "-" * 60 + "\n" logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) meta["env_info"] = env_info # log some basic info logger.info(f"Distributed training: {distributed}") logger.info(f"Config:\n{cfg.pretty_text}") # set random seeds cfg.device = get_device() seed = init_random_seed(args.seed, device=cfg.device) seed = seed + dist.get_rank() if args.diff_seed else seed logger.info(f"Set random seed to {seed}, deterministic: {args.deterministic}") set_random_seed(seed, deterministic=args.deterministic) cfg.seed = seed meta["seed"] = seed meta["exp_name"] = osp.basename(args.config) model = build_segmentor( cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") ) model.init_weights() # SyncBN is not support for DP if not distributed: warnings.warn( "SyncBN is only supported with DDP. To be compatible with DP, " "we convert SyncBN to BN. Please use dist_train.sh which can " "avoid this error." ) model = revert_sync_batchnorm(model) # logger.info(model) datasets = [build_dataset(cfg.data.train)] logger.info(f"Built training dataset from config: {cfg.data.train}") logger.info(f"Number of samples in training dataset: {len(datasets[0])}") if len(datasets[0]) == 0: logger.error("Training dataset is EMPTY! Please check the following:") logger.error( f"1. Cityscapes root path (exists?): {cfg.data.train.get('data_root', 'Not Set')}" ) logger.error( f"2. Image directory (resolved path should exist): {cfg.data.train.get('img_dir', 'Not Set')}" ) logger.error( f"3. Annotation directory (resolved path should exist): {cfg.data.train.get('ann_dir', 'Not Set')}" ) logger.error( f"4. Split file path (exists and readable?): {cfg.data.train.get('split', 'Not Set')}" ) logger.error( " Ensure the paths in the split file are correct relative to img_dir and ann_dir." ) logger.error( " And that the actual image and annotation files exist at the fully resolved paths." ) # Attempt to log resolved paths if the dataset object has them # This might not work if the object is a simple list wrapper or if attributes are private try: dataset_obj = datasets[0] if hasattr(dataset_obj, "data_root"): logger.error(f" Dataset effective data_root: {dataset_obj.data_root}") if hasattr(dataset_obj, "img_dir"): logger.error(f" Dataset effective img_dir: {dataset_obj.img_dir}") if hasattr(dataset_obj, "ann_dir"): logger.error(f" Dataset effective ann_dir: {dataset_obj.ann_dir}") if hasattr(dataset_obj, "split"): logger.error(f" Dataset effective split file: {dataset_obj.split}") except Exception as e: logger.error( f" Could not retrieve effective paths from dataset object: {e}" ) return # Exit if dataset is empty if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmseg version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmseg_version=f"{__version__}+{get_git_hash()[:7]}", config=cfg.pretty_text, CLASSES=datasets[0].CLASSES, PALETTE=datasets[0].PALETTE, ) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES # passing checkpoint meta for saving best checkpoint meta.update(cfg.checkpoint_config.meta) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True train_segmentor( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta, ) if __name__ == "__main__": main()