🚀 ESRGAN 4x + FLUX Enhancement
Upload an image to upscale 4x with ESRGAN and enhance with FLUX
Optimized for ZeroGPU | Max input: 512x512 → Output: 2048x2048
import os import random import warnings import gc import gradio as gr import numpy as np import spaces import torch from diffusers import FluxImg2ImgPipeline from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import requests # ESRGAN imports from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils import img2tensor, tensor2img css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ # Get HuggingFace token huggingface_token = os.getenv("HF_TOKEN") # Download FLUX model if not already cached print("📥 Downloading FLUX model...") model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load FLUX pipeline on CPU initially print("📥 Loading FLUX Img2Img pipeline...") pipe = FluxImg2ImgPipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, use_safetensors=True ) # Enable memory optimizations pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.vae.enable_tiling() pipe.vae.enable_slicing() # Download and load ESRGAN 4x-UltraSharp model print("📥 Loading ESRGAN 4x-UltraSharp...") esrgan_path = "4x-UltraSharp.pth" if not os.path.exists(esrgan_path): print("Downloading ESRGAN model...") url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" response = requests.get(url) with open(esrgan_path, "wb") as f: f.write(response.content) # Initialize ESRGAN model esrgan_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) state_dict = torch.load(esrgan_path, map_location='cpu') if 'params_ema' in state_dict: state_dict = state_dict['params_ema'] elif 'params' in state_dict: state_dict = state_dict['params'] esrgan_model.load_state_dict(state_dict) esrgan_model.eval() print("✅ All models loaded successfully!") MAX_SEED = 1000000 MAX_INPUT_SIZE = 512 # Max input size before upscaling def make_multiple_16(n): """Round to nearest multiple of 16 for FLUX requirements""" return ((n + 15) // 16) * 16 def truncate_prompt(prompt, max_tokens=75): """Truncate prompt to avoid CLIP token limit (77 tokens)""" if not prompt: return "" # Simple truncation by character count (rough approximation) if len(prompt) > 250: # ~75 tokens return prompt[:250] + "..." return prompt def prepare_image(image, max_size=MAX_INPUT_SIZE): """Prepare image for processing""" w, h = image.size # Limit input size if w > max_size or h > max_size: image.thumbnail((max_size, max_size), Image.LANCZOS) return image def esrgan_upscale(image): """Upscale image 4x using ESRGAN""" # Convert PIL to tensor img_np = np.array(image).astype(np.float32) / 255. img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True) # Upscale with torch.no_grad(): output = esrgan_model(img_tensor.unsqueeze(0).cpu()) # Convert back to PIL output_np = tensor2img(output.squeeze(0), rgb2bgr=False, min_max=(0, 1)) return Image.fromarray(output_np) @spaces.GPU(duration=60) # 60 seconds should be enough def enhance_image( input_image, prompt, seed, randomize_seed, num_inference_steps, denoising_strength, progress=gr.Progress(track_tqdm=True), ): """Main enhancement function""" if input_image is None: raise gr.Error("Please upload an image") # Clear memory torch.cuda.empty_cache() gc.collect() try: # Randomize seed if needed if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare and validate prompt prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed") # Prepare input image input_image = prepare_image(input_image) original_size = input_image.size # Step 1: ESRGAN upscale (4x) on CPU gr.Info("🔍 Upscaling with ESRGAN 4x...") with torch.no_grad(): # Move ESRGAN to GPU for faster processing esrgan_model.to("cuda") # Convert image for ESRGAN img_np = np.array(input_image).astype(np.float32) / 255. img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True) img_tensor = img_tensor.unsqueeze(0).to("cuda") # Upscale output_tensor = esrgan_model(img_tensor) # Convert back to PIL output_np = tensor2img(output_tensor.squeeze(0).cpu(), rgb2bgr=False, min_max=(0, 1)) upscaled_image = Image.fromarray(output_np) # Move ESRGAN back to CPU to free memory esrgan_model.to("cpu") torch.cuda.empty_cache() # Ensure dimensions are multiples of 16 for FLUX w, h = upscaled_image.size new_w = make_multiple_16(w) new_h = make_multiple_16(h) if new_w != w or new_h != h: # Pad image to meet requirements padded = Image.new('RGB', (new_w, new_h)) padded.paste(upscaled_image, (0, 0)) upscaled_image = padded # Step 2: FLUX enhancement gr.Info("🎨 Enhancing with FLUX...") # Move pipeline to GPU pipe.to("cuda") # Generate with FLUX generator = torch.Generator(device="cuda").manual_seed(seed) with torch.inference_mode(): result = pipe( prompt=prompt, image=upscaled_image, strength=denoising_strength, num_inference_steps=num_inference_steps, guidance_scale=1.0, # Fixed at 1.0 for FLUX dev height=new_h, width=new_w, generator=generator, ).images[0] # Crop back if we padded if new_w != w or new_h != h: result = result.crop((0, 0, w, h)) # Move pipeline back to CPU pipe.to("cpu") torch.cuda.empty_cache() gc.collect() # Prepare images for slider (before/after) input_resized = input_image.resize(result.size, Image.LANCZOS) gr.Info("✅ Enhancement complete!") return [input_resized, result], seed except Exception as e: # Cleanup on error pipe.to("cpu") esrgan_model.to("cpu") torch.cuda.empty_cache() gc.collect() raise gr.Error(f"Enhancement failed: {str(e)}") # Create Gradio interface with gr.Blocks(css=css) as demo: gr.HTML("""
Upload an image to upscale 4x with ESRGAN and enhance with FLUX
Optimized for ZeroGPU | Max input: 512x512 → Output: 2048x2048
📌 Pipeline: ESRGAN 4x-UltraSharp → FLUX Dev Enhancement
⚡ Optimized for ZeroGPU with automatic memory management