File size: 5,521 Bytes
c8c12e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Utilities for modifying the configuration."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import itertools
import operator
from functools import reduce
from typing import Any, Generator, List

from omegaconf import DictConfig


def flatten_sweep_params(params_dict: DictConfig) -> DictConfig:
    """Flatten the nested parameters section of the config object.

    We need to flatten the params so that all the nested keys are concatenated into a single string.
    This is useful when
    - We need to do a cartesian product of all the combinations of the configuration for grid search.
    - Save keys as headers for csv
    - Add the config to `wandb` sweep.

    Args:
        params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.

    Returns:
        flattened version of the parameter dictionary.
    """

    def flatten_nested_dict(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
        """Flatten nested dictionary.

        Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
        dictionary.

        Args:
            nested_params: DictConfig: config object containing the original parameters.
            keys: List[str]: list of keys leading to the current location in the config.
            flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
        """
        for name, cfg in nested_params.items():
            if isinstance(cfg, DictConfig):
                flatten_nested_dict(cfg, keys + [str(name)], flattened_params)
            else:
                key = ".".join(keys + [str(name)])
                flattened_params[key] = cfg

    flattened_params_dict = DictConfig({})
    flatten_nested_dict(params_dict, [], flattened_params_dict)

    return flattened_params_dict


def get_run_config(params_dict: DictConfig) -> Generator[DictConfig, None, None]:
    """Yields configuration for a single run.

    Args:
        params_dict (DictConfig): Configuration for grid search.

    Example:
        >>> dummy_config = DictConfig({
            "parent1":{
                "child1": ['a', 'b', 'c'],
                "child2": [1, 2, 3]
            },
            "parent2":['model1', 'model2']
        })
        >>> for run_config in get_run_config(dummy_config):
        >>>    print(run_config)
        {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'}
        {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model2'}
        {'parent1.child1': 'a', 'parent1.child2': 2, 'parent2': 'model1'}
        ...

    Yields:
        Generator[DictConfig]: Dictionary containing flattened keys
        and values for current run.
    """
    params = flatten_sweep_params(params_dict)
    combinations = list(itertools.product(*params.values()))
    keys = params.keys()
    for combination in combinations:
        run_config = DictConfig({})
        for key, val in zip(keys, combination):
            run_config[key] = val
        yield run_config


def get_from_nested_config(config: DictConfig, keymap: List) -> Any:
    """Retrieves an item from a nested config object using a list of keys.

    Args:
        config: DictConfig: nested DictConfig object
        keymap: List[str]: list of keys corresponding to item that should be retrieved.
    """
    return reduce(operator.getitem, keymap, config)


def set_in_nested_config(config: DictConfig, keymap: List, value: Any):
    """Set an item in a nested config object using a list of keys.

    Args:
        config: DictConfig: nested DictConfig object
        keymap: List[str]: list of keys corresponding to item that should be set.
        value: Any: Value that should be assigned to the dictionary item at the specified location.

    Example:
        >>> dummy_config = DictConfig({
            "parent1":{
                "child1": ['a', 'b', 'c'],
                "child2": [1, 2, 3]
            },
            "parent2":['model1', 'model2']
        })
        >>> model_config = DictConfig({
            "parent1":{
                "child1": 'e',
                "child2": 4,
            },
            "parent3": False
        })
        >>> for run_config in get_run_config(dummy_config):
        >>>    print("Original model config", model_config)
        >>>    print("Suggested config", run_config)
        >>>    for param in run_config.keys():
        >>>        set_in_nested_config(model_config, param.split('.'), run_config[param])
        >>>    print("Replaced model config", model_config)
        >>>    break
        Original model config {'parent1': {'child1': 'e', 'child2': 4}, 'parent3': False}
        Suggested config {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'}
        Replaced model config {'parent1': {'child1': 'a', 'child2': 1}, 'parent3': False, 'parent2': 'model1'}
    """
    get_from_nested_config(config, keymap[:-1])[keymap[-1]] = value