File size: 2,561 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08310aa
 
d174ad1
e75a247
 
 
 
 
 
 
 
 
 
 
08310aa
e75a247
 
 
 
08310aa
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os

from src.model_wrapper_gradio import inference

# === Dummy file-based prefill function ===
def prefill_event(subdataset, event_idx):
    base_path = f"demo_datasets/{subdataset}/{event_idx}"
    try:
        with open(f"{base_path}.txt", "r") as f:
            particles_data = f.read()
    except FileNotFoundError:
        particles_data = "pt eta phi mass charge\n"

    try:
        with open(f"{base_path}_quarks.txt", "r") as f:
            quarks_data = f.read()
    except FileNotFoundError:
        quarks_data = "pt eta phi\n"

    return particles_data, quarks_data


#from huggingface_hub import snapshot_download
#snapshot_download(repo_id="gregorkrzmanc/jetclustering", local_dir="models/")
#snapshot_download(repo_id="gregorkrzmanc/jetclustering_demo", local_dir="demo_datasets/", repo_type="dataset")

# === Interface layout ===
def gradio_ui():
    with gr.Blocks() as demo:
        gr.Markdown("## Jet Clustering Demo")

        with gr.Row():
            loss_dropdown = gr.Dropdown(choices=["GP_IRC_SN", "GP_IRC_S", "GP", "base"], label="Loss Function", value="GP_IRC_SN")
            train_dataset_dropdown = gr.Dropdown(choices=["QCD", "900_03", "900_03+700_07", "700_07", "900_03+700_07+QCD"], label="Training Dataset", value="QCD")

        with gr.Row():
            subdataset_dropdown = gr.Dropdown(choices=[x for x in os.listdir("demo_datasets") if not x.startswith(".")], label="Subdataset")
            event_idx_dropdown = gr.Dropdown(choices=list(range(50)), label="Event Index")
        prefill_btn = gr.Button("Load Event from Dataset")

        particles_text = gr.Textbox(label="Particles CSV (pt eta phi mass charge)", lines=6, interactive=True)
        quarks_text = gr.Textbox(label="Quarks CSV (pt eta phi) - optional", lines=3, interactive=True)

        process_btn = gr.Button("Run Jet Clustering")

        image_output = gr.Plot(label="Output")
        model_jets_output = gr.JSON(label="Model Jets")
        antikt_jets_output = gr.JSON(label="Anti-kt Jets")

        prefill_btn.click(fn=prefill_event,
                          inputs=[subdataset_dropdown, event_idx_dropdown],
                          outputs=[particles_text, quarks_text])


        process_btn.click(fn=inference,
                          inputs=[loss_dropdown, train_dataset_dropdown, particles_text, quarks_text],
                          outputs=[model_jets_output, antikt_jets_output, image_output])

    return demo


demo = gradio_ui()
demo.launch()