gregorkrzmanc commited on
Commit
4c79858
·
1 Parent(s): 0ce7f78
Files changed (2) hide show
  1. app.py +2 -2
  2. src/model_wrapper_gradio.py +2 -2
app.py CHANGED
@@ -37,8 +37,8 @@ def gradio_ui():
37
  train_dataset_dropdown = gr.Dropdown(choices=["QCD", "900_03", "900_03+700_07", "700_07", "900_03+700_07+QCD"], label="Training Dataset", value="QCD")
38
 
39
  with gr.Row():
40
- subdataset_dropdown = gr.Dropdown(choices=[x for x in os.listdir("demo_datasets") if not x.startswith(".")], label="Subdataset")
41
- event_idx_dropdown = gr.Dropdown(choices=list(range(20)), label="Event Index")
42
  prefill_btn = gr.Button("Load Event from Dataset")
43
 
44
  particles_text = gr.Textbox(label="Particles CSV (pt eta phi mass charge)", lines=6, interactive=True)
 
37
  train_dataset_dropdown = gr.Dropdown(choices=["QCD", "900_03", "900_03+700_07", "700_07", "900_03+700_07+QCD"], label="Training Dataset", value="QCD")
38
 
39
  with gr.Row():
40
+ subdataset_dropdown = gr.Dropdown(choices=[x for x in os.listdir("demo_datasets") if not x.startswith(".")], label="Subdataset", value="QCD")
41
+ event_idx_dropdown = gr.Dropdown(choices=list(range(20)), label="Event Index", value=15)
42
  prefill_btn = gr.Button("Load Event from Dataset")
43
 
44
  particles_text = gr.Textbox(label="Particles CSV (pt eta phi mass charge)", lines=6, interactive=True)
src/model_wrapper_gradio.py CHANGED
@@ -41,7 +41,7 @@ def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
41
  args.spatial_part_only = True # LGATr
42
  args.load_model_weights = model_path
43
  args.aug_soft = True # LGATr_GP etc.
44
- args.network_config = "src/1models/LGATr/lgatr.py"
45
  args.beta_type = "pt+bc"
46
  args.embed_as_vectors = False
47
  args.debug = False
@@ -167,7 +167,7 @@ def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
167
  for j in range(len(jets_pt)):
168
  if jets_pt[j] >= 30:
169
  ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1,
170
- "pt=" + str(round(jets_pt[j].item(), 1)), color="gray", fontsize=6, alpha=0.5)
171
  model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()})
172
 
173
  if jets_pt[j] >= 100:
 
41
  args.spatial_part_only = True # LGATr
42
  args.load_model_weights = model_path
43
  args.aug_soft = True # LGATr_GP etc.
44
+ args.network_config = "src/models/LGATr/lgatr.py"
45
  args.beta_type = "pt+bc"
46
  args.embed_as_vectors = False
47
  args.debug = False
 
167
  for j in range(len(jets_pt)):
168
  if jets_pt[j] >= 30:
169
  ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1,
170
+ "pt=" + str(round(jets_pt[j].item(), 1)), color="blue", fontsize=6, alpha=0.5)
171
  model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()})
172
 
173
  if jets_pt[j] >= 100: