""" Default training/testing logic modified from detectron2(https://github.com/facebookresearch/detectron2) Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import os import sys import argparse import multiprocessing as mp from torch.nn.parallel import DistributedDataParallel import pointcept.utils.comm as comm from pointcept.utils.env import get_random_seed, set_seed from pointcept.utils.config import Config, DictAction def create_ddp_model(model, *, fp16_compression=False, **kwargs): """ Create a DistributedDataParallel model if there are >1 processes. Args: model: a torch.nn.Module fp16_compression: add fp16 compression hooks to the ddp object. See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. """ if comm.get_world_size() == 1: return model # kwargs['find_unused_parameters'] = True if "device_ids" not in kwargs: kwargs["device_ids"] = [comm.get_local_rank()] if "output_device" not in kwargs: kwargs["output_device"] = [comm.get_local_rank()] ddp = DistributedDataParallel(model, **kwargs) if fp16_compression: from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) return ddp def worker_init_fn(worker_id, num_workers, rank, seed): """Worker init func for dataloader. The seed of each worker equals to num_worker * rank + worker_id + user_seed Args: worker_id (int): Worker id. num_workers (int): Number of workers. rank (int): The rank of current process. seed (int): The random seed to use. """ worker_seed = num_workers * rank + worker_id + seed set_seed(worker_seed) def default_argument_parser(epilog=None): parser = argparse.ArgumentParser( epilog=epilog or f""" Examples: Run on single machine: $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml Change some config options: $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 Run on multiple machines: (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] """, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file" ) parser.add_argument( "--num-gpus", type=int, default=1, help="number of gpus *per machine*" ) parser.add_argument( "--num-machines", type=int, default=1, help="total number of machines" ) parser.add_argument( "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)", ) # PyTorch still may leave orphan processes in multi-gpu training. # Therefore we use a deterministic way to obtain port, # so that users are aware of orphan processes by seeing the port occupied. # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", # default="tcp://127.0.0.1:{}".format(port), default="auto", help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "--options", nargs="+", action=DictAction, help="custom options" ) return parser def default_config_parser(file_path, options): # config name protocol: dataset_name/model_name-exp_name if os.path.isfile(file_path): cfg = Config.fromfile(file_path) else: sep = file_path.find("-") cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) if options is not None: cfg.merge_from_dict(options) if cfg.seed is None: cfg.seed = get_random_seed() cfg.data.train.loop = cfg.epoch // cfg.eval_epoch os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) if not cfg.resume: cfg.dump(os.path.join(cfg.save_path, "config.py")) return cfg def default_setup(cfg): # scalar by world size world_size = comm.get_world_size() cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() cfg.num_worker_per_gpu = cfg.num_worker // world_size assert cfg.batch_size % world_size == 0 assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 cfg.batch_size_per_gpu = cfg.batch_size // world_size cfg.batch_size_val_per_gpu = ( cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 ) cfg.batch_size_test_per_gpu = ( cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 ) # update data loop assert cfg.epoch % cfg.eval_epoch == 0 # settle random seed rank = comm.get_rank() seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank set_seed(seed) return cfg