File size: 14,296 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import time
import glob
import copy
import numpy as np
import awkward as ak

from src.logger.logger import _logger
from src.data.tools import _get_variable_names, _eval_expr
from src.data.fileio import _read_files


def _apply_selection(table, selection):
    if selection is None:
        return table
    selected = ak.values_astype(_eval_expr(selection, table), 'bool')
    return table[selected]


def _build_new_variables(table, funcs):
    if funcs is None:
        return table
    for k, expr in funcs.items():
        if k in table.fields:
            continue
        table[k] = _eval_expr(expr, table)
    return table


def _clean_up(table, drop_branches):
    columns = [k for k in table.fields if k not in drop_branches]
    return table[columns]


def _build_weights(table, data_config, reweight_hists=None, warn=_logger.warning):
    if data_config.weight_name is None:
        raise RuntimeError('Error when building weights: `weight_name` is None!')
    if data_config.use_precomputed_weights:
        return ak.to_numpy(table[data_config.weight_name])
    else:
        x_var, y_var = data_config.reweight_branches
        x_bins, y_bins = data_config.reweight_bins
        rwgt_sel = None
        if data_config.reweight_discard_under_overflow:
            rwgt_sel = (table[x_var] >= min(x_bins)) & (table[x_var] <= max(x_bins)) & \
                (table[y_var] >= min(y_bins)) & (table[y_var] <= max(y_bins))
        # init w/ wgt=0: events not belonging to any class in `reweight_classes` will get a weight of 0 at the end
        wgt = np.zeros(len(table), dtype='float32')
        sum_evts = 0
        if reweight_hists is None:
            reweight_hists = data_config.reweight_hists
        for label, hist in reweight_hists.items():
            pos = table[label] == 1
            if rwgt_sel is not None:
                pos = (pos & rwgt_sel)
            rwgt_x_vals = ak.to_numpy(table[x_var][pos])
            rwgt_y_vals = ak.to_numpy(table[y_var][pos])
            x_indices = np.clip(np.digitize(
                rwgt_x_vals, x_bins) - 1, a_min=0, a_max=len(x_bins) - 2)
            y_indices = np.clip(np.digitize(
                rwgt_y_vals, y_bins) - 1, a_min=0, a_max=len(y_bins) - 2)
            wgt[pos] = hist[x_indices, y_indices]
            sum_evts += np.sum(pos)
        if sum_evts != len(table):
            warn(
                'Not all selected events used in the reweighting. '
                'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings '
                '(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).',
            )
        if data_config.reweight_basewgt:
            wgt *= ak.to_numpy(table[data_config.basewgt_name])
        return wgt


