import logging import random import warnings import os import gradio as gr import numpy as np import spaces import torch from diffusers import FluxImg2ImgPipeline from transformers import AutoProcessor, AutoModelForCausalLM from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import requests import gc # Disable ESRGAN for ZeroGPU (saves memory and complexity) USE_ESRGAN = False css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ # Device setup power_device = "ZeroGPU" device = "cpu" # Start on CPU # Get HuggingFace token huggingface_token = os.getenv("HF_TOKEN") # Download FLUX model print("📥 Downloading FLUX model...") model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load Florence-2 model print("📥 Loading Florence-2 model...") florence_model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-large", torch_dtype=torch.float32, trust_remote_code=True, attn_implementation="eager" ).to(device).eval() florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-large", trust_remote_code=True ) # Load FLUX pipeline print("📥 Loading FLUX Img2Img...") pipe = FluxImg2ImgPipeline.from_pretrained( model_path, torch_dtype=torch.float32 ) # Enable memory optimizations pipe.enable_model_cpu_offload() pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.vae.enable_tiling() pipe.vae.enable_slicing() print("✅ All models loaded successfully!") MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 2048 * 2048 # Reduced for ZeroGPU stability def truncate_caption(caption, max_tokens=70): """Truncate caption to avoid CLIP token limit""" words = caption.split() truncated = [] current_length = 0 for word in words: # Rough estimate: 1 word ≈ 1.3 tokens if current_length + len(word) * 1.3 > max_tokens: break truncated.append(word) current_length += len(word) * 1.3 result = ' '.join(truncated) if len(truncated) < len(words): result += "..." return result def make_multiple_16(n): """Round to nearest multiple of 16""" return ((n + 15) // 16) * 16 def generate_caption(image): """Generate caption using Florence-2""" try: # Keep on CPU for caption generation task_prompt = "" # Resize image if too large for captioning if image.width > 1024 or image.height > 1024: image.thumbnail((1024, 1024), Image.LANCZOS) inputs = florence_processor( text=task_prompt, images=image, return_tensors="pt" ).to(device) with torch.no_grad(): generated_ids = florence_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=256, # Reduced from 1024 num_beams=1, # Reduced from 3 do_sample=False, # Faster without sampling ) generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = florence_processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) caption = parsed_answer[task_prompt] # Truncate to avoid CLIP token limit caption = truncate_caption(caption, max_tokens=70) return caption except Exception as e: print(f"Caption generation failed: {e}") return "high quality detailed image" def process_input(input_image, upscale_factor): """Process input image with size constraints""" w, h = input_image.size w_original, h_original = w, h was_resized = False # Check pixel budget if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: gr.Info("Resizing input to fit within processing limits...") target_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) scale = (target_pixels / (w * h)) ** 0.5 new_w = make_multiple_16(int(w * scale)) new_h = make_multiple_16(int(h * scale)) input_image = input_image.resize((new_w, new_h), Image.LANCZOS) was_resized = True # Ensure dimensions are multiples of 16 w, h = input_image.size new_w = make_multiple_16(w) new_h = make_multiple_16(h) if new_w != w or new_h != h: padded = Image.new('RGB', (new_w, new_h)) padded.paste(input_image, (0, 0)) input_image = padded return input_image, w_original, h_original, was_resized def simple_upscale(image, factor): """Simple LANCZOS upscaling""" return image.resize( (image.width * factor, image.height * factor), Image.LANCZOS ) @spaces.GPU(duration=90) # Reduced from 120 def enhance_image( image_input, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, use_generated_caption, custom_prompt, progress=gr.Progress(track_tqdm=True), ): """Main enhancement function optimized for ZeroGPU""" try: # Clear cache at start torch.cuda.empty_cache() gc.collect() # Handle image input if image_input is not None: input_image = image_input elif image_url: response = requests.get(image_url, stream=True) response.raise_for_status() input_image = Image.open(response.raw) else: raise gr.Error("Please provide an image") if randomize_seed: seed = random.randint(0, MAX_SEED) original_image = input_image.copy() # Process and validate input input_image, w_orig, h_orig, was_resized = process_input( input_image, upscale_factor ) # Generate or use caption (keep on CPU) if use_generated_caption: gr.Info("Generating caption...") prompt = generate_caption(input_image) print(f"Caption: {prompt}") else: prompt = custom_prompt.strip() if custom_prompt else "high quality image" prompt = truncate_caption(prompt, max_tokens=70) # Initial upscale with LANCZOS gr.Info("Upscaling image...") upscaled = simple_upscale(input_image, upscale_factor) # Move pipeline to GPU only when needed pipe.to("cuda") # For large images, process in smaller chunks w, h = upscaled.size # Determine if we need tiling based on size need_tiling = (w > 1536 or h > 1536) if need_tiling: gr.Info("Processing large image in tiles...") # Process center crop for now (to avoid timeout) crop_size = min(1024, w, h) left = (w - crop_size) // 2 top = (h - crop_size) // 2 cropped = upscaled.crop((left, top, left + crop_size, top + crop_size)) # Ensure dimensions are multiples of 16 crop_w = make_multiple_16(cropped.width) crop_h = make_multiple_16(cropped.height) if crop_w != cropped.width or crop_h != cropped.height: padded_crop = Image.new('RGB', (crop_w, crop_h)) padded_crop.paste(cropped, (0, 0)) cropped = padded_crop # Process with FLUX with torch.inference_mode(): generator = torch.Generator(device="cuda").manual_seed(seed) result_crop = pipe( prompt=prompt, image=cropped, strength=denoising_strength, num_inference_steps=num_inference_steps, guidance_scale=1.0, height=crop_h, width=crop_w, generator=generator, ).images[0] # Crop back if padded if crop_w != cropped.width or crop_h != cropped.height: result_crop = result_crop.crop((0, 0, cropped.width, cropped.height)) # Paste enhanced crop back result = upscaled.copy() result.paste(result_crop, (left, top)) else: # Process entire image if small enough # Ensure dimensions are multiples of 16 proc_w = make_multiple_16(w) proc_h = make_multiple_16(h) if proc_w != w or proc_h != h: padded = Image.new('RGB', (proc_w, proc_h)) padded.paste(upscaled, (0, 0)) upscaled = padded with torch.inference_mode(): generator = torch.Generator(device="cuda").manual_seed(seed) result = pipe( prompt=prompt, image=upscaled, strength=denoising_strength, num_inference_steps=num_inference_steps, guidance_scale=1.0, height=proc_h, width=proc_w, generator=generator, ).images[0] # Crop back if padded if proc_w != w or proc_h != h: result = result.crop((0, 0, w, h)) # Final resize if needed if was_resized: result = result.resize( (w_orig * upscale_factor, h_orig * upscale_factor), Image.LANCZOS ) # Prepare for slider input_resized = original_image.resize(result.size, Image.LANCZOS) # Clean up pipe.to("cpu") torch.cuda.empty_cache() gc.collect() return [input_resized, result] except Exception as e: # Ensure cleanup on error pipe.to("cpu") torch.cuda.empty_cache() gc.collect() raise gr.Error(f"Processing failed: {str(e)}") # Gradio Interface with gr.Blocks(css=css) as demo: gr.HTML(f"""

