""" D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. --------------------------------------------------------------------------------- Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright (c) 2023 lyuwenyu. All Rights Reserved. """ import os import sys import torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) import argparse from src.core import YAMLConfig, yaml_utils from src.misc import dist_utils from src.solver import TASKS from pprint import pprint debug = False if debug: def custom_repr(self): return f"{{Tensor:{tuple(self.shape)}}} {original_repr(self)}" original_repr = torch.Tensor.__repr__ torch.Tensor.__repr__ = custom_repr def safe_get_rank(): if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank() else: return 0 def main(args) -> None: """main""" dist_utils.setup_distributed(args.print_rank, args.print_method, seed=args.seed) assert not all( [args.tuning, args.resume] ), "Only support from_scrach or resume or tuning at one time" update_dict = yaml_utils.parse_cli(args.update) update_dict.update( { k: v for k, v in args.__dict__.items() if k not in [ "update", ] and v is not None } ) cfg = YAMLConfig(args.config, **update_dict) if args.resume or args.tuning: if "HGNetv2" in cfg.yaml_cfg: cfg.yaml_cfg["HGNetv2"]["pretrained"] = False if safe_get_rank() == 0: print("cfg: ") pprint(cfg.__dict__) solver = TASKS[cfg.yaml_cfg["task"]](cfg) if args.test_only: solver.val() else: solver.fit() dist_utils.cleanup() if __name__ == "__main__": parser = argparse.ArgumentParser() # priority 0 parser.add_argument("-c", "--config", type=str, required=True) parser.add_argument("-r", "--resume", type=str, help="resume from checkpoint") parser.add_argument("-t", "--tuning", type=str, help="tuning from checkpoint") parser.add_argument( "-d", "--device", type=str, help="device", ) parser.add_argument("--seed", type=int, help="exp reproducibility") parser.add_argument("--use-amp", action="store_true", help="auto mixed precision training") parser.add_argument("--output-dir", type=str, help="output directoy") parser.add_argument("--summary-dir", type=str, help="tensorboard summry") parser.add_argument( "--test-only", action="store_true", default=False, ) # priority 1 parser.add_argument("-u", "--update", nargs="+", help="update yaml config") # env parser.add_argument("--print-method", type=str, default="builtin", help="print method") parser.add_argument("--print-rank", type=int, default=0, help="print rank id") parser.add_argument("--local-rank", type=int, help="local rank id") args = parser.parse_args() main(args)