Nan Xue
update
4c954ae
raw
history blame
1.81 kB
import yaml
import os
from easydict import EasyDict
def load_config(config_path):
""" Load configurations from a given yaml file. """
# Check file exists
if not os.path.exists(config_path):
raise ValueError("[Error] The provided config path is not valid.")
# Load the configuration
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# return EasyDict(config)
return config
def update_config(path, model_cfg=None, dataset_cfg=None):
""" Update configuration file from the resume path. """
# Check we need to update or completely override.
model_cfg = {} if model_cfg is None else model_cfg
dataset_cfg = {} if dataset_cfg is None else dataset_cfg
# Load saved configs
with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
model_cfg_saved = yaml.safe_load(f)
model_cfg.update(model_cfg_saved)
with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
dataset_cfg_saved = yaml.safe_load(f)
dataset_cfg.update(dataset_cfg_saved)
# Update the saved yaml file
if not model_cfg == model_cfg_saved:
with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
yaml.dump(model_cfg, f)
if not dataset_cfg == dataset_cfg_saved:
with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
yaml.dump(dataset_cfg, f)
return model_cfg, dataset_cfg
def record_config(model_cfg, dataset_cfg, output_path):
""" Record dataset config to the log path. """
# Record model config
with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
yaml.safe_dump(model_cfg, f)
# Record dataset config
with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
yaml.safe_dump(dataset_cfg, f)