🎨 AI Image Upscaler

Upscale images using Florence-2 + FLUX (Optimized for ZeroGPU)

Running on {power_device}

""") with gr.Row(): with gr.Column(scale=1): gr.HTML("

📤 Input

") with gr.Tabs(): with gr.TabItem("Upload"): input_image = gr.Image( label="Upload Image", type="pil", height=200 ) with gr.TabItem("URL"): image_url = gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg" ) use_generated_caption = gr.Checkbox( label="Auto-generate caption", value=True ) custom_prompt = gr.Textbox( label="Custom Prompt (optional)", placeholder="Override auto-caption if desired", lines=2 ) upscale_factor = gr.Slider( label="Upscale Factor", minimum=2, maximum=4, step=1, value=2 ) num_inference_steps = gr.Slider( label="Quality (Steps)", minimum=15, maximum=30, step=1, value=20, info="Higher = better but slower" ) denoising_strength = gr.Slider( label="Enhancement Strength", minimum=0.1, maximum=0.5, step=0.05, value=0.3, info="Higher = more changes" ) with gr.Row(): randomize_seed = gr.Checkbox(label="Random seed", value=True) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 ) enhance_btn = gr.Button("🚀 Upscale", variant="primary", size="lg") with gr.Column(scale=2): gr.HTML("

📊 Result

") result_slider = ImageSlider( type="pil", interactive=False, height=500, label=None ) enhance_btn.click( fn=enhance_image, inputs=[ input_image, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, use_generated_caption, custom_prompt ], outputs=[result_slider] ) gr.HTML("""
âš¡ Optimized for ZeroGPU: Max 2048x2048 output, simplified processing for stability
""") if __name__ == "__main__": demo.queue(max_size=3).launch( share=False, # Don't use share on Spaces server_name="0.0.0.0", server_port=7860 )