Spaces:
Sleeping
Sleeping
| import time | |
| import glob | |
| import copy | |
| import numpy as np | |
| import awkward as ak | |
| from src.logger.logger import _logger | |
| from src.data.tools import _get_variable_names, _eval_expr | |
| from src.data.fileio import _read_files | |
| def _apply_selection(table, selection): | |
| if selection is None: | |
| return table | |
| selected = ak.values_astype(_eval_expr(selection, table), 'bool') | |
| return table[selected] | |
| def _build_new_variables(table, funcs): | |
| if funcs is None: | |
| return table | |
| for k, expr in funcs.items(): | |
| if k in table.fields: | |
| continue | |
| table[k] = _eval_expr(expr, table) | |
| return table | |
| def _clean_up(table, drop_branches): | |
| columns = [k for k in table.fields if k not in drop_branches] | |
| return table[columns] | |
| def _build_weights(table, data_config, reweight_hists=None, warn=_logger.warning): | |
| if data_config.weight_name is None: | |
| raise RuntimeError('Error when building weights: `weight_name` is None!') | |
| if data_config.use_precomputed_weights: | |
| return ak.to_numpy(table[data_config.weight_name]) | |
| else: | |
| x_var, y_var = data_config.reweight_branches | |
| x_bins, y_bins = data_config.reweight_bins | |
| rwgt_sel = None | |
| if data_config.reweight_discard_under_overflow: | |
| rwgt_sel = (table[x_var] >= min(x_bins)) & (table[x_var] <= max(x_bins)) & \ | |
| (table[y_var] >= min(y_bins)) & (table[y_var] <= max(y_bins)) | |
| # init w/ wgt=0: events not belonging to any class in `reweight_classes` will get a weight of 0 at the end | |
| wgt = np.zeros(len(table), dtype='float32') | |
| sum_evts = 0 | |
| if reweight_hists is None: | |
| reweight_hists = data_config.reweight_hists | |
| for label, hist in reweight_hists.items(): | |
| pos = table[label] == 1 | |
| if rwgt_sel is not None: | |
| pos = (pos & rwgt_sel) | |
| rwgt_x_vals = ak.to_numpy(table[x_var][pos]) | |
| rwgt_y_vals = ak.to_numpy(table[y_var][pos]) | |
| x_indices = np.clip(np.digitize( | |
| rwgt_x_vals, x_bins) - 1, a_min=0, a_max=len(x_bins) - 2) | |
| y_indices = np.clip(np.digitize( | |
| rwgt_y_vals, y_bins) - 1, a_min=0, a_max=len(y_bins) - 2) | |
| wgt[pos] = hist[x_indices, y_indices] | |
| sum_evts += np.sum(pos) | |
| if sum_evts != len(table): | |
| warn( | |
| 'Not all selected events used in the reweighting. ' | |
| 'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings ' | |
| '(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).', | |
| ) | |
| if data_config.reweight_basewgt: | |
| wgt *= ak.to_numpy(table[data_config.basewgt_name]) | |
| return wgt | |
| class AutoStandardizer(object): | |
| r"""AutoStandardizer. | |
| Class to compute the variable standardization information. | |
| Arguments: | |
| filelist (list): list of files to be loaded. | |
| data_config (DataConfig): object containing data format information. | |
| """ | |
| def __init__(self, filelist, data_config): | |
| if isinstance(filelist, dict): | |
| filelist = sum(filelist.values(), []) | |
| self._filelist = filelist if isinstance( | |
| filelist, (list, tuple)) else glob.glob(filelist) | |
| self._data_config = data_config.copy() | |
| self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1)) | |
| def read_file(self, filelist): | |
| self.keep_branches = set() | |
| self.load_branches = set() | |
| for k, params in self._data_config.preprocess_params.items(): | |
| if params['center'] == 'auto': | |
| self.keep_branches.add(k) | |
| if k in self._data_config.var_funcs: | |
| expr = self._data_config.var_funcs[k] | |
| self.load_branches.update(_get_variable_names(expr)) | |
| else: | |
| self.load_branches.add(k) | |
| if self._data_config.selection: | |
| self.load_branches.update(_get_variable_names(self._data_config.selection)) | |
| _logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches)) | |
| _logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches)) | |
| table = _read_files(filelist, self.load_branches, self.load_range, | |
| show_progressbar=True, treename=self._data_config.treename) | |
| table = _apply_selection(table, self._data_config.selection) | |
| table = _build_new_variables( | |
| table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) | |
| table = _clean_up(table, self.load_branches - self.keep_branches) | |
| return table | |
| def make_preprocess_params(self, table): | |
| _logger.info('Using %d events to calculate standardization info', len(table)) | |
| preprocess_params = copy.deepcopy(self._data_config.preprocess_params) | |
| for k, params in self._data_config.preprocess_params.items(): | |
| if params['center'] == 'auto': | |
| if k.endswith('_mask'): | |
| params['center'] = None | |
| else: | |
| a = ak.to_numpy(ak.flatten(table[k], axis=None)) | |
| # check for NaN | |
| if np.any(np.isnan(a)): | |
| _logger.warning('[AutoStandardizer] Found NaN in `%s`, will convert it to 0.', k) | |
| time.sleep(10) | |
| a = np.nan_to_num(a) | |
| low, center, high = np.percentile(a, [16, 50, 84]) | |
| scale = max(high - center, center - low) | |
| scale = 1 if scale == 0 else 1. / scale | |
| params['center'] = float(center) | |
| params['scale'] = float(scale) | |
| preprocess_params[k] = params | |
| _logger.info('[AutoStandardizer] %s low=%s, center=%s, high=%s, scale=%s', k, low, center, high, scale) | |
| return preprocess_params | |
| def produce(self, output=None): | |
| table = self.read_file(self._filelist) | |
| preprocess_params = self.make_preprocess_params(table) | |
| self._data_config.preprocess_params = preprocess_params | |
| # must also propogate the changes to `data_config.options` so it can be persisted | |
| self._data_config.options['preprocess']['params'] = preprocess_params | |
| if output: | |
| _logger.info( | |
| 'Writing YAML file w/ auto-generated preprocessing info to %s' % output) | |
| self._data_config.dump(output) | |
| return self._data_config | |
| class WeightMaker(object): | |
| r"""WeightMaker. | |
| Class to make reweighting information. | |
| Arguments: | |
| filelist (list): list of files to be loaded. | |
| data_config (DataConfig): object containing data format information. | |
| """ | |
| def __init__(self, filelist, data_config): | |
| if isinstance(filelist, dict): | |
| filelist = sum(filelist.values(), []) | |
| self._filelist = filelist if isinstance(filelist, (list, tuple)) else glob.glob(filelist) | |
| self._data_config = data_config.copy() | |
| def read_file(self, filelist): | |
| self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes + | |
| (self._data_config.basewgt_name,)) | |
| self.load_branches = set() | |
| for k in self.keep_branches: | |
| if k in self._data_config.var_funcs: | |
| expr = self._data_config.var_funcs[k] | |
| self.load_branches.update(_get_variable_names(expr)) | |
| else: | |
| self.load_branches.add(k) | |
| if self._data_config.selection: | |
| self.load_branches.update(_get_variable_names(self._data_config.selection)) | |
| _logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches)) | |
| _logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches)) | |
| table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename) | |
| table = _apply_selection(table, self._data_config.selection) | |
| table = _build_new_variables( | |
| table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) | |
| table = _clean_up(table, self.load_branches - self.keep_branches) | |
| return table | |
| def make_weights(self, table): | |
| x_var, y_var = self._data_config.reweight_branches | |
| x_bins, y_bins = self._data_config.reweight_bins | |
| if not self._data_config.reweight_discard_under_overflow: | |
| # clip variables to be within bin ranges | |
| x_min, x_max = min(x_bins), max(x_bins) | |
| y_min, y_max = min(y_bins), max(y_bins) | |
| _logger.info(f'Clipping `{x_var}` to [{x_min}, {x_max}] to compute the shapes for reweighting.') | |
| _logger.info(f'Clipping `{y_var}` to [{y_min}, {y_max}] to compute the shapes for reweighting.') | |
| table[x_var] = np.clip(table[x_var], min(x_bins), max(x_bins)) | |
| table[y_var] = np.clip(table[y_var], min(y_bins), max(y_bins)) | |
| _logger.info('Using %d events to make weights', len(table)) | |
| sum_evts = 0 | |
| max_weight = 0.9 | |
| raw_hists = {} | |
| class_events = {} | |
| result = {} | |
| for label in self._data_config.reweight_classes: | |
| pos = (table[label] == 1) | |
| x = ak.to_numpy(table[x_var][pos]) | |
| y = ak.to_numpy(table[y_var][pos]) | |
| hist, _, _ = np.histogram2d(x, y, bins=self._data_config.reweight_bins) | |
| _logger.info('%s (unweighted):\n %s', label, str(hist.astype('int64'))) | |
| sum_evts += hist.sum() | |
| if self._data_config.reweight_basewgt: | |
| w = ak.to_numpy(table[self._data_config.basewgt_name][pos]) | |
| hist, _, _ = np.histogram2d(x, y, weights=w, bins=self._data_config.reweight_bins) | |
| _logger.info('%s (weighted):\n %s', label, str(hist.astype('float32'))) | |
| raw_hists[label] = hist.astype('float32') | |
| result[label] = hist.astype('float32') | |
| if sum_evts != len(table): | |
| _logger.warning( | |
| 'Only %d (out of %d) events actually used in the reweighting. ' | |
| 'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings ' | |
| '(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).', | |
| sum_evts, len(table)) | |
| time.sleep(10) | |
| if self._data_config.reweight_method == 'flat': | |
| for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights): | |
| hist = result[label] | |
| threshold_ = np.median(hist[hist > 0]) * 0.01 | |
| nonzero_vals = hist[hist > threshold_] | |
| min_val, med_val = np.min(nonzero_vals), np.median(hist) # not really used | |
| ref_val = np.percentile(nonzero_vals, self._data_config.reweight_threshold) | |
| _logger.debug('label:%s, median=%f, min=%f, ref=%f, ref/min=%f' % | |
| (label, med_val, min_val, ref_val, ref_val / min_val)) | |
| # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1 | |
| wgt = np.clip(np.nan_to_num(ref_val / hist, posinf=0), 0, 1) | |
| result[label] = wgt | |
| # divide by classwgt here will effective increase the weight later | |
| class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt | |
| elif self._data_config.reweight_method == 'ref': | |
| # use class 0 as the reference | |
| hist_ref = raw_hists[self._data_config.reweight_classes[0]] | |
| for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights): | |
| # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1 | |
| ratio = np.nan_to_num(hist_ref / result[label], posinf=0) | |
| upper = np.percentile(ratio[ratio > 0], 100 - self._data_config.reweight_threshold) | |
| wgt = np.clip(ratio / upper, 0, 1) # -> [0,1] | |
| result[label] = wgt | |
| # divide by classwgt here will effective increase the weight later | |
| class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt | |
| # ''equalize'' all classes | |
| # multiply by max_weight (<1) to add some randomness in the sampling | |
| min_nevt = min(class_events.values()) * max_weight | |
| for label in self._data_config.reweight_classes: | |
| class_wgt = float(min_nevt) / class_events[label] | |
| result[label] *= class_wgt | |
| if self._data_config.reweight_basewgt: | |
| wgts = _build_weights(table, self._data_config, reweight_hists=result) | |
| wgt_ref = np.percentile(wgts, 100 - self._data_config.reweight_threshold) | |
| _logger.info('Set overall reweighting scale factor (%d threshold) to %s (max %s)' % | |
| (100 - self._data_config.reweight_threshold, wgt_ref, np.max(wgts))) | |
| for label in self._data_config.reweight_classes: | |
| result[label] /= wgt_ref | |
| _logger.info('weights:') | |
| for label in self._data_config.reweight_classes: | |
| _logger.info('%s:\n %s', label, str(result[label])) | |
| _logger.info('Raw hist * weights:') | |
| for label in self._data_config.reweight_classes: | |
| _logger.info('%s:\n %s', label, str((raw_hists[label] * result[label]).astype('int32'))) | |
| return result | |
| def produce(self, output=None): | |
| table = self.read_file(self._filelist) | |
| wgts = self.make_weights(table) | |
| self._data_config.reweight_hists = wgts | |
| # must also propogate the changes to `data_config.options` so it can be persisted | |
| self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()} | |
| if output: | |
| _logger.info('Writing YAML file w/ reweighting info to %s' % output) | |
| self._data_config.dump(output) | |
| return self._data_config |