import os import random import warnings import gc import gradio as gr import numpy as np import spaces import torch import torch.nn as nn from diffusers import FluxImg2ImgPipeline from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import requests # Minimal ESRGAN implementation (without basicsr dependency) class ResidualDenseBlock(nn.Module): def __init__(self, num_feat=64, num_grow_ch=32): super(ResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): def __init__(self, num_feat, num_grow_ch=32): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) def forward(self, x): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) return out * 0.2 + x class RRDBNet(nn.Module): def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4): super(RRDBNet, self).__init__() self.scale = scale self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)]) self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # Upsampling self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.conv_body(self.body(fea)) fea = fea + trunk fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.conv_hr(fea))) return out css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ # Get HuggingFace token huggingface_token = os.getenv("HF_TOKEN") # Download FLUX model if not already cached print("📥 Downloading FLUX model...") model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load FLUX pipeline on CPU initially print("📥 Loading FLUX Img2Img pipeline...") pipe = FluxImg2ImgPipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, use_safetensors=True ) # Enable memory optimizations pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.vae.enable_tiling() pipe.vae.enable_slicing() # Download and load ESRGAN 4x-UltraSharp model print("📥 Loading ESRGAN 4x-UltraSharp...") esrgan_path = "4x-UltraSharp.pth" if not os.path.exists(esrgan_path): print("Downloading ESRGAN model...") url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" response = requests.get(url) with open(esrgan_path, "wb") as f: f.write(response.content) # Initialize ESRGAN model esrgan_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) # Load state dict state_dict = torch.load(esrgan_path, map_location='cpu') if 'params_ema' in state_dict: state_dict = state_dict['params_ema'] elif 'params' in state_dict: state_dict = state_dict['params'] # Clean state dict keys if needed cleaned_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): cleaned_state_dict[k[7:]] = v else: cleaned_state_dict[k] = v esrgan_model.load_state_dict(cleaned_state_dict, strict=False) esrgan_model.eval() print("✅ All models loaded successfully!") MAX_SEED = 1000000 MAX_INPUT_SIZE = 512 # Max input size before upscaling def make_multiple_16(n): """Round to nearest multiple of 16 for FLUX requirements""" return ((n + 15) // 16) * 16 def truncate_prompt(prompt, max_tokens=75): """Truncate prompt to avoid CLIP token limit (77 tokens)""" if not prompt: return "" # Simple truncation by character count (rough approximation) if len(prompt) > 250: # ~75 tokens return prompt[:250] + "..." return prompt def prepare_image(image, max_size=MAX_INPUT_SIZE): """Prepare image for processing""" w, h = image.size # Limit input size if w > max_size or h > max_size: image.thumbnail((max_size, max_size), Image.LANCZOS) return image def esrgan_upscale(image, model, device='cuda'): """Upscale image 4x using ESRGAN""" # Prepare image img_np = np.array(image).astype(np.float32) / 255. img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device) # Upscale with torch.no_grad(): output = model(img_tensor) output = output.squeeze(0).cpu().clamp(0, 1) output_np = output.numpy() output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC output_np = (output_np * 255).astype(np.uint8) return Image.fromarray(output_np) @spaces.GPU(duration=60) # 60 seconds should be enough def enhance_image( input_image, prompt, seed, randomize_seed, num_inference_steps, denoising_strength, progress=gr.Progress(track_tqdm=True), ): """Main enhancement function""" if input_image is None: raise gr.Error("Please upload an image") # Clear memory torch.cuda.empty_cache() gc.collect() try: # Randomize seed if needed if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare and validate prompt prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed") # Prepare input image input_image = prepare_image(input_image) original_size = input_image.size # Step 1: ESRGAN upscale (4x) on GPU gr.Info("🔍 Upscaling with ESRGAN 4x...") # Move ESRGAN to GPU for faster processing esrgan_model.to("cuda") upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda") # Move ESRGAN back to CPU to free memory esrgan_model.to("cpu") torch.cuda.empty_cache() # Ensure dimensions are multiples of 16 for FLUX w, h = upscaled_image.size new_w = make_multiple_16(w) new_h = make_multiple_16(h) if new_w != w or new_h != h: # Pad image to meet requirements padded = Image.new('RGB', (new_w, new_h)) padded.paste(upscaled_image, (0, 0)) upscaled_image = padded # Step 2: FLUX enhancement gr.Info("🎨 Enhancing with FLUX...") # Move pipeline to GPU pipe.to("cuda") # Generate with FLUX generator = torch.Generator(device="cuda").manual_seed(seed) with torch.inference_mode(): result = pipe( prompt=prompt, image=upscaled_image, strength=denoising_strength, num_inference_steps=num_inference_steps, guidance_scale=1.0, # Fixed at 1.0 for FLUX dev height=new_h, width=new_w, generator=generator, ).images[0] # Crop back if we padded if new_w != w or new_h != h: result = result.crop((0, 0, w, h)) # Move pipeline back to CPU pipe.to("cpu") torch.cuda.empty_cache() gc.collect() # Prepare images for slider (before/after) input_resized = input_image.resize(result.size, Image.LANCZOS) gr.Info("✅ Enhancement complete!") return [input_resized, result], seed except Exception as e: # Cleanup on error pipe.to("cpu") esrgan_model.to("cpu") torch.cuda.empty_cache() gc.collect() raise gr.Error(f"Enhancement failed: {str(e)}") # Create Gradio interface with gr.Blocks(css=css) as demo: gr.HTML("""

🚀 ESRGAN 4x + FLUX Enhancement

Upload an image to upscale 4x with ESRGAN and enhance with FLUX

Optimized for ZeroGPU | Max input: 512x512 → Output: 2048x2048

""") with gr.Row(): with gr.Column(scale=1): # Input section input_image = gr.Image( label="Input Image", type="pil", height=256 ) prompt = gr.Textbox( label="Enhancement Prompt", placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')", value="high quality, ultra detailed, sharp", lines=2 ) with gr.Accordion("Advanced Settings", open=False): num_inference_steps = gr.Slider( label="Enhancement Steps", minimum=10, maximum=25, step=1, value=18, info="More steps = better quality but slower" ) denoising_strength = gr.Slider( label="Enhancement Strength", minimum=0.1, maximum=0.6, step=0.05, value=0.35, info="Higher = more changes to the image" ) with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize seed", value=True ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 ) enhance_btn = gr.Button( "🎨 Enhance Image (4x Upscale)", variant="primary", size="lg" ) with gr.Column(scale=2): # Output section result_slider = ImageSlider( type="pil", label="Before / After", interactive=False, height=512 ) used_seed = gr.Number( label="Seed Used", interactive=False, visible=False ) # Examples gr.Examples( examples=[ ["high quality, ultra detailed, sharp"], ["cinematic, professional photography, enhanced details"], ["vibrant colors, high contrast, sharp focus"], ["photorealistic, 8k quality, fine details"], ], inputs=[prompt], label="Example Prompts" ) # Event handler enhance_btn.click( fn=enhance_image, inputs=[ input_image, prompt, seed, randomize_seed, num_inference_steps, denoising_strength, ], outputs=[result_slider, used_seed] ) gr.HTML("""

📌 Pipeline: ESRGAN 4x-UltraSharp → FLUX Dev Enhancement

⚡ Optimized for ZeroGPU with automatic memory management

""") if __name__ == "__main__": demo.queue(max_size=3).launch( share=False, server_name="0.0.0.0", server_port=7860 )