import numpy as np import math import awkward as ak def _concat(arrays, axis=0): if len(arrays) == 0: return np.array([]) if isinstance(arrays[0], np.ndarray): return np.concatenate(arrays, axis=axis) else: return ak.concatenate(arrays, axis=axis) def _stack(arrays, axis=1): if len(arrays) == 0: return np.array([]) if isinstance(arrays[0], np.ndarray): return np.stack(arrays, axis=axis) else: return ak.concatenate(arrays, axis=axis) def _pad_vector(a, value=-1, dtype="float32"): maxlen = 2000 maxlen2 = 5 x = (np.ones((len(a), maxlen, maxlen2)) * value).astype(dtype) for idx, s in enumerate(a): for idx_vec, s_vec in enumerate(s): x[idx, idx_vec, : len(s_vec)] = s_vec return x def _pad(a, maxlen, value=0, dtype="float32"): if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen: return a elif isinstance(a, ak.Array): if a.ndim == 1: a = ak.unflatten(a, 1) a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value) return ak.values_astype(a, dtype) else: x = (np.ones((len(a), maxlen)) * value).astype(dtype) for idx, s in enumerate(a): if not len(s): continue trunc = s[:maxlen].astype(dtype) x[idx, : len(trunc)] = trunc return x def _repeat_pad(a, maxlen, shuffle=False, dtype="float32"): x = ak.to_numpy(ak.flatten(a)) x = np.tile(x, int(np.ceil(len(a) * maxlen / len(x)))) if shuffle: np.random.shuffle(x) x = x[: len(a) * maxlen].reshape((len(a), maxlen)) mask = _pad(ak.zeros_like(a), maxlen, value=1) x = _pad(a, maxlen) + mask * x return ak.values_astype(x, dtype) def _clip(a, a_min, a_max): try: return np.clip(a, a_min, a_max) except ValueError: return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a)) def _knn(support, query, k, n_jobs=1): from scipy.spatial import cKDTree kdtree = cKDTree(support) d, idx = kdtree.query(query, k, n_jobs=n_jobs) return idx def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1): assert len(supports) == len(queries) if maxlen_q is None: maxlen_q = maxlen_s batch_knn_idx = np.ones((len(supports), maxlen_q, k), dtype="int32") * ( maxlen_s - 1 ) for i, (s, q) in enumerate(zip(supports, queries)): batch_knn_idx[i, : len(q[:maxlen_q]), :] = _knn( s[:maxlen_s], q[:maxlen_q], k, n_jobs=n_jobs ).reshape( (-1, k) ) # (len(q), k) return batch_knn_idx def _batch_permute_indices(array, maxlen): batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1)) for i, a in enumerate(array): batch_permute_idx[i, : len(a)] = np.random.permutation(len(a[:maxlen])) return batch_permute_idx def _batch_argsort(array, maxlen): batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1)) for i, a in enumerate(array): batch_argsort_idx[i, : len(a)] = np.argsort(a[:maxlen]) return batch_argsort_idx def _batch_gather(array, indices): out = array.zeros_like() for i, (a, idx) in enumerate(zip(array, indices)): maxlen = min(len(a), len(idx)) out[i][:maxlen] = a[idx[:maxlen]] return out def _p4_from_pxpypze(px, py, pz, energy): import vector vector.register_awkward() return vector.zip({"px": px, "py": py, "pz": pz, "energy": energy}) def _p4_from_ptetaphie(pt, eta, phi, energy): import vector vector.register_awkward() return vector.zip({"pt": pt, "eta": eta, "phi": phi, "energy": energy}) def _p4_from_ptetaphim(pt, eta, phi, mass): import vector vector.register_awkward() return vector.zip({"pt": pt, "eta": eta, "phi": phi, "mass": mass}) def _get_variable_names(expr, exclude=["awkward", "ak", "np", "numpy", "math"]): import ast root = ast.parse(expr) return sorted( { node.id for node in ast.walk(root) if isinstance(node, ast.Name) and not node.id.startswith("_") } - set(exclude) ) def _eval_expr(expr, table): tmp = {k: table[k] for k in _get_variable_names(expr)} tmp.update( { "math": math, "np": np, "numpy": np, "ak": ak, "awkward": ak, "_concat": _concat, "_stack": _stack, "_pad": _pad, "_repeat_pad": _repeat_pad, "_clip": _clip, "_batch_knn": _batch_knn, "_batch_permute_indices": _batch_permute_indices, "_batch_argsort": _batch_argsort, "_batch_gather": _batch_gather, "_p4_from_pxpypze": _p4_from_pxpypze, "_p4_from_ptetaphie": _p4_from_ptetaphie, "_p4_from_ptetaphim": _p4_from_ptetaphim, } ) return eval(expr, tmp)