File size: 6,759 Bytes
14ce5a9 0c251ef 14ce5a9 6474836 14ce5a9 b460782 14ce5a9 5a9bf9d 14ce5a9 b460782 7d3842f b460782 14ce5a9 7d3842f b460782 14ce5a9 5a9bf9d 14ce5a9 4968738 14ce5a9 5a9bf9d 4968738 5a9bf9d 7d3842f 14ce5a9 7d3842f 14ce5a9 7d3842f b460782 14ce5a9 eb8cda0 9cd76aa eb8cda0 14ce5a9 b460782 14ce5a9 eb8cda0 b460782 14ce5a9 b460782 14ce5a9 b460782 14ce5a9 b460782 14ce5a9 7d3842f 14ce5a9 7d3842f 14ce5a9 7d3842f 14ce5a9 b460782 14ce5a9 b460782 14ce5a9 b460782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import os
import spaces
import gradio as gr
from src.data_processing import pil_to_tensor, tensor_to_pil
from PIL import Image
from src.model_processing import get_model
from huggingface_hub import snapshot_download
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
MODEL_DIR = "./VTBench_models"
# if not os.path.exists(MODEL_DIR):
# print("Downloading VTBench_models from Hugging Face...")
# snapshot_download(
# repo_id="huaweilin/VTBench_models",
# local_dir=MODEL_DIR,
# local_dir_use_symlinks=False
# )
# print("Download complete.")
example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]
model_name_mapping = {
"SD3.5L": "SD3.5L",
"chameleon": "Chameleon",
# "flowmo_lo": "FlowMo Lo",
# "flowmo_hi": "FlowMo Hi",
# "gpt4o": "GPT-4o",
"janus_pro_1b": "Janus Pro 1B/7B",
"llamagen-ds8": "LlamaGen ds8",
"llamagen-ds16": "LlamaGen ds16",
"llamagen-ds16-t2i": "LlamaGen ds16 T2I",
"maskbit_16bit": "MaskBiT 16bit",
"maskbit_18bit": "MaskBiT 18bit",
"open_magvit2": "OpenMagViT",
"titok_b64": "Titok-b64",
"titok_bl64": "Titok-bl64",
"titok_s128": "Titok-s128",
"titok_bl128": "Titok-bl128",
"titok_l32": "Titok-l32",
"titok_sl256": "Titok-sl256",
"var_256": "VAR-256",
"var_512": "VAR-512",
"FLUX.1-dev": "FLUX.1-dev",
"infinity_d32": "Infinity-d32",
"infinity_d64": "Infinity-d64",
"bsqvit": "BSQ-VIT",
}
display_to_internal = {v: k for k, v in model_name_mapping.items()}
def load_model(model_name):
model, data_params = get_model(MODEL_DIR, model_name)
model = model.to(device)
model.eval()
return model, data_params
# model_dict = {
# model_name: load_model(model_name)
# for model_name in model_name_mapping
# }
placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
@spaces.GPU
def process_selected_models(uploaded_image, selected_display_names):
if uploaded_image is None:
return [gr.update(value="⚠️ Please upload an image before processing.", visible=True)] + \
[gr.update() for _ in model_name_mapping]
if not selected_display_names:
return [gr.update(value="⚠️ Please select at least one model.", visible=True)] + \
[gr.update() for _ in model_name_mapping]
selected_results = []
placeholder_results = []
selected_internal = [display_to_internal[d] for d in selected_display_names]
for model_name in model_name_mapping:
label = model_name_mapping[model_name]
if model_name in selected_internal:
try:
model, data_params = load_model(model_name)
pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
with torch.no_grad():
output = model(pixel_values)[0]
reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
del model, pixel_values, output
torch.cuda.empty_cache()
result = gr.update(value=reconstructed_image, label=label)
except Exception as e:
print(f"Error in model {model_name}: {e}")
result = gr.update(value=placeholder_image, label=f"{label} (Error)")
selected_results.append(result)
else:
result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
placeholder_results.append(result)
return [gr.update(visible=False)] + selected_results + placeholder_results
with gr.Blocks() as demo:
gr.Markdown("## VTBench")
gr.Markdown("---")
gr.Markdown("👋 **Welcome to VTBench!** Upload an image, select models, and click 'Start Processing' to compare results side by side.")
gr.Markdown("🔗 **Check out our GitHub repo:** [https://github.com/huawei-lin/VTBench](https://github.com/huawei-lin/VTBench)")
with gr.Accordion("📘 Full Instructions", open=False):
gr.Markdown("""
**VTBench User Guide**
- **Upload an image** or click one of the example images.
- **Select one or more models** from the list.
- Click **Start Processing** to run inference.
- Selected model outputs appear first, others show placeholders.
⚠️ *Each model is downloaded on first use. Please wait patiently the first time you run a model.*
""")
image_input = gr.Image(
type="pil",
label="Upload an image",
width=512,
height=512,
)
gr.Markdown("### Click on an example image to use it as input:")
example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
for row in example_rows:
with gr.Row():
for path in row:
ex_img = gr.Image(
value=path,
show_label=False,
interactive=True,
width=256,
height=256,
)
def make_loader(p=path):
def load_img():
return Image.open(p)
return load_img
ex_img.select(fn=make_loader(), outputs=image_input)
gr.Markdown("---")
gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
gr.Markdown("*Note: Each model is downloaded on first use. Subsequent uses will load from cache and run faster.*")
display_names = list(model_name_mapping.values())
default_selected = ["SD3.5L", "Chameleon", "Janus Pro 1B/7B"]
model_selector = gr.CheckboxGroup(
choices=display_names,
label="Select models to run",
value=default_selected,
interactive=True,
)
status_output = gr.Markdown("", visible=False)
run_button = gr.Button("Start Processing")
image_outputs = []
model_names_ordered = list(model_name_mapping.keys())
n_columns = 5
output_rows = [model_names_ordered[i:i+n_columns] for i in range(0, len(model_names_ordered), n_columns)]
with gr.Column():
for row in output_rows:
with gr.Row():
for model_name in row:
display_name = model_name_mapping[model_name]
out_img = gr.Image(
label=f"{display_name} (Not run)",
value=placeholder_image,
width=512,
height=512,
)
image_outputs.append(out_img)
run_button.click(
fn=process_selected_models,
inputs=[image_input, model_selector],
outputs=[status_output] + image_outputs
)
demo.launch()
|