VTBench / app.py
huaweilin's picture
update
9cd76aa
import os
import spaces
import gradio as gr
from src.data_processing import pil_to_tensor, tensor_to_pil
from PIL import Image
from src.model_processing import get_model
from huggingface_hub import snapshot_download
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
MODEL_DIR = "./VTBench_models"
# if not os.path.exists(MODEL_DIR):
# print("Downloading VTBench_models from Hugging Face...")
# snapshot_download(
# repo_id="huaweilin/VTBench_models",
# local_dir=MODEL_DIR,
# local_dir_use_symlinks=False
# )
# print("Download complete.")
example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]
model_name_mapping = {
"SD3.5L": "SD3.5L",
"chameleon": "Chameleon",
# "flowmo_lo": "FlowMo Lo",
# "flowmo_hi": "FlowMo Hi",
# "gpt4o": "GPT-4o",
"janus_pro_1b": "Janus Pro 1B/7B",
"llamagen-ds8": "LlamaGen ds8",
"llamagen-ds16": "LlamaGen ds16",
"llamagen-ds16-t2i": "LlamaGen ds16 T2I",
"maskbit_16bit": "MaskBiT 16bit",
"maskbit_18bit": "MaskBiT 18bit",
"open_magvit2": "OpenMagViT",
"titok_b64": "Titok-b64",
"titok_bl64": "Titok-bl64",
"titok_s128": "Titok-s128",
"titok_bl128": "Titok-bl128",
"titok_l32": "Titok-l32",
"titok_sl256": "Titok-sl256",
"var_256": "VAR-256",
"var_512": "VAR-512",
"FLUX.1-dev": "FLUX.1-dev",
"infinity_d32": "Infinity-d32",
"infinity_d64": "Infinity-d64",
"bsqvit": "BSQ-VIT",
}
display_to_internal = {v: k for k, v in model_name_mapping.items()}
def load_model(model_name):
model, data_params = get_model(MODEL_DIR, model_name)
model = model.to(device)
model.eval()
return model, data_params
# model_dict = {
# model_name: load_model(model_name)
# for model_name in model_name_mapping
# }
placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
@spaces.GPU
def process_selected_models(uploaded_image, selected_display_names):
if uploaded_image is None:
return [gr.update(value="⚠️ Please upload an image before processing.", visible=True)] + \
[gr.update() for _ in model_name_mapping]
if not selected_display_names:
return [gr.update(value="⚠️ Please select at least one model.", visible=True)] + \
[gr.update() for _ in model_name_mapping]
selected_results = []
placeholder_results = []
selected_internal = [display_to_internal[d] for d in selected_display_names]
for model_name in model_name_mapping:
label = model_name_mapping[model_name]
if model_name in selected_internal:
try:
model, data_params = load_model(model_name)
pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
with torch.no_grad():
output = model(pixel_values)[0]
reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
del model, pixel_values, output
torch.cuda.empty_cache()
result = gr.update(value=reconstructed_image, label=label)
except Exception as e:
print(f"Error in model {model_name}: {e}")
result = gr.update(value=placeholder_image, label=f"{label} (Error)")
selected_results.append(result)
else:
result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
placeholder_results.append(result)
return [gr.update(visible=False)] + selected_results + placeholder_results
with gr.Blocks() as demo:
gr.Markdown("## VTBench")
gr.Markdown("---")
gr.Markdown("👋 **Welcome to VTBench!** Upload an image, select models, and click 'Start Processing' to compare results side by side.")
gr.Markdown("🔗 **Check out our GitHub repo:** [https://github.com/huawei-lin/VTBench](https://github.com/huawei-lin/VTBench)")
with gr.Accordion("📘 Full Instructions", open=False):
gr.Markdown("""
**VTBench User Guide**
- **Upload an image** or click one of the example images.
- **Select one or more models** from the list.
- Click **Start Processing** to run inference.
- Selected model outputs appear first, others show placeholders.
⚠️ *Each model is downloaded on first use. Please wait patiently the first time you run a model.*
""")
image_input = gr.Image(
type="pil",
label="Upload an image",
width=512,
height=512,
)
gr.Markdown("### Click on an example image to use it as input:")
example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
for row in example_rows:
with gr.Row():
for path in row:
ex_img = gr.Image(
value=path,
show_label=False,
interactive=True,
width=256,
height=256,
)
def make_loader(p=path):
def load_img():
return Image.open(p)
return load_img
ex_img.select(fn=make_loader(), outputs=image_input)
gr.Markdown("---")
gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
gr.Markdown("*Note: Each model is downloaded on first use. Subsequent uses will load from cache and run faster.*")
display_names = list(model_name_mapping.values())
default_selected = ["SD3.5L", "Chameleon", "Janus Pro 1B/7B"]
model_selector = gr.CheckboxGroup(
choices=display_names,
label="Select models to run",
value=default_selected,
interactive=True,
)
status_output = gr.Markdown("", visible=False)
run_button = gr.Button("Start Processing")
image_outputs = []
model_names_ordered = list(model_name_mapping.keys())
n_columns = 5
output_rows = [model_names_ordered[i:i+n_columns] for i in range(0, len(model_names_ordered), n_columns)]
with gr.Column():
for row in output_rows:
with gr.Row():
for model_name in row:
display_name = model_name_mapping[model_name]
out_img = gr.Image(
label=f"{display_name} (Not run)",
value=placeholder_image,
width=512,
height=512,
)
image_outputs.append(out_img)
run_button.click(
fn=process_selected_models,
inputs=[image_input, model_selector],
outputs=[status_output] + image_outputs
)
demo.launch()