vae-comparison / app.py
rizavelioglu's picture
add new model
778f222 verified
import spaces
import gradio as gr
import torch
from diffusers import AutoencoderKL
from diffusers.utils.remote_utils import remote_decode
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Dict
import os
import time
from huggingface_hub import login
# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)
class PadToSquare:
"""Custom transform to pad an image to square dimensions"""
def __call__(self, img):
_, h, w = img.shape # Get the original dimensions
max_side = max(h, w)
pad_h = (max_side - h) // 2
pad_w = (max_side - w) // 2
padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
return transforms.functional.pad(img, padding, padding_mode="edge")
class VAETester:
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512):
self.device = device
self.input_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
self.base_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
])
self.output_transform = transforms.Normalize(mean=[-1], std=[2])
self.vae_models = self._load_all_vaes()
def _get_endpoint(self, base_name: str) -> str:
"""Helper method to get the endpoint for a given base model name"""
endpoints = {
"sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
"sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
"FLUX.1": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
}
return endpoints[base_name]
def _load_all_vaes(self) -> Dict[str, Dict]:
"""Load configurations for local and remote VAE models"""
local_vaes = {
"stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
"eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device),
"eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device),
"sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
"sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
"stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
"FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
"CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device),
"playground-v2.5": AutoencoderKL.from_pretrained("playgroundai/playground-v2.5-1024px-aesthetic", subfolder="vae").to(self.device),
}
# Define the desired order of models
order = [
"stable-diffusion-v1-4",
"eq-vae-ema",
"eq-sdxl-vae",
"sd-vae-ft-mse",
#"sd-vae-ft-mse (remote)",
"sdxl-vae",
#"sdxl-vae (remote)",
"playground-v2.5",
"stable-diffusion-3-medium",
"FLUX.1",
#"FLUX.1 (remote)",
"CogView4-6B",
]
# Construct the vae_models dictionary in the specified order
vae_models = {}
for name in order:
if "(remote)" not in name:
# Local model
vae_models[name] = {"type": "local", "vae": local_vaes[name]}
else:
# Remote model
base_name = name.replace(" (remote)", "")
vae_models[name] = {
"type": "remote",
"local_vae_key": base_name,
"endpoint": self._get_endpoint(base_name),
}
return vae_models
def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
"""Process image through a single VAE (local or remote)"""
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
original_base = self.base_transform(img).cpu()
# Start timer
start_time = time.time()
if model_config["type"] == "local":
vae = model_config["vae"]
with torch.no_grad():
encoded = vae.encode(img_transformed).latent_dist.sample()
decoded = vae.decode(encoded).sample
elif model_config["type"] == "remote":
local_vae = self.vae_models[model_config["local_vae_key"]]["vae"]
with torch.no_grad():
encoded = local_vae.encode(img_transformed).latent_dist.sample()
decoded = remote_decode(
endpoint=model_config["endpoint"],
tensor=encoded,
do_scaling=False,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
# End timer
processing_time = time.time() - start_time
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
reconstructed = decoded_transformed.clip(0, 1)
diff = (original_base - reconstructed).abs()
bw_diff = (diff > tolerance).any(dim=0).float()
diff_image = transforms.ToPILImage()(bw_diff)
recon_image = transforms.ToPILImage()(reconstructed)
diff_score = bw_diff.sum().item()
return diff_image, recon_image, diff_score, processing_time
def process_all_models(self, img: torch.Tensor, tolerance: float):
"""Process image through all configured VAEs"""
results = {}
for name, model_config in self.vae_models.items():
results[name] = self.process_image(img, model_config, tolerance)
return results
@spaces.GPU(duration=15)
def test_all_vaes(image_path: str, tolerance: float, img_size: int):
"""Gradio interface function to test all VAEs"""
tester = VAETester(img_size=img_size)
try:
img_tensor = read_image(image_path)
results = tester.process_all_models(img_tensor, tolerance)
diff_images = []
recon_images = []
scores = []
for name in tester.vae_models.keys():
diff_img, recon_img, score, proc_time = results[name]
diff_images.append((diff_img, name))
recon_images.append((recon_img, name))
scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s")
return diff_images, recon_images, "\n".join(scores)
except Exception as e:
error_msg = f"Error: {str(e)}"
return [None], [None], error_msg
examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
gr.Markdown("# VAE Comparison Tool")
gr.Markdown("""
Upload an image or select an example to compare how different VAEs reconstruct it.
1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
2. Each VAE encodes the image into a latent space and decodes it back.
3. Outputs include:
- **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
- **Reconstructed Images**: Outputs from each VAE.
- **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds.
Adjust tolerance to change sensitivity.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Input Image", height=512)
tolerance_slider = gr.Slider(
minimum=0.01,
maximum=0.5,
value=0.1,
step=0.01,
label="Difference Tolerance",
info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged."
)
img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512)
submit_btn = gr.Button("Test All VAEs")
with gr.Column(scale=3):
with gr.Row():
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=10, elem_classes="monospace-text")
if examples:
with gr.Row():
gr.Examples(examples=examples, inputs=image_input, label="Example Images")
submit_btn.click(
fn=test_all_vaes,
inputs=[image_input, tolerance_slider, img_size],
outputs=[diff_gallery, recon_gallery, scores_output]
)
if __name__ == "__main__":
demo.launch(share=False, ssr_mode=False)