import logging import random import warnings import os import gradio as gr import numpy as np import spaces import torch from diffusers import FluxImg2ImgPipeline from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import requests from transformers import T5TokenizerFast # For ESRGAN (requires pip install basicsr gfpgan) try: from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils import img2tensor, tensor2img USE_ESRGAN = True except ImportError: USE_ESRGAN = False warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ # Device setup - Default to CPU, let runtime handle GPU power_device = "ZeroGPU" device = "cpu" # Get HuggingFace token huggingface_token = os.getenv("HF_TOKEN") MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support def make_divisible_by_16(size): """Adjust size to nearest multiple of 16, stretching if necessary""" return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16) def process_input(input_image, upscale_factor): """Process input image and handle size constraints""" w, h = input_image.size w_original, h_original = w, h aspect_ratio = w / h was_resized = False if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: warnings.warn( f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget." ) gr.Info( f"Requested output image is too large. Resizing input to fit within pixel budget." ) target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) scale = (target_input_pixels / (w * h)) ** 0.5 new_w = int(w * scale) // 16 * 16 # Ensure divisible by 16 for Flux compatibility new_h = int(h * scale) // 16 * 16 if new_w == 0 or new_h == 0: new_w = max(16, new_w) new_h = max(16, new_h) input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS) was_resized = True return input_image, w_original, h_original, was_resized def load_image_from_url(url): """Load image from URL""" try: response = requests.get(url, stream=True) response.raise_for_status() return Image.open(response.raw) except Exception as e: raise gr.Error(f"Failed to load image from URL: {e}") def esrgan_upscale(image, scale=4): if not USE_ESRGAN: return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS) img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True) with torch.no_grad(): output = esrgan_model(img.unsqueeze(0)).squeeze() output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1)) return Image.fromarray(output_img) def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32): """Tiled Img2Img to mimic Ultimate SD Upscaler tiling""" w, h = image.size output = image.copy() # Start with the control image for x in range(0, w, tile_size - overlap): for y in range(0, h, tile_size - overlap): tile_w = min(tile_size, w - x) tile_h = min(tile_size, h - y) if tile_h < 16 or tile_w < 16: # Skip tiny tiles continue tile = image.crop((x, y, x + tile_w, y + tile_h)) # Force tile to div by 16 new_tile_w = make_divisible_by_16(tile_w) new_tile_h = make_divisible_by_16(tile_h) tile = tile.resize((new_tile_w, new_tile_h), resample=Image.LANCZOS) # Run Flux on tile gen_tile = pipe( prompt=prompt, image=tile, strength=strength, num_inference_steps=steps, guidance_scale=guidance, height=new_tile_h, width=new_tile_w, generator=generator, ).images[0] # Resize gen_tile back to original tile dimensions gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS) # Paste with blending if overlap if overlap > 0: paste_box = (x, y, x + tile_w, y + tile_h) if x > 0 or y > 0: # Simple linear blend on overlaps mask = Image.new('L', (tile_w, tile_h), 255) effective_overlap_x = min(overlap, tile_w) effective_overlap_y = min(overlap, tile_h) if x > 0: for i in range(effective_overlap_x): for j in range(tile_h): mask.putpixel((i, j), int(255 * (i / overlap))) if y > 0: for i in range(tile_w): for j in range(effective_overlap_y): mask.putpixel((i, j), int(255 * (j / overlap))) output.paste(gen_tile, paste_box, mask) else: output.paste(gen_tile, paste_box) else: output.paste(gen_tile, (x, y)) return output @spaces.GPU(duration=120) def enhance_image( image_input, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, custom_prompt, tile_size, progress=gr.Progress(track_tqdm=True), ): """Main enhancement function""" # Lazy loading of models global pipe, esrgan_model if 'pipe' not in globals(): try: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 print(f"📥 Loading FLUX Img2Img on {device}...") tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token) pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, low_cpu_mem_usage=True, device_map="balanced", tokenizer_2=tokenizer_2, token=huggingface_token ) pipe.enable_vae_tiling() pipe.enable_vae_slicing() if device == "cuda": pipe.reset_device_map() pipe.enable_model_cpu_offload() if USE_ESRGAN: esrgan_path = "4x-UltraSharp.pth" if not os.path.exists(esrgan_path): url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" with open(esrgan_path, "wb") as f: f.write(requests.get(url).content) esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) state_dict = torch.load(esrgan_path)['params_ema'] esrgan_model.load_state_dict(state_dict) esrgan_model.eval() esrgan_model.to(device) print("✅ Models loaded successfully!") except Exception as e: print(f"Model loading error: {e}, falling back to CPU") device = "cpu" dtype = torch.float32 # Reload on CPU if needed tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token) pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, low_cpu_mem_usage=True, device_map=None, tokenizer_2=tokenizer_2, token=huggingface_token ) pipe.enable_vae_tiling() pipe.enable_vae_slicing() # Handle image input if image_input is not None: input_image = image_input elif image_url: input_image = load_image_from_url(image_url) else: raise gr.Error("Please provide an image (upload or URL)") if randomize_seed: seed = random.randint(0, MAX_SEED) true_input_image = input_image # Process input image input_image, w_original, h_original, was_resized = process_input( input_image, upscale_factor ) prompt = custom_prompt if custom_prompt.strip() else "" generator = torch.Generator(device=device).manual_seed(seed) gr.Info("🚀 Upscaling image...") # Initial upscale if USE_ESRGAN and upscale_factor == 4: control_image = esrgan_upscale(input_image, upscale_factor) else: w, h = input_image.size control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS) # Resize control_image to divisible by 16 (stretching) control_w, control_h = control_image.size new_control_w = make_divisible_by_16(control_w) new_control_h = make_divisible_by_16(control_h) if (new_control_w, new_control_h) != (control_w, control_h): control_image = control_image.resize((new_control_w, new_control_h), resample=Image.LANCZOS) # Tiled Flux Img2Img for refinement image = tiled_flux_img2img( pipe, prompt, control_image, denoising_strength, num_inference_steps, 3.5, # Updated guidance_scale to match workflow (3.5) generator, tile_size=tile_size, overlap=32 ) # Resize back to original target size if stretched target_w, target_h = w_original * upscale_factor, h_original * upscale_factor if image.size != (target_w, target_h): image = image.resize((target_w, target_h), resample=Image.LANCZOS) if was_resized: gr.Info(f"📏 Resizing output to target size: {target_w}x{target_h}") image = image.resize((target_w, target_h), resample=Image.LANCZOS) # Resize input image to match output size for slider alignment resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS) # Move back to CPU to release GPU if possible if device == "cuda": pipe.to("cpu") if USE_ESRGAN: esrgan_model.to("cpu") return [resized_input, image] # Create Gradio interface with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX") as demo: gr.HTML("""

