import spaces import gradio as gr import torch from diffusers import AutoencoderKL from diffusers.utils.remote_utils import remote_decode import torchvision.transforms.v2 as transforms from torchvision.io import read_image from typing import Dict import os import time from huggingface_hub import login # Get token from environment variable hf_token = os.getenv("access_token") login(token=hf_token) class PadToSquare: """Custom transform to pad an image to square dimensions""" def __call__(self, img): _, h, w = img.shape # Get the original dimensions max_side = max(h, w) pad_h = (max_side - h) // 2 pad_w = (max_side - w) // 2 padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h) return transforms.functional.pad(img, padding, padding_mode="edge") class VAETester: def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512): self.device = device self.input_transform = transforms.Compose([ PadToSquare(), transforms.Resize((img_size, img_size)), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean=[0.5], std=[0.5]), ]) self.base_transform = transforms.Compose([ PadToSquare(), transforms.Resize((img_size, img_size)), transforms.ToDtype(torch.float32, scale=True), ]) self.output_transform = transforms.Normalize(mean=[-1], std=[2]) self.vae_models = self._load_all_vaes() def _get_endpoint(self, base_name: str) -> str: """Helper method to get the endpoint for a given base model name""" endpoints = { "sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud", "sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud", "FLUX.1": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud", } return endpoints[base_name] def _load_all_vaes(self) -> Dict[str, Dict]: """Load configurations for local and remote VAE models""" local_vaes = { "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device), "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device), "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device), "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device), "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device), "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device), "FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device), "CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device), "playground-v2.5": AutoencoderKL.from_pretrained("playgroundai/playground-v2.5-1024px-aesthetic", subfolder="vae").to(self.device), } # Define the desired order of models order = [ "stable-diffusion-v1-4", "eq-vae-ema", "eq-sdxl-vae", "sd-vae-ft-mse", #"sd-vae-ft-mse (remote)", "sdxl-vae", #"sdxl-vae (remote)", "playground-v2.5", "stable-diffusion-3-medium", "FLUX.1", #"FLUX.1 (remote)", "CogView4-6B", ] # Construct the vae_models dictionary in the specified order vae_models = {} for name in order: if "(remote)" not in name: # Local model vae_models[name] = {"type": "local", "vae": local_vaes[name]} else: # Remote model base_name = name.replace(" (remote)", "") vae_models[name] = { "type": "remote", "local_vae_key": base_name, "endpoint": self._get_endpoint(base_name), } return vae_models def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float): """Process image through a single VAE (local or remote)""" img_transformed = self.input_transform(img).to(self.device).unsqueeze(0) original_base = self.base_transform(img).cpu() # Start timer start_time = time.time() if model_config["type"] == "local": vae = model_config["vae"] with torch.no_grad(): encoded = vae.encode(img_transformed).latent_dist.sample() decoded = vae.decode(encoded).sample elif model_config["type"] == "remote": local_vae = self.vae_models[model_config["local_vae_key"]]["vae"] with torch.no_grad(): encoded = local_vae.encode(img_transformed).latent_dist.sample() decoded = remote_decode( endpoint=model_config["endpoint"], tensor=encoded, do_scaling=False, output_type="pt", return_type="pt", partial_postprocess=False, ) # End timer processing_time = time.time() - start_time decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu() reconstructed = decoded_transformed.clip(0, 1) diff = (original_base - reconstructed).abs() bw_diff = (diff > tolerance).any(dim=0).float() diff_image = transforms.ToPILImage()(bw_diff) recon_image = transforms.ToPILImage()(reconstructed) diff_score = bw_diff.sum().item() return diff_image, recon_image, diff_score, processing_time def process_all_models(self, img: torch.Tensor, tolerance: float): """Process image through all configured VAEs""" results = {} for name, model_config in self.vae_models.items(): results[name] = self.process_image(img, model_config, tolerance) return results @spaces.GPU(duration=15) def test_all_vaes(image_path: str, tolerance: float, img_size: int): """Gradio interface function to test all VAEs""" tester = VAETester(img_size=img_size) try: img_tensor = read_image(image_path) results = tester.process_all_models(img_tensor, tolerance) diff_images = [] recon_images = [] scores = [] for name in tester.vae_models.keys(): diff_img, recon_img, score, proc_time = results[name] diff_images.append((diff_img, name)) recon_images.append((recon_img, name)) scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s") return diff_images, recon_images, "\n".join(scores) except Exception as e: error_msg = f"Error: {str(e)}" return [None], [None], error_msg examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))] with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo: gr.Markdown("# VAE Comparison Tool") gr.Markdown(""" Upload an image or select an example to compare how different VAEs reconstruct it. 1. The image is padded to a square and resized to the selected size (512 or 1024 pixels). 2. Each VAE encodes the image into a latent space and decodes it back. 3. Outputs include: - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance). - **Reconstructed Images**: Outputs from each VAE. - **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds. Adjust tolerance to change sensitivity. """) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Input Image", height=512) tolerance_slider = gr.Slider( minimum=0.01, maximum=0.5, value=0.1, step=0.01, label="Difference Tolerance", info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged." ) img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512) submit_btn = gr.Button("Test All VAEs") with gr.Column(scale=3): with gr.Row(): diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512) recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512) scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=10, elem_classes="monospace-text") if examples: with gr.Row(): gr.Examples(examples=examples, inputs=image_input, label="Example Images") submit_btn.click( fn=test_all_vaes, inputs=[image_input, tolerance_slider, img_size], outputs=[diff_gallery, recon_gallery, scores_output] ) if __name__ == "__main__": demo.launch(share=False, ssr_mode=False)