import spaces import gradio as gr import torch from diffusers import AutoencoderKL, AutoencoderDC, AutoModel import torchvision.transforms.v2 as transforms from torchvision.io import read_image from typing import Dict import os import time from huggingface_hub import login # Get token from environment variable hf_token = os.getenv("access_token") login(token=hf_token) class PadToSquare: """Custom transform to pad an image to square dimensions""" def __call__(self, img): _, h, w = img.shape # Get the original dimensions max_side = max(h, w) pad_h = (max_side - h) // 2 pad_w = (max_side - w) // 2 padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h) return transforms.functional.pad(img, padding, padding_mode="edge") class VAETester: def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512): self.device = device self.input_transform = transforms.Compose([ PadToSquare(), transforms.Resize((img_size, img_size)), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean=[0.5], std=[0.5]), ]) self.base_transform = transforms.Compose([ PadToSquare(), transforms.Resize((img_size, img_size)), transforms.ToDtype(torch.float32, scale=True), ]) self.output_transform = transforms.Normalize(mean=[-1], std=[2]) self.vae_models = self._load_all_vaes() def _load_all_vaes(self) -> Dict[str, Dict]: """Load configurations for all VAE models""" vaes = { "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device), "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device), "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device), "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device), "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device), "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device), "FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device), "CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device), "playground-v2.5": AutoencoderKL.from_pretrained("playgroundai/playground-v2.5-1024px-aesthetic", subfolder="vae").to(self.device), # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device), "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device), "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device), "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device), } # Define the desired order of models order = [ "stable-diffusion-v1-4", "eq-vae-ema", "eq-sdxl-vae", "sd-vae-ft-mse", "sdxl-vae", "playground-v2.5", "stable-diffusion-3-medium", "FLUX.1", "CogView4-6B", # "dc-ae-f32c32-sana-1.0", "FLUX.1-Kontext", "FLUX.2", "FLUX.2-TinyAutoEncoder", ] # Construct the vae_models dictionary in the specified order return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order} def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float, vae_name: str): """Process image through a single VAE model""" dtype = model_config["dtype"] vae = model_config["vae"] img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0) original_base = self.base_transform(img).cpu() # Time the encoding-decoding process start_time = time.time() with torch.no_grad(): if vae_name == "FLUX.2-TinyAutoEncoder": encoded = vae.encode(img_transformed, return_dict=False) decoded = vae.decode(encoded, return_dict=False) else: encoded = vae.encode(img_transformed).latent_dist.sample() decoded = vae.decode(encoded).sample processing_time = time.time() - start_time decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu() reconstructed = decoded_transformed.clip(0, 1) diff = (original_base - reconstructed).abs() bw_diff = (diff > tolerance).any(dim=0).float() diff_image = transforms.ToPILImage()(bw_diff) recon_image = transforms.ToPILImage()(reconstructed) diff_score = bw_diff.sum().item() return diff_image, recon_image, diff_score, processing_time def process_all_models(self, img: torch.Tensor, tolerance: float): """Process image through all configured VAEs""" results = {} for vae_name, model_config in self.vae_models.items(): results[vae_name] = self.process_image(img, model_config, tolerance, vae_name) return results @spaces.GPU(duration=20) def test_all_vaes(image_path: str, tolerance: float, img_size: int): """Gradio interface function to test all VAEs""" tester = VAETester(img_size=img_size) try: img_tensor = read_image(image_path) results = tester.process_all_models(img_tensor, tolerance) diff_images = [] recon_images = [] scores = [] for name in tester.vae_models.keys(): diff_img, recon_img, score, proc_time = results[name] diff_images.append((diff_img, name)) recon_images.append((recon_img, name)) scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s") return diff_images, recon_images, "\n".join(scores) except Exception as e: error_msg = f"Error: {str(e)}" return [None], [None], error_msg examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))] custom_css = """ .center-header { display: flex; align-items: center; justify-content: center; margin: 0 0 10px 0; } .monospace-text { font-family: 'Courier New', Courier, monospace; } """ with gr.Blocks(title="VAE Performance Tester", css=custom_css) as demo: gr.Markdown("

VAE Comparison Tool

") gr.Markdown(""" Upload an image or select an example to compare how different VAEs reconstruct it. 1. The image is padded to a square and resized to the selected size (512 or 1024 pixels). 2. Each VAE encodes the image into a latent space and decodes it back. 3. Outputs include: - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance). - **Reconstructed Images**: Outputs from each VAE. - **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds. Adjust tolerance to change sensitivity. """) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Input Image", height=512) tolerance_slider = gr.Slider( minimum=0.01, maximum=0.5, value=0.1, step=0.01, label="Difference Tolerance", info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged." ) img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512) submit_btn = gr.Button("Test All VAEs") with gr.Column(scale=3): with gr.Row(): diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512) recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512) scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=12, elem_classes="monospace-text") if examples: with gr.Row(): gr.Examples(examples=examples, inputs=image_input, label="Example Images") submit_btn.click( fn=test_all_vaes, inputs=[image_input, tolerance_slider, img_size], outputs=[diff_gallery, recon_gallery, scores_output] ) if __name__ == "__main__": demo.launch(share=False, ssr_mode=False)