🚀 ESRGAN 4x + FLUX Enhancement
Upload an image to upscale 4x with ESRGAN and enhance with FLUX
Optimized for ZeroGPU | Max input: 512x512 → Output: 2048x2048
import os import random import warnings import gc import gradio as gr import numpy as np import spaces import torch import torch.nn as nn from diffusers import FluxImg2ImgPipeline from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import requests # Minimal ESRGAN implementation (without basicsr dependency) class ResidualDenseBlock(nn.Module): def __init__(self, num_feat=64, num_grow_ch=32): super(ResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): def __init__(self, num_feat, num_grow_ch=32): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) def forward(self, x): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) return out * 0.2 + x class RRDBNet(nn.Module): def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4): super(RRDBNet, self).__init__() self.scale = scale self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)]) self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # Upsampling self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.conv_body(self.body(fea)) fea = fea + trunk fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.conv_hr(fea))) return out css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ # Get HuggingFace token huggingface_token = os.getenv("HF_TOKEN") # Download FLUX model if not already cached print("📥 Downloading FLUX model...") model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load FLUX pipeline on CPU initially print("📥 Loading FLUX Img2Img pipeline...") pipe = FluxImg2ImgPipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, use_safetensors=True ) # Enable memory optimizations pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.vae.enable_tiling() pipe.vae.enable_slicing() # Download and load ESRGAN 4x-UltraSharp model print("📥 Loading ESRGAN 4x-UltraSharp...") esrgan_path = "4x-UltraSharp.pth" if not os.path.exists(esrgan_path): print("Downloading ESRGAN model...") url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" response = requests.get(url) with open(esrgan_path, "wb") as f: f.write(response.content) # Initialize ESRGAN model esrgan_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) # Load state dict state_dict = torch.load(esrgan_path, map_location='cpu') if 'params_ema' in state_dict: state_dict = state_dict['params_ema'] elif 'params' in state_dict: state_dict = state_dict['params'] # Clean state dict keys if needed cleaned_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): cleaned_state_dict[k[7:]] = v else: cleaned_state_dict[k] = v esrgan_model.load_state_dict(cleaned_state_dict, strict=False) esrgan_model.eval() print("✅ All models loaded successfully!") MAX_SEED = 1000000 MAX_INPUT_SIZE = 512 # Max input size before upscaling def make_multiple_16(n): """Round to nearest multiple of 16 for FLUX requirements""" return ((n + 15) // 16) * 16 def truncate_prompt(prompt, max_tokens=75): """Truncate prompt to avoid CLIP token limit (77 tokens)""" if not prompt: return "" # Simple truncation by character count (rough approximation) if len(prompt) > 250: # ~75 tokens return prompt[:250] + "..." return prompt def prepare_image(image, max_size=MAX_INPUT_SIZE): """Prepare image for processing""" w, h = image.size # Limit input size if w > max_size or h > max_size: image.thumbnail((max_size, max_size), Image.LANCZOS) return image def esrgan_upscale(image, model, device='cuda'): """Upscale image 4x using ESRGAN""" # Prepare image img_np = np.array(image).astype(np.float32) / 255. img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device) # Upscale with torch.no_grad(): output = model(img_tensor) output = output.squeeze(0).cpu().clamp(0, 1) output_np = output.numpy() output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC output_np = (output_np * 255).astype(np.uint8) return Image.fromarray(output_np) @spaces.GPU(duration=60) # 60 seconds should be enough def enhance_image( input_image, prompt, seed, randomize_seed, num_inference_steps, denoising_strength, progress=gr.Progress(track_tqdm=True), ): """Main enhancement function""" if input_image is None: raise gr.Error("Please upload an image") # Clear memory torch.cuda.empty_cache() gc.collect() try: # Randomize seed if needed if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare and validate prompt prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed") # Prepare input image input_image = prepare_image(input_image) original_size = input_image.size # Step 1: ESRGAN upscale (4x) on GPU gr.Info("🔍 Upscaling with ESRGAN 4x...") # Move ESRGAN to GPU for faster processing esrgan_model.to("cuda") upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda") # Move ESRGAN back to CPU to free memory esrgan_model.to("cpu") torch.cuda.empty_cache() # Ensure dimensions are multiples of 16 for FLUX w, h = upscaled_image.size new_w = make_multiple_16(w) new_h = make_multiple_16(h) if new_w != w or new_h != h: # Pad image to meet requirements padded = Image.new('RGB', (new_w, new_h)) padded.paste(upscaled_image, (0, 0)) upscaled_image = padded # Step 2: FLUX enhancement gr.Info("🎨 Enhancing with FLUX...") # Move pipeline to GPU pipe.to("cuda") # Generate with FLUX generator = torch.Generator(device="cuda").manual_seed(seed) with torch.inference_mode(): result = pipe( prompt=prompt, image=upscaled_image, strength=denoising_strength, num_inference_steps=num_inference_steps, guidance_scale=1.0, # Fixed at 1.0 for FLUX dev height=new_h, width=new_w, generator=generator, ).images[0] # Crop back if we padded if new_w != w or new_h != h: result = result.crop((0, 0, w, h)) # Move pipeline back to CPU pipe.to("cpu") torch.cuda.empty_cache() gc.collect() # Prepare images for slider (before/after) input_resized = input_image.resize(result.size, Image.LANCZOS) gr.Info("✅ Enhancement complete!") return [input_resized, result], seed except Exception as e: # Cleanup on error pipe.to("cpu") esrgan_model.to("cpu") torch.cuda.empty_cache() gc.collect() raise gr.Error(f"Enhancement failed: {str(e)}") # Create Gradio interface with gr.Blocks(css=css) as demo: gr.HTML("""
Upload an image to upscale 4x with ESRGAN and enhance with FLUX
Optimized for ZeroGPU | Max input: 512x512 → Output: 2048x2048
📌 Pipeline: ESRGAN 4x-UltraSharp → FLUX Dev Enhancement
⚡ Optimized for ZeroGPU with automatic memory management