File size: 5,525 Bytes
14ce5a9 6474836 14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import spaces
import subprocess
import sys
# REQUIREMENTS_FILE = "requirements.txt"
# if os.path.exists(REQUIREMENTS_FILE):
# try:
# print("Installing dependencies from requirements.txt...")
# subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE])
# print("Dependencies installed successfully.")
# except subprocess.CalledProcessError as e:
# print(f"Failed to install dependencies: {e}")
# else:
# print("requirements.txt not found.")
import gradio as gr
from src.data_processing import pil_to_tensor, tensor_to_pil
from PIL import Image
from src.model_processing import get_model
from huggingface_hub import snapshot_download
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
MODEL_DIR = "./VTBench_models"
if not os.path.exists(MODEL_DIR):
print("Downloading VTBench_models from Hugging Face...")
snapshot_download(
repo_id="huaweilin/VTBench_models",
local_dir=MODEL_DIR,
local_dir_use_symlinks=False
)
print("Download complete.")
example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]
model_name_mapping = {
"SD3.5L": "SD3.5L",
"chameleon": "Chameleon",
# "flowmo_lo": "FlowMo Lo",
# "flowmo_hi": "FlowMo Hi",
# "gpt4o": "GPT-4o",
"janus_pro_1b": "Janus Pro 1B/7B",
"llamagen-ds8": "LlamaGen ds8",
"llamagen-ds16": "LlamaGen ds16",
"llamagen-ds16-t2i": "LlamaGen ds16 T2I",
"maskbit_16bit": "MaskBiT 16bit",
"maskbit_18bit": "MaskBiT 18bit",
"open_magvit2": "OpenMagViT",
"titok_b64": "Titok-b64",
"titok_bl64": "Titok-bl64",
"titok_s128": "Titok-s128",
"titok_bl128": "Titok-bl128",
"titok_l32": "Titok-l32",
"titok_sl256": "Titok-sl256",
"var_256": "VAR-256",
"var_512": "VAR-512",
"FLUX.1-dev": "FLUX.1-dev",
"infinity_d32": "Infinity-d32",
"infinity_d64": "Infinity-d64",
"bsqvit": "BSQ-VIT",
}
def load_model(model_name):
model, data_params = get_model(MODEL_DIR, model_name)
model = model.to(device)
model.eval()
return model, data_params
model_dict = {
model_name: load_model(model_name)
for model_name in model_name_mapping
}
placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
@spaces.GPU
def process_selected_models(uploaded_image, selected_models):
results = []
for model_name in model_name_mapping:
if uploaded_image is None:
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (No input)"))
elif model_name in selected_models:
try:
model, data_params = model_dict[model_name]
pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
output = model(pixel_values)[0]
reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
results.append(gr.update(value=reconstructed_image, label=model_name_mapping[model_name]))
except Exception as e:
print(f"Error in model {model_name}: {e}")
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Error)"))
else:
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Not selected)"))
return results
with gr.Blocks() as demo:
gr.Markdown("## VTBench")
gr.Markdown("---")
image_input = gr.Image(
type="pil",
label="Upload an image",
width=512,
height=512,
)
gr.Markdown("### Click on an example image to use it as input:")
example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
for row in example_rows:
with gr.Row():
for path in row:
ex_img = gr.Image(
value=path,
show_label=False,
interactive=True,
width=256,
height=256,
)
def make_loader(p=path):
def load_img():
return Image.open(p)
return load_img
ex_img.select(fn=make_loader(), outputs=image_input)
gr.Markdown("---")
gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
model_selector = gr.CheckboxGroup(
choices=list(model_name_mapping.keys()),
label="Select models to run",
value=["SD3.5L", "chameleon", "janus_pro_1b"],
interactive=True,
)
run_button = gr.Button("Start Processing")
image_outputs = []
model_items = list(model_name_mapping.items())
n_columns = 5
output_rows = [model_items[i:i+n_columns] for i in range(0, len(model_items), n_columns)]
with gr.Column():
for row in output_rows:
with gr.Row():
for model_name, display_name in row:
out_img = gr.Image(
label=f"{display_name} (Not run)",
value=placeholder_image,
width=512,
height=512,
)
image_outputs.append(out_img)
run_button.click(
fn=process_selected_models,
inputs=[image_input, model_selector],
outputs=image_outputs
)
demo.launch()
|