James Zhou
[init]
9867d34
raw
history blame
3.72 kB
import collections.abc
from itertools import repeat
import importlib
import yaml
import time
def default(value, default_val):
return default_val if value is None else value
def default_dtype(value, default_val):
if value is not None:
assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}."
return value
return default_val
def repeat_interleave(lst, num_repeats):
return [item for item in lst for _ in range(num_repeats)]
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
x = tuple(x)
if len(x) == 1:
x = tuple(repeat(x[0], n))
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
def as_tuple(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
if x is None or isinstance(x, (int, float, str)):
return (x,)
else:
raise ValueError(f"Unknown type {type(x)}")
def as_list_of_2tuple(x):
x = as_tuple(x)
if len(x) == 1:
x = (x[0], x[0])
assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
lst = []
for i in range(0, len(x), 2):
lst.append((x[i], x[i + 1]))
return lst
def find_multiple(n: int, k: int) -> int:
assert k > 0
if n % k == 0:
return n
return n - (n % k) + k
def merge_dicts(dict1, dict2):
for key, value in dict2.items():
if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict):
merge_dicts(dict1[key], value)
else:
dict1[key] = value
return dict1
def merge_yaml_files(file_list):
merged_config = {}
for file in file_list:
with open(file, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
if config:
# Remove the first level
for key, value in config.items():
if isinstance(value, dict):
merged_config = merge_dicts(merged_config, value)
else:
merged_config[key] = value
return merged_config
def merge_dict(file_list):
merged_config = {}
for file in file_list:
with open(file, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
if config:
merged_config = merge_dicts(merged_config, config)
return merged_config
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def readable_time(seconds):
""" Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """
seconds = int(seconds)
days, seconds = divmod(seconds, 86400)
hours, seconds = divmod(seconds, 3600)
minutes, seconds = divmod(seconds, 60)
if days > 0:
return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds"
if hours > 0:
return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds"
if minutes > 0:
return f"{minutes} Minutes, {seconds} Seconds"
return f"{seconds} Seconds"
def get_obj_from_cfg(cfg, reload=False):
if isinstance(cfg, str):
return get_obj_from_str(cfg, reload)
elif isinstance(cfg, (list, tuple,)):
return tuple([get_obj_from_str(c, reload) for c in cfg])
else:
raise NotImplementedError(f"Not implemented for {type(cfg)}.")