File size: 5,823 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import dataclasses
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar
import yaml
from . import util
__all__ = ['ConfigBase', 'PathLike']
T = TypeVar('T', bound='ConfigBase')
PathLike = util.PathLike
def _is_missing(obj: Any) -> bool:
return isinstance(obj, type(dataclasses.MISSING))
class ConfigBase:
"""
Base class of config classes.
Subclass may override `_canonical_rules` and `_validation_rules`,
and `validate()` if the logic is complex.
"""
# Rules to convert field value to canonical format.
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules = {} # type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules = {} # type: ignore
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
"""
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
if 'basepath' in kwargs:
_base_path = kwargs.pop('basepath')
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
for field in dataclasses.fields(self):
value = kwargs.pop(util.case_insensitive(field.name), field.default)
if value is not None and not _is_missing(value):
# relative paths loaded from config file are not relative to pwd
if 'Path' in str(field.type):
value = Path(value).expanduser()
if not value.is_absolute():
value = _base_path / value
setattr(self, field.name, value)
if kwargs:
cls = type(self).__name__
fields = ', '.join(kwargs.keys())
raise ValueError(f'{cls}: Unrecognized fields {fields}')
@classmethod
def load(cls: Type[T], path: PathLike) -> T:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data = yaml.safe_load(open(path))
if not isinstance(data, dict):
raise ValueError(f'Content of config file {path} is not a dict/object')
return cls(**data, _base_path=Path(path).parent)
def json(self) -> Dict[str, Any]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
self.validate()
return dataclasses.asdict(
self.canonical(),
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
)
def canonical(self: T) -> T:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret = copy.deepcopy(self)
for field in dataclasses.fields(ret):
key, value = field.name, getattr(ret, field.name)
rule = ret._canonical_rules.get(key)
if rule is not None:
setattr(ret, key, rule(value))
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
elif isinstance(value, Path):
setattr(ret, key, str(value))
return ret
def validate(self) -> None:
"""
Validate the config object and raise Exception if it's ill-formed.
"""
class_name = type(self).__name__
config = self.canonical()
for field in dataclasses.fields(config):
key, value = field.name, getattr(config, field.name)
# check existence
if _is_missing(value):
raise ValueError(f'{class_name}: {key} is not set')
# check type (TODO)
type_name = str(field.type).replace('typing.', '')
optional = any([
type_name.startswith('Optional['),
type_name.startswith('Union[') and 'None' in type_name,
type_name == 'Any'
])
if value is None:
if optional:
continue
else:
raise ValueError(f'{class_name}: {key} cannot be None')
# check value
rule = config._validation_rules.get(key)
if rule is not None:
try:
result = rule(value)
except Exception:
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}')
if isinstance(result, bool):
if not result:
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range')
else:
if not result[0]:
raise ValueError(f'{class_name}: {key} {result[1]}')
# check nested config
if isinstance(value, ConfigBase):
value.validate()
|