File size: 6,759 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
0c251ef
 
 
 
 
 
 
 
14ce5a9
 
 
 
 
 
 
 
 
 
6474836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14ce5a9
 
b460782
 
14ce5a9
 
 
 
 
 
5a9bf9d
 
 
 
14ce5a9
 
 
 
b460782
 
 
 
 
 
 
 
 
7d3842f
 
 
b460782
 
14ce5a9
7d3842f
 
b460782
14ce5a9
5a9bf9d
14ce5a9
4968738
 
14ce5a9
5a9bf9d
4968738
5a9bf9d
 
7d3842f
14ce5a9
 
7d3842f
 
14ce5a9
7d3842f
 
 
b460782
 
14ce5a9
 
 
 
 
eb8cda0
9cd76aa
eb8cda0
 
 
 
 
 
 
 
 
 
 
 
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b460782
14ce5a9
 
 
 
eb8cda0
b460782
 
 
 
14ce5a9
b460782
14ce5a9
b460782
14ce5a9
 
b460782
 
14ce5a9
 
 
7d3842f
14ce5a9
7d3842f
14ce5a9
 
 
 
7d3842f
 
14ce5a9
 
 
 
 
 
 
 
b460782
14ce5a9
 
 
b460782
14ce5a9
 
 
b460782
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import spaces
import gradio as gr
from src.data_processing import pil_to_tensor, tensor_to_pil
from PIL import Image
from src.model_processing import get_model
from huggingface_hub import snapshot_download
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

MODEL_DIR = "./VTBench_models"
# if not os.path.exists(MODEL_DIR):
#     print("Downloading VTBench_models from Hugging Face...")
#     snapshot_download(
#         repo_id="huaweilin/VTBench_models",
#         local_dir=MODEL_DIR,
#         local_dir_use_symlinks=False
#     )
#     print("Download complete.")

example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]

model_name_mapping = {
    "SD3.5L": "SD3.5L",
    "chameleon": "Chameleon",
    # "flowmo_lo": "FlowMo Lo",
    # "flowmo_hi": "FlowMo Hi",
    # "gpt4o": "GPT-4o",
    "janus_pro_1b": "Janus Pro 1B/7B",
    "llamagen-ds8": "LlamaGen ds8",
    "llamagen-ds16": "LlamaGen ds16",
    "llamagen-ds16-t2i": "LlamaGen ds16 T2I",
    "maskbit_16bit": "MaskBiT 16bit",
    "maskbit_18bit": "MaskBiT 18bit",
    "open_magvit2": "OpenMagViT",
    "titok_b64": "Titok-b64",
    "titok_bl64": "Titok-bl64",
    "titok_s128": "Titok-s128",
    "titok_bl128": "Titok-bl128",
    "titok_l32": "Titok-l32",
    "titok_sl256": "Titok-sl256",
    "var_256": "VAR-256",
    "var_512": "VAR-512",
    "FLUX.1-dev": "FLUX.1-dev",
    "infinity_d32": "Infinity-d32",
    "infinity_d64": "Infinity-d64",
    "bsqvit": "BSQ-VIT",
}

display_to_internal = {v: k for k, v in model_name_mapping.items()}

def load_model(model_name):
    model, data_params = get_model(MODEL_DIR, model_name)
    model = model.to(device)
    model.eval()
    return model, data_params

# model_dict = {
#     model_name: load_model(model_name)
#     for model_name in model_name_mapping
# }

placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))

@spaces.GPU
def process_selected_models(uploaded_image, selected_display_names):
    if uploaded_image is None:
        return [gr.update(value="⚠️  Please upload an image before processing.", visible=True)] + \
               [gr.update() for _ in model_name_mapping]

    if not selected_display_names:
        return [gr.update(value="⚠️  Please select at least one model.", visible=True)] + \
               [gr.update() for _ in model_name_mapping]

    selected_results = []
    placeholder_results = []

    selected_internal = [display_to_internal[d] for d in selected_display_names]

    for model_name in model_name_mapping:
        label = model_name_mapping[model_name]

        if model_name in selected_internal:
            try:
                model, data_params = load_model(model_name)
                pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
                with torch.no_grad():
                    output = model(pixel_values)[0]
                reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)

                del model, pixel_values, output
                torch.cuda.empty_cache()

                result = gr.update(value=reconstructed_image, label=label)
            except Exception as e:
                print(f"Error in model {model_name}: {e}")
                result = gr.update(value=placeholder_image, label=f"{label} (Error)")
            selected_results.append(result)
        else:
            result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
            placeholder_results.append(result)

    return [gr.update(visible=False)] + selected_results + placeholder_results


with gr.Blocks() as demo:
    gr.Markdown("## VTBench")
    gr.Markdown("---")

    gr.Markdown("👋 **Welcome to VTBench!** Upload an image, select models, and click 'Start Processing' to compare results side by side.")
    gr.Markdown("🔗 **Check out our GitHub repo:** [https://github.com/huawei-lin/VTBench](https://github.com/huawei-lin/VTBench)")
    with gr.Accordion("📘 Full Instructions", open=False):
        gr.Markdown("""
**VTBench User Guide**

- **Upload an image** or click one of the example images.
- **Select one or more models** from the list.
- Click **Start Processing** to run inference.
- Selected model outputs appear first, others show placeholders.

⚠️ *Each model is downloaded on first use. Please wait patiently the first time you run a model.*
""")

    image_input = gr.Image(
        type="pil",
        label="Upload an image",
        width=512,
        height=512,
    )

    gr.Markdown("### Click on an example image to use it as input:")
    example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
    for row in example_rows:
        with gr.Row():
            for path in row:
                ex_img = gr.Image(
                    value=path,
                    show_label=False,
                    interactive=True,
                    width=256,
                    height=256,
                )

                def make_loader(p=path):
                    def load_img():
                        return Image.open(p)
                    return load_img

                ex_img.select(fn=make_loader(), outputs=image_input)

    gr.Markdown("---")
    gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
    gr.Markdown("*Note: Each model is downloaded on first use. Subsequent uses will load from cache and run faster.*")

    display_names = list(model_name_mapping.values())
    default_selected = ["SD3.5L", "Chameleon", "Janus Pro 1B/7B"]

    model_selector = gr.CheckboxGroup(
        choices=display_names,
        label="Select models to run",
        value=default_selected,
        interactive=True,
    )

    status_output = gr.Markdown("", visible=False)
    run_button = gr.Button("Start Processing")

    image_outputs = []
    model_names_ordered = list(model_name_mapping.keys())
    n_columns = 5
    output_rows = [model_names_ordered[i:i+n_columns] for i in range(0, len(model_names_ordered), n_columns)]

    with gr.Column():
        for row in output_rows:
            with gr.Row():
                for model_name in row:
                    display_name = model_name_mapping[model_name]
                    out_img = gr.Image(
                        label=f"{display_name} (Not run)",
                        value=placeholder_image,
                        width=512,
                        height=512,
                    )
                    image_outputs.append(out_img)


    run_button.click(
        fn=process_selected_models,
        inputs=[image_input, model_selector],
        outputs=[status_output] + image_outputs
    )

demo.launch()