|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import copy
|
|
import os
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import yaml
|
|
|
|
from .workspace import GLOBAL_CONFIG
|
|
|
|
__all__ = [
|
|
"load_config",
|
|
"merge_config",
|
|
"merge_dict",
|
|
"parse_cli",
|
|
]
|
|
|
|
|
|
INCLUDE_KEY = "__include__"
|
|
|
|
|
|
def load_config(file_path, cfg=dict()):
|
|
"""load config"""
|
|
_, ext = os.path.splitext(file_path)
|
|
assert ext in [".yml", ".yaml"], "only support yaml files"
|
|
|
|
with open(file_path) as f:
|
|
file_cfg = yaml.load(f, Loader=yaml.Loader)
|
|
if file_cfg is None:
|
|
return {}
|
|
|
|
if INCLUDE_KEY in file_cfg:
|
|
base_yamls = list(file_cfg[INCLUDE_KEY])
|
|
for base_yaml in base_yamls:
|
|
if base_yaml.startswith("~"):
|
|
base_yaml = os.path.expanduser(base_yaml)
|
|
|
|
if not base_yaml.startswith("/"):
|
|
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
|
|
|
|
with open(base_yaml) as f:
|
|
base_cfg = load_config(base_yaml, cfg)
|
|
merge_dict(cfg, base_cfg)
|
|
|
|
return merge_dict(cfg, file_cfg)
|
|
|
|
|
|
def merge_dict(dct, another_dct, inplace=True) -> Dict:
|
|
"""merge another_dct into dct"""
|
|
|
|
def _merge(dct, another) -> Dict:
|
|
for k in another:
|
|
if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict):
|
|
_merge(dct[k], another[k])
|
|
else:
|
|
dct[k] = another[k]
|
|
|
|
return dct
|
|
|
|
if not inplace:
|
|
dct = copy.deepcopy(dct)
|
|
|
|
return _merge(dct, another_dct)
|
|
|
|
|
|
def dictify(s: str, v: Any) -> Dict:
|
|
if "." not in s:
|
|
return {s: v}
|
|
key, rest = s.split(".", 1)
|
|
return {key: dictify(rest, v)}
|
|
|
|
|
|
def parse_cli(nargs: List[str]) -> Dict:
|
|
"""
|
|
parse command-line arguments
|
|
convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}`
|
|
"""
|
|
cfg = {}
|
|
if nargs is None or len(nargs) == 0:
|
|
return cfg
|
|
|
|
for s in nargs:
|
|
s = s.strip()
|
|
k, v = s.split("=", 1)
|
|
d = dictify(k, yaml.load(v, Loader=yaml.Loader))
|
|
cfg = merge_dict(cfg, d)
|
|
|
|
return cfg
|
|
|
|
|
|
def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False):
|
|
"""
|
|
Merge another_cfg into cfg, return the merged config
|
|
|
|
Example:
|
|
|
|
cfg1 = load_config('./dfine_r18vd_6x_coco.yml')
|
|
cfg1 = merge_config(cfg, inplace=True)
|
|
|
|
cfg2 = load_config('./dfine_r50vd_6x_coco.yml')
|
|
cfg2 = merge_config(cfg2, inplace=True)
|
|
|
|
model1 = create(cfg1['model'], cfg1)
|
|
model2 = create(cfg2['model'], cfg2)
|
|
"""
|
|
|
|
def _merge(dct, another):
|
|
for k in another:
|
|
if k not in dct:
|
|
dct[k] = another[k]
|
|
|
|
elif isinstance(dct[k], dict) and isinstance(another[k], dict):
|
|
_merge(dct[k], another[k])
|
|
|
|
elif overwrite:
|
|
dct[k] = another[k]
|
|
|
|
return cfg
|
|
|
|
if not inplace:
|
|
cfg = copy.deepcopy(cfg)
|
|
|
|
return _merge(cfg, another_cfg)
|
|
|