import collections.abc from itertools import repeat import importlib import yaml import time def default(value, default_val): return default_val if value is None else value def default_dtype(value, default_val): if value is not None: assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." return value return default_val def repeat_interleave(lst, num_repeats): return [item for item in lst for _ in range(num_repeats)] def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): x = tuple(x) if len(x) == 1: x = tuple(repeat(x[0], n)) return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) def as_tuple(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) if x is None or isinstance(x, (int, float, str)): return (x,) else: raise ValueError(f"Unknown type {type(x)}") def as_list_of_2tuple(x): x = as_tuple(x) if len(x) == 1: x = (x[0], x[0]) assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." lst = [] for i in range(0, len(x), 2): lst.append((x[i], x[i + 1])) return lst def find_multiple(n: int, k: int) -> int: assert k > 0 if n % k == 0: return n return n - (n % k) + k def merge_dicts(dict1, dict2): for key, value in dict2.items(): if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): merge_dicts(dict1[key], value) else: dict1[key] = value return dict1 def merge_yaml_files(file_list): merged_config = {} for file in file_list: with open(file, "r", encoding="utf-8") as f: config = yaml.safe_load(f) if config: # Remove the first level for key, value in config.items(): if isinstance(value, dict): merged_config = merge_dicts(merged_config, value) else: merged_config[key] = value return merged_config def merge_dict(file_list): merged_config = {} for file in file_list: with open(file, "r", encoding="utf-8") as f: config = yaml.safe_load(f) if config: merged_config = merge_dicts(merged_config, config) return merged_config def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def readable_time(seconds): """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ seconds = int(seconds) days, seconds = divmod(seconds, 86400) hours, seconds = divmod(seconds, 3600) minutes, seconds = divmod(seconds, 60) if days > 0: return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" if hours > 0: return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" if minutes > 0: return f"{minutes} Minutes, {seconds} Seconds" return f"{seconds} Seconds" def get_obj_from_cfg(cfg, reload=False): if isinstance(cfg, str): return get_obj_from_str(cfg, reload) elif isinstance(cfg, (list, tuple,)): return tuple([get_obj_from_str(c, reload) for c in cfg]) else: raise NotImplementedError(f"Not implemented for {type(cfg)}.")