Spaces:
Sleeping
Sleeping
Commit
·
4c79858
1
Parent(s):
0ce7f78
- app.py +2 -2
- src/model_wrapper_gradio.py +2 -2
app.py
CHANGED
@@ -37,8 +37,8 @@ def gradio_ui():
|
|
37 |
train_dataset_dropdown = gr.Dropdown(choices=["QCD", "900_03", "900_03+700_07", "700_07", "900_03+700_07+QCD"], label="Training Dataset", value="QCD")
|
38 |
|
39 |
with gr.Row():
|
40 |
-
subdataset_dropdown = gr.Dropdown(choices=[x for x in os.listdir("demo_datasets") if not x.startswith(".")], label="Subdataset")
|
41 |
-
event_idx_dropdown = gr.Dropdown(choices=list(range(20)), label="Event Index")
|
42 |
prefill_btn = gr.Button("Load Event from Dataset")
|
43 |
|
44 |
particles_text = gr.Textbox(label="Particles CSV (pt eta phi mass charge)", lines=6, interactive=True)
|
|
|
37 |
train_dataset_dropdown = gr.Dropdown(choices=["QCD", "900_03", "900_03+700_07", "700_07", "900_03+700_07+QCD"], label="Training Dataset", value="QCD")
|
38 |
|
39 |
with gr.Row():
|
40 |
+
subdataset_dropdown = gr.Dropdown(choices=[x for x in os.listdir("demo_datasets") if not x.startswith(".")], label="Subdataset", value="QCD")
|
41 |
+
event_idx_dropdown = gr.Dropdown(choices=list(range(20)), label="Event Index", value=15)
|
42 |
prefill_btn = gr.Button("Load Event from Dataset")
|
43 |
|
44 |
particles_text = gr.Textbox(label="Particles CSV (pt eta phi mass charge)", lines=6, interactive=True)
|
src/model_wrapper_gradio.py
CHANGED
@@ -41,7 +41,7 @@ def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
|
|
41 |
args.spatial_part_only = True # LGATr
|
42 |
args.load_model_weights = model_path
|
43 |
args.aug_soft = True # LGATr_GP etc.
|
44 |
-
args.network_config = "src/
|
45 |
args.beta_type = "pt+bc"
|
46 |
args.embed_as_vectors = False
|
47 |
args.debug = False
|
@@ -167,7 +167,7 @@ def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
|
|
167 |
for j in range(len(jets_pt)):
|
168 |
if jets_pt[j] >= 30:
|
169 |
ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1,
|
170 |
-
"pt=" + str(round(jets_pt[j].item(), 1)), color="
|
171 |
model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()})
|
172 |
|
173 |
if jets_pt[j] >= 100:
|
|
|
41 |
args.spatial_part_only = True # LGATr
|
42 |
args.load_model_weights = model_path
|
43 |
args.aug_soft = True # LGATr_GP etc.
|
44 |
+
args.network_config = "src/models/LGATr/lgatr.py"
|
45 |
args.beta_type = "pt+bc"
|
46 |
args.embed_as_vectors = False
|
47 |
args.debug = False
|
|
|
167 |
for j in range(len(jets_pt)):
|
168 |
if jets_pt[j] >= 30:
|
169 |
ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1,
|
170 |
+
"pt=" + str(round(jets_pt[j].item(), 1)), color="blue", fontsize=6, alpha=0.5)
|
171 |
model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()})
|
172 |
|
173 |
if jets_pt[j] >= 100:
|