Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import random | |
import warnings | |
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import FluxImg2ImgPipeline | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import requests | |
import gc | |
# Disable ESRGAN for ZeroGPU (saves memory and complexity) | |
USE_ESRGAN = False | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
""" | |
# Device setup | |
power_device = "ZeroGPU" | |
device = "cpu" # Start on CPU | |
# Get HuggingFace token | |
huggingface_token = os.getenv("HF_TOKEN") | |
# Download FLUX model | |
print("π₯ Downloading FLUX model...") | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*.gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token, | |
) | |
# Load Florence-2 model | |
print("π₯ Loading Florence-2 model...") | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=torch.float32, | |
trust_remote_code=True, | |
attn_implementation="eager" | |
).to(device).eval() | |
florence_processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-large", | |
trust_remote_code=True | |
) | |
# Load FLUX pipeline | |
print("π₯ Loading FLUX Img2Img...") | |
pipe = FluxImg2ImgPipeline.from_pretrained( | |
model_path, | |
torch_dtype=torch.float32 | |
) | |
# Enable memory optimizations | |
pipe.enable_model_cpu_offload() | |
pipe.enable_vae_tiling() | |
pipe.enable_vae_slicing() | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
print("β All models loaded successfully!") | |
MAX_SEED = 1000000 | |
MAX_PIXEL_BUDGET = 2048 * 2048 # Reduced for ZeroGPU stability | |
def truncate_caption(caption, max_tokens=70): | |
"""Truncate caption to avoid CLIP token limit""" | |
words = caption.split() | |
truncated = [] | |
current_length = 0 | |
for word in words: | |
# Rough estimate: 1 word β 1.3 tokens | |
if current_length + len(word) * 1.3 > max_tokens: | |
break | |
truncated.append(word) | |
current_length += len(word) * 1.3 | |
result = ' '.join(truncated) | |
if len(truncated) < len(words): | |
result += "..." | |
return result | |
def make_multiple_16(n): | |
"""Round to nearest multiple of 16""" | |
return ((n + 15) // 16) * 16 | |
def generate_caption(image): | |
"""Generate caption using Florence-2""" | |
try: | |
# Keep on CPU for caption generation | |
task_prompt = "<MORE_DETAILED_CAPTION>" | |
# Resize image if too large for captioning | |
if image.width > 1024 or image.height > 1024: | |
image.thumbnail((1024, 1024), Image.LANCZOS) | |
inputs = florence_processor( | |
text=task_prompt, | |
images=image, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=256, # Reduced from 1024 | |
num_beams=1, # Reduced from 3 | |
do_sample=False, # Faster without sampling | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
caption = parsed_answer[task_prompt] | |
# Truncate to avoid CLIP token limit | |
caption = truncate_caption(caption, max_tokens=70) | |
return caption | |
except Exception as e: | |
print(f"Caption generation failed: {e}") | |
return "high quality detailed image" | |
def process_input(input_image, upscale_factor): | |
"""Process input image with size constraints""" | |
w, h = input_image.size | |
w_original, h_original = w, h | |
was_resized = False | |
# Check pixel budget | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
gr.Info("Resizing input to fit within processing limits...") | |
target_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) | |
scale = (target_pixels / (w * h)) ** 0.5 | |
new_w = make_multiple_16(int(w * scale)) | |
new_h = make_multiple_16(int(h * scale)) | |
input_image = input_image.resize((new_w, new_h), Image.LANCZOS) | |
was_resized = True | |
# Ensure dimensions are multiples of 16 | |
w, h = input_image.size | |
new_w = make_multiple_16(w) | |
new_h = make_multiple_16(h) | |
if new_w != w or new_h != h: | |
padded = Image.new('RGB', (new_w, new_h)) | |
padded.paste(input_image, (0, 0)) | |
input_image = padded | |
return input_image, w_original, h_original, was_resized | |
def simple_upscale(image, factor): | |
"""Simple LANCZOS upscaling""" | |
return image.resize( | |
(image.width * factor, image.height * factor), | |
Image.LANCZOS | |
) | |
# Reduced from 120 | |
def enhance_image( | |
image_input, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
denoising_strength, | |
use_generated_caption, | |
custom_prompt, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main enhancement function optimized for ZeroGPU""" | |
try: | |
# Clear cache at start | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Handle image input | |
if image_input is not None: | |
input_image = image_input | |
elif image_url: | |
response = requests.get(image_url, stream=True) | |
response.raise_for_status() | |
input_image = Image.open(response.raw) | |
else: | |
raise gr.Error("Please provide an image") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
original_image = input_image.copy() | |
# Process and validate input | |
input_image, w_orig, h_orig, was_resized = process_input( | |
input_image, upscale_factor | |
) | |
# Generate or use caption (keep on CPU) | |
if use_generated_caption: | |
gr.Info("Generating caption...") | |
prompt = generate_caption(input_image) | |
print(f"Caption: {prompt}") | |
else: | |
prompt = custom_prompt.strip() if custom_prompt else "high quality image" | |
prompt = truncate_caption(prompt, max_tokens=70) | |
# Initial upscale with LANCZOS | |
gr.Info("Upscaling image...") | |
upscaled = simple_upscale(input_image, upscale_factor) | |
# Move pipeline to GPU only when needed | |
pipe.to("cuda") | |
# For large images, process in smaller chunks | |
w, h = upscaled.size | |
# Determine if we need tiling based on size | |
need_tiling = (w > 1536 or h > 1536) | |
if need_tiling: | |
gr.Info("Processing large image in tiles...") | |
# Process center crop for now (to avoid timeout) | |
crop_size = min(1024, w, h) | |
left = (w - crop_size) // 2 | |
top = (h - crop_size) // 2 | |
cropped = upscaled.crop((left, top, left + crop_size, top + crop_size)) | |
# Ensure dimensions are multiples of 16 | |
crop_w = make_multiple_16(cropped.width) | |
crop_h = make_multiple_16(cropped.height) | |
if crop_w != cropped.width or crop_h != cropped.height: | |
padded_crop = Image.new('RGB', (crop_w, crop_h)) | |
padded_crop.paste(cropped, (0, 0)) | |
cropped = padded_crop | |
# Process with FLUX | |
with torch.inference_mode(): | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
result_crop = pipe( | |
prompt=prompt, | |
image=cropped, | |
strength=denoising_strength, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=1.0, | |
height=crop_h, | |
width=crop_w, | |
generator=generator, | |
).images[0] | |
# Crop back if padded | |
if crop_w != cropped.width or crop_h != cropped.height: | |
result_crop = result_crop.crop((0, 0, cropped.width, cropped.height)) | |
# Paste enhanced crop back | |
result = upscaled.copy() | |
result.paste(result_crop, (left, top)) | |
else: | |
# Process entire image if small enough | |
# Ensure dimensions are multiples of 16 | |
proc_w = make_multiple_16(w) | |
proc_h = make_multiple_16(h) | |
if proc_w != w or proc_h != h: | |
padded = Image.new('RGB', (proc_w, proc_h)) | |
padded.paste(upscaled, (0, 0)) | |
upscaled = padded | |
with torch.inference_mode(): | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
result = pipe( | |
prompt=prompt, | |
image=upscaled, | |
strength=denoising_strength, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=1.0, | |
height=proc_h, | |
width=proc_w, | |
generator=generator, | |
).images[0] | |
# Crop back if padded | |
if proc_w != w or proc_h != h: | |
result = result.crop((0, 0, w, h)) | |
# Final resize if needed | |
if was_resized: | |
result = result.resize( | |
(w_orig * upscale_factor, h_orig * upscale_factor), | |
Image.LANCZOS | |
) | |
# Prepare for slider | |
input_resized = original_image.resize(result.size, Image.LANCZOS) | |
# Clean up | |
pipe.to("cpu") | |
torch.cuda.empty_cache() | |
gc.collect() | |
return [input_resized, result] | |
except Exception as e: | |
# Ensure cleanup on error | |
pipe.to("cpu") | |
torch.cuda.empty_cache() | |
gc.collect() | |
raise gr.Error(f"Processing failed: {str(e)}") | |
# Gradio Interface | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(f""" | |
<div class="main-header"> | |
<h1>π¨ AI Image Upscaler</h1> | |
<p>Upscale images using Florence-2 + FLUX (Optimized for ZeroGPU)</p> | |
<p>Running on <strong>{power_device}</strong></p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Input</h3>") | |
with gr.Tabs(): | |
with gr.TabItem("Upload"): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
height=200 | |
) | |
with gr.TabItem("URL"): | |
image_url = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg" | |
) | |
use_generated_caption = gr.Checkbox( | |
label="Auto-generate caption", | |
value=True | |
) | |
custom_prompt = gr.Textbox( | |
label="Custom Prompt (optional)", | |
placeholder="Override auto-caption if desired", | |
lines=2 | |
) | |
upscale_factor = gr.Slider( | |
label="Upscale Factor", | |
minimum=2, | |
maximum=4, | |
step=1, | |
value=2 | |
) | |
num_inference_steps = gr.Slider( | |
label="Quality (Steps)", | |
minimum=15, | |
maximum=30, | |
step=1, | |
value=20, | |
info="Higher = better but slower" | |
) | |
denoising_strength = gr.Slider( | |
label="Enhancement Strength", | |
minimum=0.1, | |
maximum=0.5, | |
step=0.05, | |
value=0.3, | |
info="Higher = more changes" | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(label="Random seed", value=True) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42 | |
) | |
enhance_btn = gr.Button("π Upscale", variant="primary", size="lg") | |
with gr.Column(scale=2): | |
gr.HTML("<h3>π Result</h3>") | |
result_slider = ImageSlider( | |
type="pil", | |
interactive=False, | |
height=500, | |
label=None | |
) | |
enhance_btn.click( | |
fn=enhance_image, | |
inputs=[ | |
input_image, image_url, seed, randomize_seed, | |
num_inference_steps, upscale_factor, denoising_strength, | |
use_generated_caption, custom_prompt | |
], | |
outputs=[result_slider] | |
) | |
gr.HTML(""" | |
<div style="margin-top: 1rem; padding: 0.5rem; background: #f0f0f0; border-radius: 8px;"> | |
<small>β‘ Optimized for ZeroGPU: Max 2048x2048 output, simplified processing for stability</small> | |
</div> | |
""") | |
if __name__ == "__main__": | |
demo.queue(max_size=3).launch( | |
share=False, # Don't use share on Spaces | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |