File size: 3,722 Bytes
9867d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import collections.abc
from itertools import repeat
import importlib
import yaml
import time

def default(value, default_val):
    return default_val if value is None else value


def default_dtype(value, default_val):
    if value is not None:
        assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}."
        return value
    return default_val


def repeat_interleave(lst, num_repeats):
    return [item for item in lst for _ in range(num_repeats)]


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            x = tuple(x)
            if len(x) == 1:
                x = tuple(repeat(x[0], n))
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)


def as_tuple(x):
    if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
        return tuple(x)
    if x is None or isinstance(x, (int, float, str)):
        return (x,)
    else:
        raise ValueError(f"Unknown type {type(x)}")


def as_list_of_2tuple(x):
    x = as_tuple(x)
    if len(x) == 1:
        x = (x[0], x[0])
    assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
    lst = []
    for i in range(0, len(x), 2):
        lst.append((x[i], x[i + 1]))
    return lst


def find_multiple(n: int, k: int) -> int:
    assert k > 0
    if n % k == 0:
        return n
    return n - (n % k) + k


def merge_dicts(dict1, dict2):
    for key, value in dict2.items():
        if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict):
            merge_dicts(dict1[key], value)
        else:
            dict1[key] = value
    return dict1


def merge_yaml_files(file_list):
    merged_config = {}

    for file in file_list:
        with open(file, "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)
            if config:
                # Remove the first level
                for key, value in config.items():
                    if isinstance(value, dict):
                        merged_config = merge_dicts(merged_config, value)
                    else:
                        merged_config[key] = value

    return merged_config


def merge_dict(file_list):
    merged_config = {}

    for file in file_list:
        with open(file, "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)
            if config:
                merged_config = merge_dicts(merged_config, config)

    return merged_config


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def readable_time(seconds):
    """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """
    seconds = int(seconds)
    days, seconds = divmod(seconds, 86400)
    hours, seconds = divmod(seconds, 3600)
    minutes, seconds = divmod(seconds, 60)
    if days > 0:
        return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds"
    if hours > 0:
        return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds"
    if minutes > 0:
        return f"{minutes} Minutes, {seconds} Seconds"
    return f"{seconds} Seconds"


def get_obj_from_cfg(cfg, reload=False):
    if isinstance(cfg, str):
        return get_obj_from_str(cfg, reload)
    elif isinstance(cfg, (list, tuple,)):
        return tuple([get_obj_from_str(c, reload) for c in cfg])
    else:
        raise NotImplementedError(f"Not implemented for {type(cfg)}.")