Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import yaml | |
| import copy | |
| from src.logger.logger import _logger | |
| from src.data.tools import _get_variable_names | |
| def _as_list(x): | |
| if x is None: | |
| return None | |
| elif isinstance(x, (list, tuple)): | |
| return x | |
| else: | |
| return [x] | |
| def _md5(fname): | |
| '''https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file''' | |
| import hashlib | |
| hash_md5 = hashlib.md5() | |
| with open(fname, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| class DataConfig(object): | |
| r"""Data loading configuration. | |
| """ | |
| def __init__(self, print_info=True, **kwargs): | |
| opts = { | |
| 'treename': None, | |
| 'selection': None, | |
| 'test_time_selection': None, | |
| 'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None}, | |
| 'new_variables': {}, | |
| 'inputs': {}, | |
| 'labels': {}, | |
| 'observers': [], | |
| 'monitor_variables': [], | |
| 'weights': None, | |
| 'graph_config': {}, | |
| 'custom_model_kwargs': {} | |
| } | |
| for k, v in kwargs.items(): | |
| if v is not None: | |
| if isinstance(opts[k], dict): | |
| opts[k].update(v) | |
| else: | |
| opts[k] = v | |
| # only information in ``self.options'' will be persisted when exporting to YAML | |
| self.options = opts | |
| if print_info: | |
| _logger.debug(opts) | |
| self.selection = opts['selection'] | |
| self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection | |
| self.var_funcs = copy.deepcopy(opts['new_variables']) | |
| # preprocessing config | |
| self.preprocess = opts['preprocess'] | |
| self._auto_standardization = opts['preprocess']['method'].lower().startswith('auto') | |
| self._missing_standardization_info = False | |
| self.preprocess_params = opts['preprocess']['params'] if opts['preprocess']['params'] is not None else {} | |
| # inputs | |
| self.input_names = tuple(opts['inputs'].keys()) | |
| self.input_dicts = {k: [] for k in self.input_names} | |
| self.input_shapes = {} | |
| for k, o in opts['inputs'].items(): | |
| self.input_shapes[k] = (-1, len(o['vars']), o['length']) | |
| for v in o['vars']: | |
| v = _as_list(v) | |
| self.input_dicts[k].append(v[0]) | |
| if opts['preprocess']['params'] is None: | |
| def _get(idx, default): | |
| try: | |
| return v[idx] | |
| except IndexError: | |
| return default | |
| params = {'length': o['length'], 'pad_mode': o.get('pad_mode', 'constant').lower(), | |
| 'center': _get(1, 'auto' if self._auto_standardization else None), | |
| 'scale': _get(2, 1), 'min': _get(3, -5), 'max': _get(4, 5), 'pad_value': _get(5, 0)} | |
| if v[0] in self.preprocess_params and params != self.preprocess_params[v[0]]: | |
| raise RuntimeError( | |
| 'Incompatible info for variable %s, had: \n %s\nnow got:\n %s' % | |
| (v[0], str(self.preprocess_params[v[0]]), str(params))) | |
| if k.endswith('_mask') and params['pad_mode'] != 'constant': | |
| raise RuntimeError('The `pad_mode` must be set to `constant` for the mask input `%s`' % k) | |
| if params['center'] == 'auto': | |
| self._missing_standardization_info = True | |
| self.preprocess_params[v[0]] = params | |
| # observers | |
| self.observer_names = tuple(opts['observers']) | |
| # monitor variables | |
| self.monitor_variables = tuple(opts['monitor_variables']) | |
| # Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval) | |
| self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables | |
| # remove self mapping from var_funcs | |
| for k, v in self.var_funcs.items(): | |
| if k == v: | |
| del self.var_funcs[k] | |
| if print_info: | |
| def _log(msg, *args, **kwargs): | |
| _logger.info(msg, *args, color='lightgray', **kwargs) | |
| _log('preprocess config: %s', str(self.preprocess)) | |
| _log('selection: %s', str(self.selection)) | |
| _log('test_time_selection: %s', str(self.test_time_selection)) | |
| _log('var_funcs:\n - %s', '\n - '.join(str(it) for it in self.var_funcs.items())) | |
| _log('input_names: %s', str(self.input_names)) | |
| _log('input_dicts:\n - %s', '\n - '.join(str(it) for it in self.input_dicts.items())) | |
| _log('input_shapes:\n - %s', '\n - '.join(str(it) for it in self.input_shapes.items())) | |
| _log('preprocess_params:\n - %s', '\n - '.join(str(it) for it in self.preprocess_params.items())) | |
| #_log('label_names: %s', str(self.label_names)) | |
| _log('observer_names: %s', str(self.observer_names)) | |
| _log('monitor_variables: %s', str(self.monitor_variables)) | |
| if opts['weights'] is not None: | |
| if self.use_precomputed_weights: | |
| _log('weight: %s' % self.var_funcs[self.weight_name]) | |
| else: | |
| for k in ['reweight_method', 'reweight_basewgt', 'reweight_branches', 'reweight_bins', | |
| 'reweight_classes', 'class_weights', 'reweight_threshold', | |
| 'reweight_discard_under_overflow']: | |
| _log('%s: %s' % (k, getattr(self, k))) | |
| # parse config | |
| self.keep_branches = set() | |
| aux_branches = set() | |
| # selection | |
| if self.selection: | |
| aux_branches.update(_get_variable_names(self.selection)) | |
| # test time selection | |
| if self.test_time_selection: | |
| aux_branches.update(_get_variable_names(self.test_time_selection)) | |
| # var_funcs | |
| self.keep_branches.update(self.var_funcs.keys()) | |
| for expr in self.var_funcs.values(): | |
| aux_branches.update(_get_variable_names(expr)) | |
| # inputs | |
| for names in self.input_dicts.values(): | |
| self.keep_branches.update(names) | |
| # labels | |
| #self.keep_branches.update(self.label_names) | |
| # weight | |
| #if self.weight_name: | |
| # self.keep_branches.add(self.weight_name) | |
| # if not self.use_precomputed_weights: | |
| # aux_branches.update(self.reweight_branches) | |
| # aux_branches.update(self.reweight_classes) | |
| # observers | |
| self.keep_branches.update(self.observer_names) | |
| # monitor variables | |
| self.keep_branches.update(self.monitor_variables) | |
| # keep and drop | |
| self.drop_branches = (aux_branches - self.keep_branches) | |
| self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) #- {self.weight_name, } | |
| if print_info: | |
| _logger.debug('drop_branches:\n %s', ','.join(self.drop_branches)) | |
| _logger.debug('load_branches:\n %s', ','.join(self.load_branches)) | |
| def __getattr__(self, name): | |
| return self.options[name] | |
| def dump(self, fp): | |
| with open(fp, 'w') as f: | |
| yaml.safe_dump(self.options, f, sort_keys=False) | |
| def load(cls, fp, load_observers=True, load_reweight_info=True, extra_selection=None, extra_test_selection=None): | |
| with open(fp) as f: | |
| options = yaml.safe_load(f) | |
| if not load_observers: | |
| options['observers'] = None | |
| if not load_reweight_info: | |
| options['weights'] = None | |
| if extra_selection: | |
| options['selection'] = '(%s) & (%s)' % (options['selection'], extra_selection) | |
| if extra_test_selection: | |
| if 'test_time_selection' not in options: | |
| raise RuntimeError('`test_time_selection` is not defined in the yaml file!') | |
| options['test_time_selection'] = '(%s) & (%s)' % (options['test_time_selection'], extra_test_selection) | |
| return cls(**options) | |
| def copy(self): | |
| return self.__class__(print_info=False, **copy.deepcopy(self.options)) | |
| def __copy__(self): | |
| return self.copy() | |
| def __deepcopy__(self, memo): | |
| return self.copy() | |
| def export_json(self, fp): | |
| import json | |
| j = {'output_names': self.label_value, 'input_names': self.input_names} | |
| for k, v in self.input_dicts.items(): | |
| j[k] = {'var_names': v, 'var_infos': {}} | |
| for var_name in v: | |
| j[k]['var_length'] = self.preprocess_params[var_name]['length'] | |
| info = self.preprocess_params[var_name] | |
| j[k]['var_infos'][var_name] = { | |
| 'median': 0 if info['center'] is None else info['center'], | |
| 'norm_factor': info['scale'], | |
| 'replace_inf_value': 0, | |
| 'lower_bound': -1e32 if info['center'] is None else info['min'], | |
| 'upper_bound': 1e32 if info['center'] is None else info['max'], | |
| 'pad': info['pad_value'] | |
| } | |
| with open(fp, 'w') as f: | |
| json.dump(j, f, indent=2) | |