File size: 7,019 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b085dea
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84d750e
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# A simple wrapper to run the L-GATr model on HuggingFace spaces
import shutil
import glob
import argparse
import functools
import numpy as np
import math
import torch
import sys
import os
import wandb
import time
from pathlib import Path
from src.layers.object_cond import calc_eta_phi
torch.autograd.set_detect_anomaly(True)
from src.dataset.functions_data import get_batch
from src.dataset.functions_data import concat_events, Event, EventPFCands
from src.plotting.plot_event import plot_event
from src.dataset.dataset import EventDataset
from src.jetfinder.clustering import get_clustering_labels
from torch_scatter import scatter_sum

from src.utils.train_utils import (
    to_filelist,
    train_load,
    test_load,
    get_model,
    get_optimizer_and_scheduler,
    get_model_obj_score
)
from src.utils.paths import get_path
import warnings
import pickle
import os

import fastjet

def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
    args = argparse.ArgumentParser()
    model_path = f"models/{loss_str}/{train_dataset_str}.ckpt"
    args.spatial_part_only = True # LGATr
    args.load_model_weights =  model_path
    args.aug_soft = True # LGATr_GP etc.
    args.network_config = "src/1models/LGATr/lgatr.py"
    args.beta_type = "pt+bc"
    args.embed_as_vectors = False
    args.debug = False
    args.epsilon = 0.3
    args.gen_level = False
    args.parton_level = False
    args.global_features_obj_score = False
    args.gt_radius = 0.8
    args.no_pid = True
    args.hidden_mv_channels = 16
    args.hidden_s_channels = 64
    args.internal_dim = 128
    args.lorentz_norm = False
    args.min_cluster_size = 2
    args.min_samples = 1
    args.n_heads = 4
    args.num_blocks = 10
    args.scalars_oc=False

    dev = torch.device("cpu")
    model = get_model(args, dev)
    orig_model = model
    batch_config = {"use_p_xyz": True, "use_four_momenta": False}

    if "lgatr" in args.network_config.lower():
        batch_config = {"use_four_momenta": True}
    batch_config["no_pid"] = True

    print("batch_config:", batch_config)
    model.eval()

    # input text in format pt,eta,phi,mass,charge
    pt, eta, phi, mass, charge = [], [], [], [], []
    # now parse the input text
    for line in input_text.strip().split('\n'):
        values = list(map(float, line.split()))
        pt.append(values[0])
        eta.append(values[1])
        phi.append(values[2])
        mass.append(values[3])
        charge.append(int(values[4]))
    pt_quarks, eta_quarks, phi_quarks = [], [], []
    for line in input_text_quarks.strip().split("\n"):
        values = list(map(float, line.split()))
        pt_quarks.append(values[0])
        eta_quarks.append(values[1])
        phi_quarks.append(values[2])
    pid = torch.zeros(len(pt))
    pf_cand_jet_idx = [-1] * len(pt)

    pfcands = EventPFCands(pt, eta, phi, mass, charge, pid, pf_cand_jet_idx=pf_cand_jet_idx)
    n_soft = 0
    if "GP" in loss_str:
        n_soft = 500
    if n_soft > 0:
        pfcands = EventDataset.pfcands_add_soft_particles(pfcands, n_soft, random_generator=np.random.RandomState(seed=0))
    event = Event(pfcands=pfcands)
    event_batch = concat_events([event])
    batch, _ = get_batch(event_batch, batch_config, torch.zeros(len(pfcands)), test=True)

    with torch.no_grad():
        coords = model(batch, cpu_demo=True)[:, 1:4] # !!! Only use cpu_demo with batch size of 1 (quick fix for unavailability of xformers attention on CPU)
    clust_labels = get_clustering_labels(coords.detach().cpu().numpy(), batch.batch_idx, min_cluster_size=args.min_cluster_size, min_samples=args.min_samples, epsilon=args.epsilon)
    jets_pxyz = scatter_sum(torch.tensor(pfcands.pxyz), torch.tensor(clust_labels+1), dim=0)[1:]
    jets_pt = torch.norm(jets_pxyz[:, :2], p=2, dim=-1)
    filt = torch.where(jets_pt > 30)[0].tolist()
    jets_eta, jets_phi = calc_eta_phi(jets_pxyz, False)
    clust_assignment = {}
    for i in range(len(clust_labels)):
        if clust_labels[i] in filt and clust_labels[i] != -1:
            clust_assignment[i] = filt.index(clust_labels[i])
    jets_pt = jets_pt[filt]
    jets_eta = jets_eta[filt]
    jets_phi = jets_phi[filt]
    ak_pt, ak_eta, ak_phi, _, ak_assignment = EventDataset.get_jets_fastjets_raw_with_assignment(pfcands, fastjet.JetDefinition(fastjet.antikt_algorithm, 0.8), pt_cutoff=30)
    model_coords = calc_eta_phi(coords, return_stacked=0)
    clist = ['#1f78b4', '#b3df8a', '#33a02c', '#fb9a99', '#e31a1c', '#fdbe6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99',
             '#b15928']
    colors = {
        -1: "gray",
        0: clist[0],
        1: clist[1],
        2: clist[2],
        3: clist[3],
        4: clist[4],
        5: clist[5],
        6: clist[6],
        7: clist[7],
    }
    c = []
    c_ak = []
    for i in range(len(pfcands)):
        if i in ak_assignment:
            c_ak.append(colors.get(ak_assignment[i], "purple"))
        else:
            c_ak.append("gray")
        if i in clust_assignment:
            c.append(colors.get(clust_assignment[i], "gray"))
        else:
            c.append("gray")
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(1, 3, figsize=(10, 3.33)) # with AK colors, with model colors, with model colors in clustering space
    ax[0].set_title("Colors: AK clusters")
    ax[1].set_title("Colors: Model clusters")
    ax[2].set_title("Colors: Model clusters in cl. space")
    plot_event(event, colors=c_ak, ax=ax[0], jets=0)
    plot_event(event, colors=c, ax=ax[1], jets=0)
    plot_event(event, colors=c, ax=ax[2], custom_coords=model_coords, jets=0)
    model_jets, ak_jets = [], []

    for j in range(len(ak_pt)):
        if ak_pt[j] >= 30:
            ax[0].text(ak_eta[j] + 0.1, ak_phi[j] + 0.1,
                                   "pt=" + str(round(ak_pt[j], 1)), color="blue", fontsize=6, alpha=0.5)
            ak_jets.append({"pt": ak_pt[j], "eta": ak_eta[j], "phi": ak_phi[j]})
        if ak_pt[j] >= 100:
            for k in range(3):
                circle = plt.Circle((ak_eta[j], ak_phi[j]), 0.8, color="green", fill=False, alpha=.7)
                ax[k].add_artist(circle)


    for j in range(len(jets_pt)):
        if jets_pt[j] >= 30:
            ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1,
                                   "pt=" + str(round(jets_pt[j].item(), 1)), color="gray", fontsize=6, alpha=0.5)
            model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()})

        if jets_pt[j] >= 100:
            for k in range(3):
                circle = plt.Circle((jets_eta[j], jets_phi[j]), 0.7, color="blue", fill=False, alpha=.7)
                ax[k].add_artist(circle)
    for k in range(3):
        #for n in range(len(phi_quarks)):
        #   # add triangle symb
        ax[k].scatter(eta_quarks, phi_quarks, s=pt_quarks,  c="red", marker="^", alpha=0.3)
        ax[k].set_xlabel("$\eta$")
        ax[k].set_ylabel("$\phi$")
    fig.tight_layout()
    return model_jets, ak_jets, fig