julien.blanchon
add app
c8c12e9
"""Utilities for modifying the configuration."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import itertools
import operator
from functools import reduce
from typing import Any, Generator, List
from omegaconf import DictConfig
def flatten_sweep_params(params_dict: DictConfig) -> DictConfig:
"""Flatten the nested parameters section of the config object.
We need to flatten the params so that all the nested keys are concatenated into a single string.
This is useful when
- We need to do a cartesian product of all the combinations of the configuration for grid search.
- Save keys as headers for csv
- Add the config to `wandb` sweep.
Args:
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.
Returns:
flattened version of the parameter dictionary.
"""
def flatten_nested_dict(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
"""Flatten nested dictionary.
Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
dictionary.
Args:
nested_params: DictConfig: config object containing the original parameters.
keys: List[str]: list of keys leading to the current location in the config.
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
"""
for name, cfg in nested_params.items():
if isinstance(cfg, DictConfig):
flatten_nested_dict(cfg, keys + [str(name)], flattened_params)
else:
key = ".".join(keys + [str(name)])
flattened_params[key] = cfg
flattened_params_dict = DictConfig({})
flatten_nested_dict(params_dict, [], flattened_params_dict)
return flattened_params_dict
def get_run_config(params_dict: DictConfig) -> Generator[DictConfig, None, None]:
"""Yields configuration for a single run.
Args:
params_dict (DictConfig): Configuration for grid search.
Example:
>>> dummy_config = DictConfig({
"parent1":{
"child1": ['a', 'b', 'c'],
"child2": [1, 2, 3]
},
"parent2":['model1', 'model2']
})
>>> for run_config in get_run_config(dummy_config):
>>> print(run_config)
{'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'}
{'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model2'}
{'parent1.child1': 'a', 'parent1.child2': 2, 'parent2': 'model1'}
...
Yields:
Generator[DictConfig]: Dictionary containing flattened keys
and values for current run.
"""
params = flatten_sweep_params(params_dict)
combinations = list(itertools.product(*params.values()))
keys = params.keys()
for combination in combinations:
run_config = DictConfig({})
for key, val in zip(keys, combination):
run_config[key] = val
yield run_config
def get_from_nested_config(config: DictConfig, keymap: List) -> Any:
"""Retrieves an item from a nested config object using a list of keys.
Args:
config: DictConfig: nested DictConfig object
keymap: List[str]: list of keys corresponding to item that should be retrieved.
"""
return reduce(operator.getitem, keymap, config)
def set_in_nested_config(config: DictConfig, keymap: List, value: Any):
"""Set an item in a nested config object using a list of keys.
Args:
config: DictConfig: nested DictConfig object
keymap: List[str]: list of keys corresponding to item that should be set.
value: Any: Value that should be assigned to the dictionary item at the specified location.
Example:
>>> dummy_config = DictConfig({
"parent1":{
"child1": ['a', 'b', 'c'],
"child2": [1, 2, 3]
},
"parent2":['model1', 'model2']
})
>>> model_config = DictConfig({
"parent1":{
"child1": 'e',
"child2": 4,
},
"parent3": False
})
>>> for run_config in get_run_config(dummy_config):
>>> print("Original model config", model_config)
>>> print("Suggested config", run_config)
>>> for param in run_config.keys():
>>> set_in_nested_config(model_config, param.split('.'), run_config[param])
>>> print("Replaced model config", model_config)
>>> break
Original model config {'parent1': {'child1': 'e', 'child2': 4}, 'parent3': False}
Suggested config {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'}
Replaced model config {'parent1': {'child1': 'a', 'child2': 1}, 'parent3': False, 'parent2': 'model1'}
"""
get_from_nested_config(config, keymap[:-1])[keymap[-1]] = value