Spaces:
Build error
Build error
"""Utilities for modifying the configuration.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
import itertools | |
import operator | |
from functools import reduce | |
from typing import Any, Generator, List | |
from omegaconf import DictConfig | |
def flatten_sweep_params(params_dict: DictConfig) -> DictConfig: | |
"""Flatten the nested parameters section of the config object. | |
We need to flatten the params so that all the nested keys are concatenated into a single string. | |
This is useful when | |
- We need to do a cartesian product of all the combinations of the configuration for grid search. | |
- Save keys as headers for csv | |
- Add the config to `wandb` sweep. | |
Args: | |
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure. | |
Returns: | |
flattened version of the parameter dictionary. | |
""" | |
def flatten_nested_dict(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig): | |
"""Flatten nested dictionary. | |
Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened | |
dictionary. | |
Args: | |
nested_params: DictConfig: config object containing the original parameters. | |
keys: List[str]: list of keys leading to the current location in the config. | |
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored. | |
""" | |
for name, cfg in nested_params.items(): | |
if isinstance(cfg, DictConfig): | |
flatten_nested_dict(cfg, keys + [str(name)], flattened_params) | |
else: | |
key = ".".join(keys + [str(name)]) | |
flattened_params[key] = cfg | |
flattened_params_dict = DictConfig({}) | |
flatten_nested_dict(params_dict, [], flattened_params_dict) | |
return flattened_params_dict | |
def get_run_config(params_dict: DictConfig) -> Generator[DictConfig, None, None]: | |
"""Yields configuration for a single run. | |
Args: | |
params_dict (DictConfig): Configuration for grid search. | |
Example: | |
>>> dummy_config = DictConfig({ | |
"parent1":{ | |
"child1": ['a', 'b', 'c'], | |
"child2": [1, 2, 3] | |
}, | |
"parent2":['model1', 'model2'] | |
}) | |
>>> for run_config in get_run_config(dummy_config): | |
>>> print(run_config) | |
{'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'} | |
{'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model2'} | |
{'parent1.child1': 'a', 'parent1.child2': 2, 'parent2': 'model1'} | |
... | |
Yields: | |
Generator[DictConfig]: Dictionary containing flattened keys | |
and values for current run. | |
""" | |
params = flatten_sweep_params(params_dict) | |
combinations = list(itertools.product(*params.values())) | |
keys = params.keys() | |
for combination in combinations: | |
run_config = DictConfig({}) | |
for key, val in zip(keys, combination): | |
run_config[key] = val | |
yield run_config | |
def get_from_nested_config(config: DictConfig, keymap: List) -> Any: | |
"""Retrieves an item from a nested config object using a list of keys. | |
Args: | |
config: DictConfig: nested DictConfig object | |
keymap: List[str]: list of keys corresponding to item that should be retrieved. | |
""" | |
return reduce(operator.getitem, keymap, config) | |
def set_in_nested_config(config: DictConfig, keymap: List, value: Any): | |
"""Set an item in a nested config object using a list of keys. | |
Args: | |
config: DictConfig: nested DictConfig object | |
keymap: List[str]: list of keys corresponding to item that should be set. | |
value: Any: Value that should be assigned to the dictionary item at the specified location. | |
Example: | |
>>> dummy_config = DictConfig({ | |
"parent1":{ | |
"child1": ['a', 'b', 'c'], | |
"child2": [1, 2, 3] | |
}, | |
"parent2":['model1', 'model2'] | |
}) | |
>>> model_config = DictConfig({ | |
"parent1":{ | |
"child1": 'e', | |
"child2": 4, | |
}, | |
"parent3": False | |
}) | |
>>> for run_config in get_run_config(dummy_config): | |
>>> print("Original model config", model_config) | |
>>> print("Suggested config", run_config) | |
>>> for param in run_config.keys(): | |
>>> set_in_nested_config(model_config, param.split('.'), run_config[param]) | |
>>> print("Replaced model config", model_config) | |
>>> break | |
Original model config {'parent1': {'child1': 'e', 'child2': 4}, 'parent3': False} | |
Suggested config {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'} | |
Replaced model config {'parent1': {'child1': 'a', 'child2': 1}, 'parent3': False, 'parent2': 'model1'} | |
""" | |
get_from_nested_config(config, keymap[:-1])[keymap[-1]] = value | |