gregorkrzmanc's picture
.
e75a247
raw
history blame
5 kB
import numpy as np
import math
import awkward as ak
def _concat(arrays, axis=0):
if len(arrays) == 0:
return np.array([])
if isinstance(arrays[0], np.ndarray):
return np.concatenate(arrays, axis=axis)
else:
return ak.concatenate(arrays, axis=axis)
def _stack(arrays, axis=1):
if len(arrays) == 0:
return np.array([])
if isinstance(arrays[0], np.ndarray):
return np.stack(arrays, axis=axis)
else:
return ak.concatenate(arrays, axis=axis)
def _pad_vector(a, value=-1, dtype="float32"):
maxlen = 2000
maxlen2 = 5
x = (np.ones((len(a), maxlen, maxlen2)) * value).astype(dtype)
for idx, s in enumerate(a):
for idx_vec, s_vec in enumerate(s):
x[idx, idx_vec, : len(s_vec)] = s_vec
return x
def _pad(a, maxlen, value=0, dtype="float32"):
if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen:
return a
elif isinstance(a, ak.Array):
if a.ndim == 1:
a = ak.unflatten(a, 1)
a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value)
return ak.values_astype(a, dtype)
else:
x = (np.ones((len(a), maxlen)) * value).astype(dtype)
for idx, s in enumerate(a):
if not len(s):
continue
trunc = s[:maxlen].astype(dtype)
x[idx, : len(trunc)] = trunc
return x
def _repeat_pad(a, maxlen, shuffle=False, dtype="float32"):
x = ak.to_numpy(ak.flatten(a))
x = np.tile(x, int(np.ceil(len(a) * maxlen / len(x))))
if shuffle:
np.random.shuffle(x)
x = x[: len(a) * maxlen].reshape((len(a), maxlen))
mask = _pad(ak.zeros_like(a), maxlen, value=1)
x = _pad(a, maxlen) + mask * x
return ak.values_astype(x, dtype)
def _clip(a, a_min, a_max):
try:
return np.clip(a, a_min, a_max)
except ValueError:
return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a))
def _knn(support, query, k, n_jobs=1):
from scipy.spatial import cKDTree
kdtree = cKDTree(support)
d, idx = kdtree.query(query, k, n_jobs=n_jobs)
return idx
def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1):
assert len(supports) == len(queries)
if maxlen_q is None:
maxlen_q = maxlen_s
batch_knn_idx = np.ones((len(supports), maxlen_q, k), dtype="int32") * (
maxlen_s - 1
)
for i, (s, q) in enumerate(zip(supports, queries)):
batch_knn_idx[i, : len(q[:maxlen_q]), :] = _knn(
s[:maxlen_s], q[:maxlen_q], k, n_jobs=n_jobs
).reshape(
(-1, k)
) # (len(q), k)
return batch_knn_idx
def _batch_permute_indices(array, maxlen):
batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1))
for i, a in enumerate(array):
batch_permute_idx[i, : len(a)] = np.random.permutation(len(a[:maxlen]))
return batch_permute_idx
def _batch_argsort(array, maxlen):
batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1))
for i, a in enumerate(array):
batch_argsort_idx[i, : len(a)] = np.argsort(a[:maxlen])
return batch_argsort_idx
def _batch_gather(array, indices):
out = array.zeros_like()
for i, (a, idx) in enumerate(zip(array, indices)):
maxlen = min(len(a), len(idx))
out[i][:maxlen] = a[idx[:maxlen]]
return out
def _p4_from_pxpypze(px, py, pz, energy):
import vector
vector.register_awkward()
return vector.zip({"px": px, "py": py, "pz": pz, "energy": energy})
def _p4_from_ptetaphie(pt, eta, phi, energy):
import vector
vector.register_awkward()
return vector.zip({"pt": pt, "eta": eta, "phi": phi, "energy": energy})
def _p4_from_ptetaphim(pt, eta, phi, mass):
import vector
vector.register_awkward()
return vector.zip({"pt": pt, "eta": eta, "phi": phi, "mass": mass})
def _get_variable_names(expr, exclude=["awkward", "ak", "np", "numpy", "math"]):
import ast
root = ast.parse(expr)
return sorted(
{
node.id
for node in ast.walk(root)
if isinstance(node, ast.Name) and not node.id.startswith("_")
}
- set(exclude)
)
def _eval_expr(expr, table):
tmp = {k: table[k] for k in _get_variable_names(expr)}
tmp.update(
{
"math": math,
"np": np,
"numpy": np,
"ak": ak,
"awkward": ak,
"_concat": _concat,
"_stack": _stack,
"_pad": _pad,
"_repeat_pad": _repeat_pad,
"_clip": _clip,
"_batch_knn": _batch_knn,
"_batch_permute_indices": _batch_permute_indices,
"_batch_argsort": _batch_argsort,
"_batch_gather": _batch_gather,
"_p4_from_pxpypze": _p4_from_pxpypze,
"_p4_from_ptetaphie": _p4_from_ptetaphie,
"_p4_from_ptetaphim": _p4_from_ptetaphim,
}
)
return eval(expr, tmp)