class AutoStandardizer(object):
    r"""AutoStandardizer.
    Class to compute the variable standardization information.
    Arguments:
        filelist (list): list of files to be loaded.
        data_config (DataConfig): object containing data format information.
    """

    def __init__(self, filelist, data_config):
        if isinstance(filelist, dict):
            filelist = sum(filelist.values(), [])
        self._filelist = filelist if isinstance(
            filelist, (list, tuple)) else glob.glob(filelist)
        self._data_config = data_config.copy()
        self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1))

    def read_file(self, filelist):
        self.keep_branches = set()
        self.load_branches = set()
        for k, params in self._data_config.preprocess_params.items():
            if params['center'] == 'auto':
                self.keep_branches.add(k)
                if k in self._data_config.var_funcs:
                    expr = self._data_config.var_funcs[k]
                    self.load_branches.update(_get_variable_names(expr))
                else:
                    self.load_branches.add(k)
        if self._data_config.selection:
            self.load_branches.update(_get_variable_names(self._data_config.selection))
        _logger.debug('[AutoStandardizer] keep_branches:\n  %s', ','.join(self.keep_branches))
        _logger.debug('[AutoStandardizer] load_branches:\n  %s', ','.join(self.load_branches))

        table = _read_files(filelist, self.load_branches, self.load_range,
                            show_progressbar=True, treename=self._data_config.treename)
        table = _apply_selection(table, self._data_config.selection)
        table = _build_new_variables(
            table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
        table = _clean_up(table, self.load_branches - self.keep_branches)
        return table

    def make_preprocess_params(self, table):
        _logger.info('Using %d events to calculate standardization info', len(table))
        preprocess_params = copy.deepcopy(self._data_config.preprocess_params)
        for k, params in self._data_config.preprocess_params.items():
            if params['center'] == 'auto':
                if k.endswith('_mask'):
                    params['center'] = None
                else:
                    a = ak.to_numpy(ak.flatten(table[k], axis=None))
                    # check for NaN
                    if np.any(np.isnan(a)):
                        _logger.warning('[AutoStandardizer] Found NaN in `%s`, will convert it to 0.', k)
                        time.sleep(10)
                        a = np.nan_to_num(a)
                    low, center, high = np.percentile(a, [16, 50, 84])
                    scale = max(high - center, center - low)
                    scale = 1 if scale == 0 else 1. / scale
                    params['center'] = float(center)
                    params['scale'] = float(scale)
                preprocess_params[k] = params
                _logger.info('[AutoStandardizer] %s low=%s, center=%s, high=%s, scale=%s', k, low, center, high, scale)
        return preprocess_params

    def produce(self, output=None):
        table = self.read_file(self._filelist)
        preprocess_params = self.make_preprocess_params(table)
        self._data_config.preprocess_params = preprocess_params
        # must also propogate the changes to `data_config.options` so it can be persisted
        self._data_config.options['preprocess']['params'] = preprocess_params
        if output:
            _logger.info(
                'Writing YAML file w/ auto-generated preprocessing info to %s' % output)
            self._data_config.dump(output)
        return self._data_config


class WeightMaker(object):
    r"""WeightMaker.
    Class to make reweighting information.
    Arguments:
        filelist (list): list of files to be loaded.
        data_config (DataConfig): object containing data format information.
    """

    def __init__(self, filelist, data_config):
        if isinstance(filelist, dict):
            filelist = sum(filelist.values(), [])
        self._filelist = filelist if isinstance(filelist, (list, tuple)) else glob.glob(filelist)
        self._data_config = data_config.copy()

    def read_file(self, filelist):
        self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes +
                                 (self._data_config.basewgt_name,))
        self.load_branches = set()
        for k in self.keep_branches:
            if k in self._data_config.var_funcs:
                expr = self._data_config.var_funcs[k]
                self.load_branches.update(_get_variable_names(expr))
            else:
                self.load_branches.add(k)
        if self._data_config.selection:
            self.load_branches.update(_get_variable_names(self._data_config.selection))
        _logger.debug('[WeightMaker] keep_branches:\n  %s', ','.join(self.keep_branches))
        _logger.debug('[WeightMaker] load_branches:\n  %s', ','.join(self.load_branches))
        table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename)
        table = _apply_selection(table, self._data_config.selection)
        table = _build_new_variables(
            table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
        table = _clean_up(table, self.load_branches - self.keep_branches)
        return table

    def make_weights(self, table):
        x_var, y_var = self._data_config.reweight_branches
        x_bins, y_bins = self._data_config.reweight_bins
        if not self._data_config.reweight_discard_under_overflow:
            # clip variables to be within bin ranges
            x_min, x_max = min(x_bins), max(x_bins)
            y_min, y_max = min(y_bins), max(y_bins)
            _logger.info(f'Clipping `{x_var}` to [{x_min}, {x_max}] to compute the shapes for reweighting.')
            _logger.info(f'Clipping `{y_var}` to [{y_min}, {y_max}] to compute the shapes for reweighting.')
            table[x_var] = np.clip(table[x_var], min(x_bins), max(x_bins))
            table[y_var] = np.clip(table[y_var], min(y_bins), max(y_bins))

        _logger.info('Using %d events to make weights', len(table))

        sum_evts = 0
        max_weight = 0.9
        raw_hists = {}
        class_events = {}
        result = {}
        for label in self._data_config.reweight_classes:
            pos = (table[label] == 1)
            x = ak.to_numpy(table[x_var][pos])
            y = ak.to_numpy(table[y_var][pos])
            hist, _, _ = np.histogram2d(x, y, bins=self._data_config.reweight_bins)
            _logger.info('%s (unweighted):\n %s', label, str(hist.astype('int64')))
            sum_evts += hist.sum()
            if self._data_config.reweight_basewgt:
                w = ak.to_numpy(table[self._data_config.basewgt_name][pos])
                hist, _, _ = np.histogram2d(x, y, weights=w, bins=self._data_config.reweight_bins)
                _logger.info('%s (weighted):\n %s', label, str(hist.astype('float32')))
            raw_hists[label] = hist.astype('float32')
            result[label] = hist.astype('float32')
        if sum_evts != len(table):
            _logger.warning(
                'Only %d (out of %d) events actually used in the reweighting. '
                'Check consistency between `selection` and `reweight_classes` definition, or with the `reweight_vars` binnings '
                '(under- and overflow bins are discarded by default, unless `reweight_discard_under_overflow` is set to `False` in the `weights` section).',
                sum_evts, len(table))
            time.sleep(10)

        if self._data_config.reweight_method == 'flat':
            for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
                hist = result[label]
                threshold_ = np.median(hist[hist > 0]) * 0.01
                nonzero_vals = hist[hist > threshold_]
                min_val, med_val = np.min(nonzero_vals), np.median(hist)  # not really used
                ref_val = np.percentile(nonzero_vals, self._data_config.reweight_threshold)
                _logger.debug('label:%s, median=%f, min=%f, ref=%f, ref/min=%f' %
                              (label, med_val, min_val, ref_val, ref_val / min_val))
                # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
                wgt = np.clip(np.nan_to_num(ref_val / hist, posinf=0), 0, 1)
                result[label] = wgt
                # divide by classwgt here will effective increase the weight later
                class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
        elif self._data_config.reweight_method == 'ref':
            # use class 0 as the reference
            hist_ref = raw_hists[self._data_config.reweight_classes[0]]
            for label, classwgt in zip(self._data_config.reweight_classes, self._data_config.class_weights):
                # wgt: bins w/ 0 elements will get a weight of 0; bins w/ content<ref_val will get 1
                ratio = np.nan_to_num(hist_ref / result[label], posinf=0)
                upper = np.percentile(ratio[ratio > 0], 100 - self._data_config.reweight_threshold)
                wgt = np.clip(ratio / upper, 0, 1)  # -> [0,1]
                result[label] = wgt
                # divide by classwgt here will effective increase the weight later
                class_events[label] = np.sum(raw_hists[label] * wgt) / classwgt
        # ''equalize'' all classes
        # multiply by max_weight (<1) to add some randomness in the sampling
        min_nevt = min(class_events.values()) * max_weight
        for label in self._data_config.reweight_classes:
            class_wgt = float(min_nevt) / class_events[label]
            result[label] *= class_wgt

        if self._data_config.reweight_basewgt:
            wgts = _build_weights(table, self._data_config, reweight_hists=result)
            wgt_ref = np.percentile(wgts, 100 - self._data_config.reweight_threshold)
            _logger.info('Set overall reweighting scale factor (%d threshold) to %s (max %s)' %
                         (100 - self._data_config.reweight_threshold, wgt_ref, np.max(wgts)))
            for label in self._data_config.reweight_classes:
                result[label] /= wgt_ref

        _logger.info('weights:')
        for label in self._data_config.reweight_classes:
            _logger.info('%s:\n %s', label, str(result[label]))

        _logger.info('Raw hist * weights:')
        for label in self._data_config.reweight_classes:
            _logger.info('%s:\n %s', label, str((raw_hists[label] * result[label]).astype('int32')))

        return result

    def produce(self, output=None):
        table = self.read_file(self._filelist)
        wgts = self.make_weights(table)
        self._data_config.reweight_hists = wgts
        # must also propogate the changes to `data_config.options` so it can be persisted
        self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()}
        if output:
            _logger.info('Writing YAML file w/ reweighting info to %s' % output)
            self._data_config.dump(output)
        return self._data_config