🎨 AI Image Upscaler

Upload an image or provide a URL to upscale it using FLUX upscaling

Currently running on {}

""".format(power_device)) with gr.Row(): with gr.Column(scale=1): gr.HTML("

📤 Input

") with gr.Tabs(): with gr.TabItem("📁 Upload Image"): input_image = gr.Image( label="Upload Image", type="pil", height=200 # Made smaller ) with gr.TabItem("🔗 Image URL"): image_url = gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg", value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" ) gr.HTML("

🎛️ Prompt Settings

") custom_prompt = gr.Textbox( label="Custom Prompt (optional)", placeholder="Enter custom prompt or leave empty", lines=2 ) gr.HTML("

⚙️ Upscaling Settings

") upscale_factor = gr.Slider( label="Upscale Factor", minimum=1, maximum=4, step=1, value=2, info="How much to upscale the image" ) num_inference_steps = gr.Slider( label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=4, info="More steps = better quality but slower (default 4 for schnell)" ) denoising_strength = gr.Slider( label="Denoising Strength", minimum=0.0, maximum=1.0, step=0.05, value=0.3, info="Controls how much the image is transformed" ) tile_size = gr.Slider( label="Tile Size", minimum=256, maximum=2048, step=64, value=1024, info="Size of tiles for processing (larger = faster but more memory)" ) with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize seed", value=True ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True ) enhance_btn = gr.Button( "🚀 Upscale Image", variant="primary", size="lg" ) with gr.Column(scale=2): # Larger scale for results gr.HTML("

📊 Results

") result_slider = ImageSlider( type="pil", interactive=False, # Disable interactivity to prevent uploads height=600, # Made larger elem_id="result_slider", label=None # Remove default label ) # Event handler enhance_btn.click( fn=enhance_image, inputs=[ input_image, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, custom_prompt, tile_size ], outputs=[result_slider] ) gr.HTML("""

Note: This upscaler uses the Flux.1-schnell model. Users are responsible for obtaining commercial rights if used commercially under their license.

""") # Custom CSS for slider gr.HTML(""" """) # JS to set slider default position to middle gr.HTML(""" """) if __name__ == "__main__": demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)