import numpy as np import yaml import copy from src.logger.logger import _logger from src.data.tools import _get_variable_names def _as_list(x): if x is None: return None elif isinstance(x, (list, tuple)): return x else: return [x] def _md5(fname): '''https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file''' import hashlib hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() class DataConfig(object): r"""Data loading configuration. """ def __init__(self, print_info=True, **kwargs): opts = { 'treename': None, 'selection': None, 'test_time_selection': None, 'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None}, 'new_variables': {}, 'inputs': {}, 'labels': {}, 'observers': [], 'monitor_variables': [], 'weights': None, 'graph_config': {}, 'custom_model_kwargs': {} } for k, v in kwargs.items(): if v is not None: if isinstance(opts[k], dict): opts[k].update(v) else: opts[k] = v # only information in ``self.options'' will be persisted when exporting to YAML self.options = opts if print_info: _logger.debug(opts) self.selection = opts['selection'] self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection self.var_funcs = copy.deepcopy(opts['new_variables']) # preprocessing config self.preprocess = opts['preprocess'] self._auto_standardization = opts['preprocess']['method'].lower().startswith('auto') self._missing_standardization_info = False self.preprocess_params = opts['preprocess']['params'] if opts['preprocess']['params'] is not None else {} # inputs self.input_names = tuple(opts['inputs'].keys()) self.input_dicts = {k: [] for k in self.input_names} self.input_shapes = {} for k, o in opts['inputs'].items(): self.input_shapes[k] = (-1, len(o['vars']), o['length']) for v in o['vars']: v = _as_list(v) self.input_dicts[k].append(v[0]) if opts['preprocess']['params'] is None: def _get(idx, default): try: return v[idx] except IndexError: return default params = {'length': o['length'], 'pad_mode': o.get('pad_mode', 'constant').lower(), 'center': _get(1, 'auto' if self._auto_standardization else None), 'scale': _get(2, 1), 'min': _get(3, -5), 'max': _get(4, 5), 'pad_value': _get(5, 0)} if v[0] in self.preprocess_params and params != self.preprocess_params[v[0]]: raise RuntimeError( 'Incompatible info for variable %s, had: \n %s\nnow got:\n %s' % (v[0], str(self.preprocess_params[v[0]]), str(params))) if k.endswith('_mask') and params['pad_mode'] != 'constant': raise RuntimeError('The `pad_mode` must be set to `constant` for the mask input `%s`' % k) if params['center'] == 'auto': self._missing_standardization_info = True self.preprocess_params[v[0]] = params # observers self.observer_names = tuple(opts['observers']) # monitor variables self.monitor_variables = tuple(opts['monitor_variables']) # Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval) self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables # remove self mapping from var_funcs for k, v in self.var_funcs.items(): if k == v: del self.var_funcs[k] if print_info: def _log(msg, *args, **kwargs): _logger.info(msg, *args, color='lightgray', **kwargs) _log('preprocess config: %s', str(self.preprocess)) _log('selection: %s', str(self.selection)) _log('test_time_selection: %s', str(self.test_time_selection)) _log('var_funcs:\n - %s', '\n - '.join(str(it) for it in self.var_funcs.items())) _log('input_names: %s', str(self.input_names)) _log('input_dicts:\n - %s', '\n - '.join(str(it) for it in self.input_dicts.items())) _log('input_shapes:\n - %s', '\n - '.join(str(it) for it in self.input_shapes.items())) _log('preprocess_params:\n - %s', '\n - '.join(str(it) for it in self.preprocess_params.items())) #_log('label_names: %s', str(self.label_names)) _log('observer_names: %s', str(self.observer_names)) _log('monitor_variables: %s', str(self.monitor_variables)) if opts['weights'] is not None: if self.use_precomputed_weights: _log('weight: %s' % self.var_funcs[self.weight_name]) else: for k in ['reweight_method', 'reweight_basewgt', 'reweight_branches', 'reweight_bins', 'reweight_classes', 'class_weights', 'reweight_threshold', 'reweight_discard_under_overflow']: _log('%s: %s' % (k, getattr(self, k))) # parse config self.keep_branches = set() aux_branches = set() # selection if self.selection: aux_branches.update(_get_variable_names(self.selection)) # test time selection if self.test_time_selection: aux_branches.update(_get_variable_names(self.test_time_selection)) # var_funcs self.keep_branches.update(self.var_funcs.keys()) for expr in self.var_funcs.values(): aux_branches.update(_get_variable_names(expr)) # inputs for names in self.input_dicts.values(): self.keep_branches.update(names) # labels #self.keep_branches.update(self.label_names) # weight #if self.weight_name: # self.keep_branches.add(self.weight_name) # if not self.use_precomputed_weights: # aux_branches.update(self.reweight_branches) # aux_branches.update(self.reweight_classes) # observers self.keep_branches.update(self.observer_names) # monitor variables self.keep_branches.update(self.monitor_variables) # keep and drop self.drop_branches = (aux_branches - self.keep_branches) self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) #- {self.weight_name, } if print_info: _logger.debug('drop_branches:\n %s', ','.join(self.drop_branches)) _logger.debug('load_branches:\n %s', ','.join(self.load_branches)) def __getattr__(self, name): return self.options[name] def dump(self, fp): with open(fp, 'w') as f: yaml.safe_dump(self.options, f, sort_keys=False) @classmethod def load(cls, fp, load_observers=True, load_reweight_info=True, extra_selection=None, extra_test_selection=None): with open(fp) as f: options = yaml.safe_load(f) if not load_observers: options['observers'] = None if not load_reweight_info: options['weights'] = None if extra_selection: options['selection'] = '(%s) & (%s)' % (options['selection'], extra_selection) if extra_test_selection: if 'test_time_selection' not in options: raise RuntimeError('`test_time_selection` is not defined in the yaml file!') options['test_time_selection'] = '(%s) & (%s)' % (options['test_time_selection'], extra_test_selection) return cls(**options) def copy(self): return self.__class__(print_info=False, **copy.deepcopy(self.options)) def __copy__(self): return self.copy() def __deepcopy__(self, memo): return self.copy() def export_json(self, fp): import json j = {'output_names': self.label_value, 'input_names': self.input_names} for k, v in self.input_dicts.items(): j[k] = {'var_names': v, 'var_infos': {}} for var_name in v: j[k]['var_length'] = self.preprocess_params[var_name]['length'] info = self.preprocess_params[var_name] j[k]['var_infos'][var_name] = { 'median': 0 if info['center'] is None else info['center'], 'norm_factor': info['scale'], 'replace_inf_value': 0, 'lower_bound': -1e32 if info['center'] is None else info['min'], 'upper_bound': 1e32 if info['center'] is None else info['max'], 'pad': info['pad_value'] } with open(fp, 'w') as f: json.dump(j, f, indent=2)