Spaces:
Sleeping
Sleeping
import numpy as np | |
import yaml | |
import copy | |
from src.logger.logger import _logger | |
from src.data.tools import _get_variable_names | |
def _as_list(x): | |
if x is None: | |
return None | |
elif isinstance(x, (list, tuple)): | |
return x | |
else: | |
return [x] | |
def _md5(fname): | |
'''https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file''' | |
import hashlib | |
hash_md5 = hashlib.md5() | |
with open(fname, "rb") as f: | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
class DataConfig(object): | |
r"""Data loading configuration. | |
""" | |
def __init__(self, print_info=True, **kwargs): | |
opts = { | |
'treename': None, | |
'selection': None, | |
'test_time_selection': None, | |
'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None}, | |
'new_variables': {}, | |
'inputs': {}, | |
'labels': {}, | |
'observers': [], | |
'monitor_variables': [], | |
'weights': None, | |
'graph_config': {}, | |
'custom_model_kwargs': {} | |
} | |
for k, v in kwargs.items(): | |
if v is not None: | |
if isinstance(opts[k], dict): | |
opts[k].update(v) | |
else: | |
opts[k] = v | |
# only information in ``self.options'' will be persisted when exporting to YAML | |
self.options = opts | |
if print_info: | |
_logger.debug(opts) | |
self.selection = opts['selection'] | |
self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection | |
self.var_funcs = copy.deepcopy(opts['new_variables']) | |
# preprocessing config | |
self.preprocess = opts['preprocess'] | |
self._auto_standardization = opts['preprocess']['method'].lower().startswith('auto') | |
self._missing_standardization_info = False | |
self.preprocess_params = opts['preprocess']['params'] if opts['preprocess']['params'] is not None else {} | |
# inputs | |
self.input_names = tuple(opts['inputs'].keys()) | |
self.input_dicts = {k: [] for k in self.input_names} | |
self.input_shapes = {} | |
for k, o in opts['inputs'].items(): | |
self.input_shapes[k] = (-1, len(o['vars']), o['length']) | |
for v in o['vars']: | |
v = _as_list(v) | |
self.input_dicts[k].append(v[0]) | |
if opts['preprocess']['params'] is None: | |
def _get(idx, default): | |
try: | |
return v[idx] | |
except IndexError: | |
return default | |
params = {'length': o['length'], 'pad_mode': o.get('pad_mode', 'constant').lower(), | |
'center': _get(1, 'auto' if self._auto_standardization else None), | |
'scale': _get(2, 1), 'min': _get(3, -5), 'max': _get(4, 5), 'pad_value': _get(5, 0)} | |
if v[0] in self.preprocess_params and params != self.preprocess_params[v[0]]: | |
raise RuntimeError( | |
'Incompatible info for variable %s, had: \n %s\nnow got:\n %s' % | |
(v[0], str(self.preprocess_params[v[0]]), str(params))) | |
if k.endswith('_mask') and params['pad_mode'] != 'constant': | |
raise RuntimeError('The `pad_mode` must be set to `constant` for the mask input `%s`' % k) | |
if params['center'] == 'auto': | |
self._missing_standardization_info = True | |
self.preprocess_params[v[0]] = params | |
# observers | |
self.observer_names = tuple(opts['observers']) | |
# monitor variables | |
self.monitor_variables = tuple(opts['monitor_variables']) | |
# Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval) | |
self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables | |
# remove self mapping from var_funcs | |
for k, v in self.var_funcs.items(): | |
if k == v: | |
del self.var_funcs[k] | |
if print_info: | |
def _log(msg, *args, **kwargs): | |
_logger.info(msg, *args, color='lightgray', **kwargs) | |
_log('preprocess config: %s', str(self.preprocess)) | |
_log('selection: %s', str(self.selection)) | |
_log('test_time_selection: %s', str(self.test_time_selection)) | |
_log('var_funcs:\n - %s', '\n - '.join(str(it) for it in self.var_funcs.items())) | |
_log('input_names: %s', str(self.input_names)) | |
_log('input_dicts:\n - %s', '\n - '.join(str(it) for it in self.input_dicts.items())) | |
_log('input_shapes:\n - %s', '\n - '.join(str(it) for it in self.input_shapes.items())) | |
_log('preprocess_params:\n - %s', '\n - '.join(str(it) for it in self.preprocess_params.items())) | |
#_log('label_names: %s', str(self.label_names)) | |
_log('observer_names: %s', str(self.observer_names)) | |
_log('monitor_variables: %s', str(self.monitor_variables)) | |
if opts['weights'] is not None: | |
if self.use_precomputed_weights: | |
_log('weight: %s' % self.var_funcs[self.weight_name]) | |
else: | |
for k in ['reweight_method', 'reweight_basewgt', 'reweight_branches', 'reweight_bins', | |
'reweight_classes', 'class_weights', 'reweight_threshold', | |
'reweight_discard_under_overflow']: | |
_log('%s: %s' % (k, getattr(self, k))) | |
# parse config | |
self.keep_branches = set() | |
aux_branches = set() | |
# selection | |
if self.selection: | |
aux_branches.update(_get_variable_names(self.selection)) | |
# test time selection | |
if self.test_time_selection: | |
aux_branches.update(_get_variable_names(self.test_time_selection)) | |
# var_funcs | |
self.keep_branches.update(self.var_funcs.keys()) | |
for expr in self.var_funcs.values(): | |
aux_branches.update(_get_variable_names(expr)) | |
# inputs | |
for names in self.input_dicts.values(): | |
self.keep_branches.update(names) | |
# labels | |
#self.keep_branches.update(self.label_names) | |
# weight | |
#if self.weight_name: | |
# self.keep_branches.add(self.weight_name) | |
# if not self.use_precomputed_weights: | |
# aux_branches.update(self.reweight_branches) | |
# aux_branches.update(self.reweight_classes) | |
# observers | |
self.keep_branches.update(self.observer_names) | |
# monitor variables | |
self.keep_branches.update(self.monitor_variables) | |
# keep and drop | |
self.drop_branches = (aux_branches - self.keep_branches) | |
self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) #- {self.weight_name, } | |
if print_info: | |
_logger.debug('drop_branches:\n %s', ','.join(self.drop_branches)) | |
_logger.debug('load_branches:\n %s', ','.join(self.load_branches)) | |
def __getattr__(self, name): | |
return self.options[name] | |
def dump(self, fp): | |
with open(fp, 'w') as f: | |
yaml.safe_dump(self.options, f, sort_keys=False) | |
def load(cls, fp, load_observers=True, load_reweight_info=True, extra_selection=None, extra_test_selection=None): | |
with open(fp) as f: | |
options = yaml.safe_load(f) | |
if not load_observers: | |
options['observers'] = None | |
if not load_reweight_info: | |
options['weights'] = None | |
if extra_selection: | |
options['selection'] = '(%s) & (%s)' % (options['selection'], extra_selection) | |
if extra_test_selection: | |
if 'test_time_selection' not in options: | |
raise RuntimeError('`test_time_selection` is not defined in the yaml file!') | |
options['test_time_selection'] = '(%s) & (%s)' % (options['test_time_selection'], extra_test_selection) | |
return cls(**options) | |
def copy(self): | |
return self.__class__(print_info=False, **copy.deepcopy(self.options)) | |
def __copy__(self): | |
return self.copy() | |
def __deepcopy__(self, memo): | |
return self.copy() | |
def export_json(self, fp): | |
import json | |
j = {'output_names': self.label_value, 'input_names': self.input_names} | |
for k, v in self.input_dicts.items(): | |
j[k] = {'var_names': v, 'var_infos': {}} | |
for var_name in v: | |
j[k]['var_length'] = self.preprocess_params[var_name]['length'] | |
info = self.preprocess_params[var_name] | |
j[k]['var_infos'][var_name] = { | |
'median': 0 if info['center'] is None else info['center'], | |
'norm_factor': info['scale'], | |
'replace_inf_value': 0, | |
'lower_bound': -1e32 if info['center'] is None else info['min'], | |
'upper_bound': 1e32 if info['center'] is None else info['max'], | |
'pad': info['pad_value'] | |
} | |
with open(fp, 'w') as f: | |
json.dump(j, f, indent=2) | |