Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import random | |
import warnings | |
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import FluxImg2ImgPipeline | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import requests | |
# For ESRGAN (optional - will work without it) | |
try: | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from basicsr.utils import img2tensor, tensor2img | |
USE_ESRGAN = True | |
except ImportError: | |
USE_ESRGAN = False | |
warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
""" | |
# Device setup | |
power_device = "ZeroGPU" | |
device = "cpu" # Start on CPU, will move to GPU when needed | |
# Get HuggingFace token | |
huggingface_token = os.getenv("HF_TOKEN") | |
# Download FLUX model | |
print("π₯ Downloading FLUX model...") | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*.gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token, | |
) | |
# Load Florence-2 model for image captioning on CPU | |
print("π₯ Loading Florence-2 model...") | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=torch.float32, # Use float32 on CPU to avoid dtype issues | |
trust_remote_code=True, | |
attn_implementation="eager" | |
).to(device) | |
florence_processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-large", | |
trust_remote_code=True | |
) | |
# Load FLUX Img2Img pipeline on CPU | |
print("π₯ Loading FLUX Img2Img...") | |
pipe = FluxImg2ImgPipeline.from_pretrained( | |
model_path, | |
torch_dtype=torch.float32 # Start with float32 on CPU | |
) | |
pipe.enable_vae_tiling() | |
pipe.enable_vae_slicing() | |
print("β All models loaded successfully!") | |
# Download ESRGAN model if using | |
if USE_ESRGAN: | |
try: | |
esrgan_path = "4x-UltraSharp.pth" | |
if not os.path.exists(esrgan_path): | |
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" | |
print("π₯ Downloading ESRGAN model...") | |
with open(esrgan_path, "wb") as f: | |
f.write(requests.get(url).content) | |
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
state_dict = torch.load(esrgan_path, map_location='cpu')['params_ema'] | |
esrgan_model.load_state_dict(state_dict) | |
esrgan_model.eval() | |
print("β ESRGAN model loaded!") | |
except Exception as e: | |
print(f"Failed to load ESRGAN: {e}") | |
USE_ESRGAN = False | |
MAX_SEED = 1000000 | |
MAX_PIXEL_BUDGET = 8192 * 8192 | |
def make_multiple_16(n): | |
"""Round up to nearest multiple of 16""" | |
return ((n + 15) // 16) * 16 | |
def generate_caption(image): | |
"""Generate detailed caption using Florence-2""" | |
try: | |
# Ensure model is on the correct device with correct dtype | |
if florence_model.device.type == "cuda": | |
florence_model.to(torch.float16) | |
task_prompt = "<MORE_DETAILED_CAPTION>" | |
prompt = task_prompt | |
inputs = florence_processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(florence_model.device) | |
# Ensure dtype consistency | |
if florence_model.device.type == "cuda": | |
if hasattr(inputs, "pixel_values"): | |
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=True, | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
caption = parsed_answer[task_prompt] | |
return caption | |
except Exception as e: | |
print(f"Caption generation failed: {e}") | |
return "a high quality detailed image" | |
def process_input(input_image, upscale_factor): | |
"""Process input image and handle size constraints""" | |
w, h = input_image.size | |
w_original, h_original = w, h | |
was_resized = False | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
warnings.warn( | |
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget." | |
) | |
gr.Info( | |
f"Requested output image is too large. Resizing input to fit within pixel budget." | |
) | |
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) | |
scale = (target_input_pixels / (w * h)) ** 0.5 | |
new_w = make_multiple_16(int(w * scale)) | |
new_h = make_multiple_16(int(h * scale)) | |
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS) | |
was_resized = True | |
return input_image, w_original, h_original, was_resized | |
def load_image_from_url(url): | |
"""Load image from URL""" | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
return Image.open(response.raw) | |
except Exception as e: | |
raise gr.Error(f"Failed to load image from URL: {e}") | |
def esrgan_upscale(image, scale=4): | |
"""Upscale image using ESRGAN or fallback to LANCZOS""" | |
if not USE_ESRGAN: | |
return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS) | |
try: | |
img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True) | |
with torch.no_grad(): | |
# Move model to same device as image tensor | |
if torch.cuda.is_available(): | |
esrgan_model.to("cuda") | |
img = img.to("cuda") | |
output = esrgan_model(img.unsqueeze(0)).squeeze() | |
output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1)) | |
return Image.fromarray(output_img) | |
except Exception as e: | |
print(f"ESRGAN upscale failed: {e}, falling back to LANCZOS") | |
return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS) | |
def create_blend_mask(width, height, overlap, edge_x, edge_y): | |
"""Create a gradient blend mask for smooth tile transitions""" | |
mask = Image.new('L', (width, height), 255) | |
pixels = mask.load() | |
# Horizontal blend (left edge) | |
if edge_x and overlap > 0: | |
for x in range(min(overlap, width)): | |
alpha = x / overlap | |
for y in range(height): | |
pixels[x, y] = int(255 * alpha) | |
# Vertical blend (top edge) | |
if edge_y and overlap > 0: | |
for y in range(min(overlap, height)): | |
alpha = y / overlap | |
for x in range(width): | |
# Combine with existing alpha if both edges | |
existing = pixels[x, y] / 255.0 | |
combined = min(existing, alpha) | |
pixels[x, y] = int(255 * combined) | |
return mask | |
def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=64): | |
"""Tiled Img2Img to handle large images""" | |
w, h = image.size | |
# Ensure tile_size is divisible by 16 | |
tile_size = make_multiple_16(tile_size) | |
overlap = make_multiple_16(overlap) | |
# If image is small enough, process without tiling | |
if w <= tile_size and h <= tile_size: | |
# Ensure dimensions are divisible by 16 | |
new_w = make_multiple_16(w) | |
new_h = make_multiple_16(h) | |
if new_w != w or new_h != h: | |
padded_image = Image.new('RGB', (new_w, new_h)) | |
padded_image.paste(image, (0, 0)) | |
else: | |
padded_image = image | |
result = pipe( | |
prompt=prompt, | |
image=padded_image, | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance, | |
height=new_h, | |
width=new_w, | |
generator=generator, | |
).images[0] | |
# Crop back to original size if padded | |
if new_w != w or new_h != h: | |
result = result.crop((0, 0, w, h)) | |
return result | |
# Process with tiling for large images | |
output = Image.new('RGB', (w, h)) | |
# Calculate tile positions | |
tiles = [] | |
for y in range(0, h, tile_size - overlap): | |
for x in range(0, w, tile_size - overlap): | |
tile_w = min(tile_size, w - x) | |
tile_h = min(tile_size, h - y) | |
# Ensure tile dimensions are divisible by 16 | |
tile_w_padded = make_multiple_16(tile_w) | |
tile_h_padded = make_multiple_16(tile_h) | |
tiles.append({ | |
'x': x, | |
'y': y, | |
'w': tile_w, | |
'h': tile_h, | |
'w_padded': tile_w_padded, | |
'h_padded': tile_h_padded, | |
'edge_x': x > 0, | |
'edge_y': y > 0 | |
}) | |
# Process each tile | |
for i, tile_info in enumerate(tiles): | |
print(f"Processing tile {i+1}/{len(tiles)}...") | |
# Extract tile from image | |
tile = image.crop(( | |
tile_info['x'], | |
tile_info['y'], | |
tile_info['x'] + tile_info['w'], | |
tile_info['y'] + tile_info['h'] | |
)) | |
# Pad if necessary | |
if tile_info['w_padded'] != tile_info['w'] or tile_info['h_padded'] != tile_info['h']: | |
padded_tile = Image.new('RGB', (tile_info['w_padded'], tile_info['h_padded'])) | |
padded_tile.paste(tile, (0, 0)) | |
tile = padded_tile | |
# Process tile with FLUX | |
try: | |
gen_tile = pipe( | |
prompt=prompt, | |
image=tile, | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance, | |
height=tile_info['h_padded'], | |
width=tile_info['w_padded'], | |
generator=generator, | |
).images[0] | |
# Crop back to original tile size if padded | |
if tile_info['w_padded'] != tile_info['w'] or tile_info['h_padded'] != tile_info['h']: | |
gen_tile = gen_tile.crop((0, 0, tile_info['w'], tile_info['h'])) | |
# Create blend mask if needed | |
if overlap > 0 and (tile_info['edge_x'] or tile_info['edge_y']): | |
mask = create_blend_mask( | |
tile_info['w'], | |
tile_info['h'], | |
overlap, | |
tile_info['edge_x'], | |
tile_info['edge_y'] | |
) | |
# Composite with blending | |
output.paste(gen_tile, (tile_info['x'], tile_info['y']), mask) | |
else: | |
# Direct paste for first tile or no overlap | |
output.paste(gen_tile, (tile_info['x'], tile_info['y'])) | |
except Exception as e: | |
print(f"Error processing tile: {e}") | |
# Fallback: paste original tile | |
output.paste(tile, (tile_info['x'], tile_info['y'])) | |
return output | |
def enhance_image( | |
image_input, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
denoising_strength, | |
use_generated_caption, | |
custom_prompt, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main enhancement function""" | |
try: | |
# Move models to GPU and convert to appropriate dtype | |
pipe.to("cuda") | |
pipe.to(torch.bfloat16) | |
florence_model.to("cuda") | |
florence_model.to(torch.float16) | |
# Handle image input | |
if image_input is not None: | |
input_image = image_input | |
elif image_url: | |
input_image = load_image_from_url(image_url) | |
else: | |
raise gr.Error("Please provide an image (upload or URL)") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
true_input_image = input_image | |
# Process input image | |
input_image, w_original, h_original, was_resized = process_input( | |
input_image, upscale_factor | |
) | |
# Generate caption if requested | |
if use_generated_caption: | |
gr.Info("π Generating image caption...") | |
generated_caption = generate_caption(input_image) | |
prompt = generated_caption | |
print(f"Generated caption: {prompt}") | |
else: | |
prompt = custom_prompt if custom_prompt.strip() else "" | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
gr.Info("π Upscaling image...") | |
# Initial upscale | |
if USE_ESRGAN and upscale_factor == 4: | |
if torch.cuda.is_available(): | |
esrgan_model.to("cuda") | |
control_image = esrgan_upscale(input_image, upscale_factor) | |
if torch.cuda.is_available(): | |
esrgan_model.to("cpu") | |
else: | |
w, h = input_image.size | |
control_image = input_image.resize( | |
(w * upscale_factor, h * upscale_factor), | |
resample=Image.LANCZOS | |
) | |
# Tiled Flux Img2Img for refinement | |
image = tiled_flux_img2img( | |
pipe, | |
prompt, | |
control_image, | |
denoising_strength, | |
num_inference_steps, | |
1.0, # guidance_scale fixed to 1.0 | |
generator, | |
tile_size=1024, | |
overlap=64 | |
) | |
if was_resized: | |
gr.Info(f"π Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}") | |
image = image.resize( | |
(w_original * upscale_factor, h_original * upscale_factor), | |
resample=Image.LANCZOS | |
) | |
# Resize input image to match output size for slider alignment | |
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS) | |
# Move models back to CPU to release GPU | |
pipe.to("cpu") | |
florence_model.to("cpu") | |
torch.cuda.empty_cache() | |
return [resized_input, image] | |
except Exception as e: | |
# Ensure models are moved back to CPU even on error | |
pipe.to("cpu") | |
florence_model.to("cpu") | |
torch.cuda.empty_cache() | |
raise gr.Error(f"Enhancement failed: {str(e)}") | |
# Create Gradio interface | |
with gr.Blocks(css=css, title="π¨ AI Image Upscaler - Florence-2 + FLUX") as demo: | |
gr.HTML(f""" | |
<div class="main-header"> | |
<h1>π¨ AI Image Upscaler</h1> | |
<p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p> | |
<p>Currently running on <strong>{power_device}</strong></p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Input</h3>") | |
with gr.Tabs(): | |
with gr.TabItem("π Upload Image"): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
height=200 | |
) | |
with gr.TabItem("π Image URL"): | |
image_url = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
value="" | |
) | |
gr.HTML("<h3>ποΈ Caption Settings</h3>") | |
use_generated_caption = gr.Checkbox( | |
label="Use AI-generated caption (Florence-2)", | |
value=True, | |
info="Generate detailed caption automatically" | |
) | |
custom_prompt = gr.Textbox( | |
label="Custom Prompt (optional)", | |
placeholder="Enter custom prompt or leave empty for generated caption", | |
lines=2 | |
) | |
gr.HTML("<h3>βοΈ Upscaling Settings</h3>") | |
upscale_factor = gr.Slider( | |
label="Upscale Factor", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=2, | |
info="How much to upscale the image" | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of Inference Steps", | |
minimum=8, | |
maximum=50, | |
step=1, | |
value=25, | |
info="More steps = better quality but slower" | |
) | |
denoising_strength = gr.Slider( | |
label="Denoising Strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.3, | |
info="Controls how much the image is transformed" | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
interactive=True | |
) | |
enhance_btn = gr.Button( | |
"π Upscale Image", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
gr.HTML("<h3>π Results</h3>") | |
result_slider = ImageSlider( | |
type="pil", | |
interactive=False, | |
height=600, | |
elem_id="result_slider", | |
label=None | |
) | |
# Event handler | |
enhance_btn.click( | |
fn=enhance_image, | |
inputs=[ | |
input_image, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
denoising_strength, | |
use_generated_caption, | |
custom_prompt, | |
], | |
outputs=[result_slider] | |
) | |
gr.HTML(""" | |
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
<p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p> | |
</div> | |
""") | |
if __name__ == "__main__": | |
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860) |