|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
GLOBAL_CONFIG = defaultdict(dict)
|
|
|
|
|
|
def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
|
|
"""
|
|
dct:
|
|
if dct is Dict, register foo into dct as key-value pair
|
|
if dct is Clas, register as modules attibute
|
|
force
|
|
whether force register.
|
|
"""
|
|
|
|
def decorator(foo):
|
|
register_name = foo.__name__ if name is None else name
|
|
if not force:
|
|
if inspect.isclass(dct):
|
|
assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
|
|
else:
|
|
assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"
|
|
|
|
if inspect.isfunction(foo):
|
|
|
|
@functools.wraps(foo)
|
|
def wrap_func(*args, **kwargs):
|
|
return foo(*args, **kwargs)
|
|
|
|
if isinstance(dct, dict):
|
|
dct[foo.__name__] = wrap_func
|
|
elif inspect.isclass(dct):
|
|
setattr(dct, foo.__name__, wrap_func)
|
|
else:
|
|
raise AttributeError("")
|
|
return wrap_func
|
|
|
|
elif inspect.isclass(foo):
|
|
dct[register_name] = extract_schema(foo)
|
|
|
|
else:
|
|
raise ValueError(f"Do not support {type(foo)} register")
|
|
|
|
return foo
|
|
|
|
return decorator
|
|
|
|
|
|
def extract_schema(module: type):
|
|
"""
|
|
Args:
|
|
module (type),
|
|
Return:
|
|
Dict,
|
|
"""
|
|
argspec = inspect.getfullargspec(module.__init__)
|
|
arg_names = [arg for arg in argspec.args if arg != "self"]
|
|
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
|
|
num_requires = len(arg_names) - num_defualts
|
|
|
|
schame = dict()
|
|
schame["_name"] = module.__name__
|
|
schame["_pymodule"] = importlib.import_module(module.__module__)
|
|
schame["_inject"] = getattr(module, "__inject__", [])
|
|
schame["_share"] = getattr(module, "__share__", [])
|
|
schame["_kwargs"] = {}
|
|
for i, name in enumerate(arg_names):
|
|
if name in schame["_share"]:
|
|
assert i >= num_requires, "share config must have default value."
|
|
value = argspec.defaults[i - num_requires]
|
|
|
|
elif i >= num_requires:
|
|
value = argspec.defaults[i - num_requires]
|
|
|
|
else:
|
|
value = None
|
|
|
|
schame[name] = value
|
|
schame["_kwargs"][name] = value
|
|
|
|
return schame
|
|
|
|
|
|
def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
|
|
""" """
|
|
assert type(type_or_name) in (type, str), "create should be modules or name."
|
|
|
|
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
|
|
|
|
if name in global_cfg:
|
|
if hasattr(global_cfg[name], "__dict__"):
|
|
return global_cfg[name]
|
|
else:
|
|
raise ValueError("The module {} is not registered".format(name))
|
|
|
|
cfg = global_cfg[name]
|
|
|
|
if isinstance(cfg, dict) and "type" in cfg:
|
|
_cfg: dict = global_cfg[cfg["type"]]
|
|
|
|
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
|
|
for _arg in _keys:
|
|
del _cfg[_arg]
|
|
_cfg.update(_cfg["_kwargs"])
|
|
_cfg.update(cfg)
|
|
_cfg.update(kwargs)
|
|
name = _cfg.pop("type")
|
|
|
|
return create(name, global_cfg)
|
|
|
|
module = getattr(cfg["_pymodule"], name)
|
|
module_kwargs = {}
|
|
module_kwargs.update(cfg)
|
|
|
|
|
|
for k in cfg["_share"]:
|
|
if k in global_cfg:
|
|
module_kwargs[k] = global_cfg[k]
|
|
else:
|
|
module_kwargs[k] = cfg[k]
|
|
|
|
|
|
for k in cfg["_inject"]:
|
|
_k = cfg[k]
|
|
|
|
if _k is None:
|
|
continue
|
|
|
|
if isinstance(_k, str):
|
|
if _k not in global_cfg:
|
|
raise ValueError(f"Missing inject config of {_k}.")
|
|
|
|
_cfg = global_cfg[_k]
|
|
|
|
if isinstance(_cfg, dict):
|
|
module_kwargs[k] = create(_cfg["_name"], global_cfg)
|
|
else:
|
|
module_kwargs[k] = _cfg
|
|
|
|
elif isinstance(_k, dict):
|
|
if "type" not in _k.keys():
|
|
raise ValueError("Missing inject for `type` style.")
|
|
|
|
_type = str(_k["type"])
|
|
if _type not in global_cfg:
|
|
raise ValueError(f"Missing {_type} in inspect stage.")
|
|
|
|
|
|
_cfg: dict = global_cfg[_type]
|
|
|
|
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
|
|
for _arg in _keys:
|
|
del _cfg[_arg]
|
|
_cfg.update(_cfg["_kwargs"])
|
|
_cfg.update(_k)
|
|
name = _cfg.pop("type")
|
|
module_kwargs[k] = create(name, global_cfg)
|
|
|
|
else:
|
|
raise ValueError(f"Inject does not support {_k}")
|
|
|
|
|
|
module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return module(**module_kwargs)
|
|
|